Pass list of EventSynchronizers to GPU reduction
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_cuda.cu
index ca88b5e94f8accde6f1d1e4944848cfb14d21daa..23502b3fbe8bac6d88c26a5c7910c65d5790e2b2 100644 (file)
@@ -806,15 +806,15 @@ void nbnxn_gpu_x_to_nbat_x(const Nbnxm::Grid               &grid,
 }
 
 /* F buffer operations on GPU: performs force summations and conversion from nb to rvec format. */
-void nbnxn_gpu_add_nbat_f_to_f(const AtomLocality               atomLocality,
-                               DeviceBuffer<float>              totalForcesDevice,
-                               gmx_nbnxn_gpu_t                 *nb,
-                               void                            *pmeForcesDevice,
-                               GpuEventSynchronizer            *pmeForcesReady,
-                               int                              atomStart,
-                               int                              numAtoms,
-                               bool                             useGpuFPmeReduction,
-                               bool                             accumulateForce)
+void nbnxn_gpu_add_nbat_f_to_f(const AtomLocality                          atomLocality,
+                               DeviceBuffer<float>                         totalForcesDevice,
+                               gmx_nbnxn_gpu_t                            *nb,
+                               void                                       *pmeForcesDevice,
+                               gmx::ArrayRef<GpuEventSynchronizer* const>  dependencyList,
+                               int                                         atomStart,
+                               int                                         numAtoms,
+                               bool                                        useGpuFPmeReduction,
+                               bool                                        accumulateForce)
 {
     GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
     GMX_ASSERT(numAtoms != 0, "Cannot call function with no atoms");
@@ -824,10 +824,15 @@ void nbnxn_gpu_add_nbat_f_to_f(const AtomLocality               atomLocality,
     cudaStream_t              stream        = nb->stream[iLocality];
     cu_atomdata_t            *adat          = nb->atdat;
 
-    if (useGpuFPmeReduction)
+    size_t gmx_used_in_debug  numDependency =
+        static_cast<size_t>((useGpuFPmeReduction == true)) +
+        static_cast<size_t>((accumulateForce == true));
+    GMX_ASSERT(numDependency >= dependencyList.size(), "Mismatching number of dependencies and call signature");
+
+    // Enqueue wait on all dependencies passed
+    for (auto const synchronizer : dependencyList)
     {
-        //Stream must wait for PME force completion
-        pmeForcesReady->enqueueWaitEvent(stream);
+        synchronizer->enqueueWaitEvent(stream);
     }
 
     /* launch kernel */