Rework GPU halo and state propagator streams and dependencies to get better overlap
[alexxy/gromacs.git] / src / gromacs / domdec / domdec.cpp
index 8a634b7c07594c99bb43fb46e62557beaeacab4e..f063f3b0e75cc38a6edb6364e4cd1c1fd6d3680e 100644 (file)
@@ -3205,14 +3205,7 @@ void constructGpuHaloExchange(const gmx::MDLogger&            mdlog,
         for (int pulse = cr.dd->gpuHaloExchange[d].size(); pulse < cr.dd->comm->cd[d].numPulses(); pulse++)
         {
             cr.dd->gpuHaloExchange[d].push_back(std::make_unique<gmx::GpuHaloExchange>(
-                    cr.dd,
-                    d,
-                    cr.mpi_comm_mygroup,
-                    deviceStreamManager.context(),
-                    deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal),
-                    deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal),
-                    pulse,
-                    wcycle));
+                    cr.dd, d, cr.mpi_comm_mygroup, deviceStreamManager.context(), pulse, wcycle));
         }
     }
 }
@@ -3230,26 +3223,31 @@ void reinitGpuHaloExchange(const t_commrec&              cr,
     }
 }
 
-void communicateGpuHaloCoordinates(const t_commrec&      cr,
-                                   const matrix          box,
-                                   GpuEventSynchronizer* coordinatesReadyOnDeviceEvent)
+GpuEventSynchronizer* communicateGpuHaloCoordinates(const t_commrec&      cr,
+                                                    const matrix          box,
+                                                    GpuEventSynchronizer* dependencyEvent)
 {
+    GpuEventSynchronizer* eventPtr = dependencyEvent;
     for (int d = 0; d < cr.dd->ndim; d++)
     {
         for (int pulse = 0; pulse < cr.dd->comm->cd[d].numPulses(); pulse++)
         {
-            cr.dd->gpuHaloExchange[d][pulse]->communicateHaloCoordinates(box, coordinatesReadyOnDeviceEvent);
+            eventPtr = cr.dd->gpuHaloExchange[d][pulse]->communicateHaloCoordinates(box, eventPtr);
         }
     }
+    return eventPtr;
 }
 
-void communicateGpuHaloForces(const t_commrec& cr, bool accumulateForces)
+void communicateGpuHaloForces(const t_commrec&                                    cr,
+                              bool                                                accumulateForces,
+                              gmx::FixedCapacityVector<GpuEventSynchronizer*, 2>* dependencyEvents)
 {
     for (int d = cr.dd->ndim - 1; d >= 0; d--)
     {
         for (int pulse = cr.dd->comm->cd[d].numPulses() - 1; pulse >= 0; pulse--)
         {
-            cr.dd->gpuHaloExchange[d][pulse]->communicateHaloForces(accumulateForces);
+            cr.dd->gpuHaloExchange[d][pulse]->communicateHaloForces(accumulateForces, dependencyEvents);
+            dependencyEvents->push_back(cr.dd->gpuHaloExchange[d][pulse]->getForcesReadyOnDeviceEvent());
         }
     }
 }