GPU halo exchange
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index c540fa86f6d2d390ceec96aa0d212649b46daf5b..e2fffacd83c6c0366c9d9e6752417305940f9203 100644 (file)
@@ -49,6 +49,7 @@
 #include "gromacs/domdec/dlbtiming.h"
 #include "gromacs/domdec/domdec.h"
 #include "gromacs/domdec/domdec_struct.h"
+#include "gromacs/domdec/gpuhaloexchange.h"
 #include "gromacs/domdec/partition.h"
 #include "gromacs/essentialdynamics/edsam.h"
 #include "gromacs/ewald/pme.h"
@@ -125,6 +126,10 @@ static const bool c_disableAlternatingWait = (getenv("GMX_DISABLE_ALTERNATING_GP
 // TODO eventially tie this in with other existing GPU flags.
 static const bool c_enableGpuBufOps = (getenv("GMX_USE_GPU_BUFFER_OPS") != nullptr);
 
+/*! \brief environment variable to enable GPU P2P communication */
+static const bool c_enableGpuHaloExchange = (getenv("GMX_GPU_DD_COMMS") != nullptr)
+    && GMX_THREAD_MPI && (GMX_GPU == GMX_GPU_CUDA);
+
 static void sum_forces(rvec f[], gmx::ArrayRef<const gmx::RVec> forceToAdd)
 {
     const int      end = forceToAdd.size();
@@ -1180,6 +1185,12 @@ void do_force(FILE                                     *fplog,
         launchPmeGpuFftAndGather(fr->pmedata, wcycle);
     }
 
+    const bool            ddUsesGpuDirectCommunication
+        = c_enableGpuHaloExchange && c_enableGpuBufOps && bUseGPU && havePPDomainDecomposition(cr);
+    gmx::GpuHaloExchange *gpuHaloExchange = ddUsesGpuDirectCommunication ? cr->dd->gpuHaloExchange.get() : nullptr;
+    GMX_ASSERT(!ddUsesGpuDirectCommunication || gpuHaloExchange != nullptr,
+               "Must have valid gpuHaloExchange when doing halo exchange on the GPU");
+
     /* Communicate coordinates and sum dipole if necessary +
        do non-local pair search */
     if (havePPDomainDecomposition(cr))
@@ -1196,15 +1207,36 @@ void do_force(FILE                                     *fplog,
             nbv->setupGpuShortRangeWork(fr->gpuBonded, Nbnxm::InteractionLocality::NonLocal);
             wallcycle_sub_stop(wcycle, ewcsNBS_SEARCH_NONLOCAL);
             wallcycle_stop(wcycle, ewcNS);
+            if (ddUsesGpuDirectCommunication)
+            {
+                rvec* d_x    = static_cast<rvec *> (nbv->get_gpu_xrvec());
+                gpuHaloExchange->reinitHalo(d_x);
+            }
         }
         else
         {
-            dd_move_x(cr->dd, box, x.unpaddedArrayRef(), wcycle);
+            if (ddUsesGpuDirectCommunication)
+            {
+                // The following must be called after local setCoordinates (which records an event
+                // when the coordinate data has been copied to the device).
+                gpuHaloExchange->communicateHaloCoordinates(box);
+
+                // TODO Force flags should include haveFreeEnergyWork for this domain
+                if (forceWork.haveCpuBondedWork || (fr->efep != efepNO))
+                {
+                    //non-local part of coordinate buffer must be copied back to host for CPU work
+                    nbv->launch_copy_x_from_gpu(as_rvec_array(x.unpaddedArrayRef().data()), Nbnxm::AtomLocality::NonLocal);
+                }
+            }
+            else
+            {
+                dd_move_x(cr->dd, box, x.unpaddedArrayRef(), wcycle);
+            }
 
             if (useGpuXBufOps == BufferOpsUseGpu::True)
             {
                 // The condition here was (pme != nullptr && pme_gpu_get_device_x(fr->pmedata) != nullptr)
-                if (!useGpuPme)
+                if (!useGpuPme && !ddUsesGpuDirectCommunication)
                 {
                     nbv->copyCoordinatesToGpu(Nbnxm::AtomLocality::NonLocal, false,
                                               x.unpaddedArrayRef());
@@ -1403,6 +1435,13 @@ void do_force(FILE                                     *fplog,
         update_QMMMrec(cr, fr, as_rvec_array(x.unpaddedArrayRef().data()), mdatoms, box);
     }
 
+    // TODO Force flags should include haveFreeEnergyWork for this domain
+    if (ddUsesGpuDirectCommunication &&
+        (forceWork.haveCpuBondedWork || (fr->efep != efepNO)))
+    {
+        /* Wait for non-local coordinate data to be copied from device */
+        nbv->wait_nonlocal_x_copy_D2H_done();
+    }
     /* Compute the bonded and non-bonded energies and optionally forces */
     do_force_lowlevel(fr, inputrec, &(top->idef),
                       cr, ms, nrnb, wcycle, mdatoms,
@@ -1419,8 +1458,10 @@ void do_force(FILE                                     *fplog,
                          forceFlags, &forceOut.forceWithVirial(), enerd,
                          ed, forceFlags.doNeighborSearch);
 
-    bool                   useCpuFPmeReduction = thisRankHasDuty(cr, DUTY_PME) && !useGpuPmeFReduction;
-    bool                   haveCpuForces       = (forceWork.haveSpecialForces || forceWork.haveCpuListedForceWork || useCpuFPmeReduction);
+    const bool useCpuFPmeReduction = thisRankHasDuty(cr, DUTY_PME) && !useGpuPmeFReduction;
+    // TODO Force flags should include haveFreeEnergyWork for this domain
+    const bool haveCpuForces       = (forceWork.haveSpecialForces || forceWork.haveCpuListedForceWork ||
+                                      useCpuFPmeReduction || (fr->efep != efepNO));
 
     // Will store the amount of cycles spent waiting for the GPU that
     // will be later used in the DLB accounting.