Link GPU coordinate producer and consumer tasks
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_internal.cpp
index 7d71782183dafc9b4c5caec0637aac125b5ad8d6..e98ae3a223c8be7ee88cd7aad5682b32866f2442 100644 (file)
@@ -1014,11 +1014,12 @@ std::pair<int, int> inline pmeGpuCreateGrid(const PmeGpu *pmeGpu, int blockCount
     return std::pair<int, int>(colCount, minRowCount);
 }
 
-void pme_gpu_spread(const PmeGpu    *pmeGpu,
-                    int gmx_unused   gridIndex,
-                    real            *h_grid,
-                    bool             computeSplines,
-                    bool             spreadCharges)
+void pme_gpu_spread(const PmeGpu         *pmeGpu,
+                    GpuEventSynchronizer *xReadyOnDevice,
+                    int gmx_unused        gridIndex,
+                    real                 *h_grid,
+                    bool                  computeSplines,
+                    bool                  spreadCharges)
 {
     GMX_ASSERT(computeSplines || spreadCharges, "PME spline/spread kernel has invalid input (nothing to do)");
     const auto   *kernelParamsPtr = pmeGpu->kernelParams.get();
@@ -1037,6 +1038,16 @@ void pme_gpu_spread(const PmeGpu    *pmeGpu,
     //(for spline data mostly, together with varying PME_GPU_PARALLEL_SPLINE define)
     GMX_ASSERT(!c_usePadding || !(c_pmeAtomDataAlignment % atomsPerBlock), "inconsistent atom data padding vs. spreading block size");
 
+    // Ensure that coordinates are ready on the device before launching spread;
+    // only needed with CUDA on PP+PME ranks, not on separate PME ranks, in unit tests
+    // nor in OpenCL as these cases use a single stream (hence xReadyOnDevice == nullptr).
+    // Note: Consider adding an assertion on xReadyOnDevice when we can detect
+    // here separate PME ranks.
+    if (xReadyOnDevice)
+    {
+        xReadyOnDevice->enqueueWaitEvent(pmeGpu->archSpecific->pmeStream);
+    }
+
     const int          blockCount = pmeGpu->nAtomsPadded / atomsPerBlock;
     auto               dimGrid    = pmeGpuCreateGrid(pmeGpu, blockCount);