Multiple pulses for GPU Halo Exchange
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index 07644556aaf76f22885166f8bafb2fccc4db81ae..61878b8f5af58b95b8d7c6b1c05183d001ced540 100644 (file)
@@ -1001,6 +1001,7 @@ void do_force(FILE*                               fplog,
     {
         if (stepWork.doNeighborSearch)
         {
+            // TODO refactor this to do_md, after partitioning.
             stateGpu->reinit(mdatoms->homenr,
                              cr->dd != nullptr ? dd_numAtomsZones(*cr->dd) : mdatoms->homenr);
             if (useGpuPmeOnThisRank)
@@ -1023,9 +1024,8 @@ void do_force(FILE*                               fplog,
     // The conditions for gpuHaloExchange e.g. using GPU buffer
     // operations were checked before construction, so here we can
     // just use it and assert upon any conditions.
-    gmx::GpuHaloExchange* gpuHaloExchange =
-            (havePPDomainDecomposition(cr) ? cr->dd->gpuHaloExchange.get() : nullptr);
-    const bool ddUsesGpuDirectCommunication = (gpuHaloExchange != nullptr);
+    const bool ddUsesGpuDirectCommunication =
+            ((cr->dd != nullptr) && (!cr->dd->gpuHaloExchange.empty()));
     GMX_ASSERT(!ddUsesGpuDirectCommunication || stepWork.useGpuXBufferOps,
                "Must use coordinate buffer ops with GPU halo exchange");
     const bool useGpuForcesHaloExchange = ddUsesGpuDirectCommunication && stepWork.useGpuFBufferOps;
@@ -1259,9 +1259,13 @@ void do_force(FILE*                               fplog,
             nbv->setupGpuShortRangeWork(fr->gpuBonded, InteractionLocality::NonLocal);
             wallcycle_sub_stop(wcycle, ewcsNBS_SEARCH_NONLOCAL);
             wallcycle_stop(wcycle, ewcNS);
+            // TODO refactor this GPU halo exchange re-initialisation
+            // to location in do_md where GPU halo exchange is
+            // constructed at partitioning, after above stateGpu
+            // re-initialization has similarly been refactored
             if (ddUsesGpuDirectCommunication)
             {
-                gpuHaloExchange->reinitHalo(stateGpu->getCoordinates(), stateGpu->getForces());
+                reinitGpuHaloExchange(*cr, stateGpu->getCoordinates(), stateGpu->getForces());
             }
         }
         else
@@ -1270,7 +1274,7 @@ void do_force(FILE*                               fplog,
             {
                 // The following must be called after local setCoordinates (which records an event
                 // when the coordinate data has been copied to the device).
-                gpuHaloExchange->communicateHaloCoordinates(box, localXReadyOnDevice);
+                communicateGpuHaloCoordinates(*cr, box, localXReadyOnDevice);
 
                 if (domainWork.haveCpuBondedWork || domainWork.haveFreeEnergyWork)
                 {
@@ -1590,7 +1594,7 @@ void do_force(FILE*                               fplog,
                 {
                     stateGpu->copyForcesToGpu(forceOut.forceWithShiftForces().force(), AtomLocality::Local);
                 }
-                gpuHaloExchange->communicateHaloForces(domainWork.haveCpuLocalForceWork);
+                communicateGpuHaloForces(*cr, domainWork.haveCpuLocalForceWork);
             }
             else
             {
@@ -1731,7 +1735,7 @@ void do_force(FILE*                               fplog,
             }
             if (useGpuForcesHaloExchange)
             {
-                dependencyList.push_back(gpuHaloExchange->getForcesReadyOnDeviceEvent());
+                dependencyList.push_back(cr->dd->gpuHaloExchange[0]->getForcesReadyOnDeviceEvent());
             }
             nbv->atomdata_add_nbat_f_to_f_gpu(AtomLocality::Local, stateGpu->getForces(), pmeForcePtr,
                                               dependencyList, stepWork.useGpuPmeFReduction,