StatePropagatorDataGpu object to manage GPU forces, positions and velocities buffers
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index 16c891532dc3865e87dd4ea43338cb0f8ffb9662..a77e5751cff80ab82b91a1f749564ae24a09433a 100644 (file)
@@ -86,6 +86,7 @@
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/mdtypes/simulation_workload.h"
 #include "gromacs/mdtypes/state.h"
+#include "gromacs/mdtypes/state_propagator_data_gpu.h"
 #include "gromacs/nbnxm/atomdata.h"
 #include "gromacs/nbnxm/gpu_data_mgmt.h"
 #include "gromacs/nbnxm/nbnxm.h"
@@ -601,7 +602,6 @@ computeSpecialForces(FILE                          *fplog,
  *
  * \param[in]  pmedata              The PME structure
  * \param[in]  box                  The box matrix
- * \param[in]  x                    Coordinate array
  * \param[in]  stepWork             Step schedule flags
  * \param[in]  pmeFlags             PME flags
  * \param[in]  useGpuForceReduction True if GPU-based force reduction is active this step
@@ -609,14 +609,12 @@ computeSpecialForces(FILE                          *fplog,
  */
 static inline void launchPmeGpuSpread(gmx_pme_t          *pmedata,
                                       const matrix        box,
-                                      const rvec          x[],
                                       const StepWorkload &stepWork,
                                       int                 pmeFlags,
                                       bool                useGpuForceReduction,
                                       gmx_wallcycle_t     wcycle)
 {
     pme_gpu_prepare_computation(pmedata, stepWork.haveDynamicBox, box, wcycle, pmeFlags, useGpuForceReduction);
-    pme_gpu_copy_coordinates_to_gpu(pmedata, x, wcycle);
     pme_gpu_launch_spread(pmedata, wcycle);
 }
 
@@ -889,12 +887,13 @@ void do_force(FILE                                     *fplog,
               int                                       legacyFlags,
               const DDBalanceRegionHandler             &ddBalanceRegionHandler)
 {
-    int                  i, j;
-    double               mu[2*DIM];
-    gmx_bool             bFillGrid, bCalcCGCM;
-    gmx_bool             bUseGPU, bUseOrEmulGPU;
-    nonbonded_verlet_t  *nbv = fr->nbv.get();
-    interaction_const_t *ic  = fr->ic;
+    int                          i, j;
+    double                       mu[2*DIM];
+    gmx_bool                     bFillGrid, bCalcCGCM;
+    gmx_bool                     bUseGPU, bUseOrEmulGPU;
+    nonbonded_verlet_t          *nbv      = fr->nbv.get();
+    interaction_const_t         *ic       = fr->ic;
+    gmx::StatePropagatorDataGpu *stateGpu = fr->stateGpu;
 
     // TODO remove the code below when the legacy flags are not in use anymore
     /* modify force flag if not doing nonbonded */
@@ -998,9 +997,27 @@ void do_force(FILE                                     *fplog,
     }
 #endif /* GMX_MPI */
 
+    // Coordinates on the device are needed if PME or BufferOps are offloaded.
+    // The local coordinates can be copied right away.
+    // NOTE: Consider moving this copy to right after they are updated and constrained,
+    //       if the later is not offloaded.
+    if (useGpuPme || useGpuXBufOps == BufferOpsUseGpu::True)
+    {
+        if (stepWork.doNeighborSearch)
+        {
+            stateGpu->reinit(mdatoms->homenr, cr->dd != nullptr ? dd_numAtomsZones(*cr->dd) : mdatoms->homenr);
+            if (useGpuPme)
+            {
+                // TODO: This should be moved into PME setup function ( pme_gpu_prepare_computation(...) )
+                pme_gpu_set_device_x(fr->pmedata, stateGpu->getCoordinates());
+            }
+        }
+        stateGpu->copyCoordinatesToGpu(x.unpaddedArrayRef(), gmx::StatePropagatorDataGpu::AtomLocality::Local);
+    }
+
     if (useGpuPme)
     {
-        launchPmeGpuSpread(fr->pmedata, box, as_rvec_array(x.unpaddedArrayRef().data()), stepWork, pmeFlags, useGpuPmeFReduction, wcycle);
+        launchPmeGpuSpread(fr->pmedata, box, stepWork, pmeFlags, useGpuPmeFReduction, wcycle);
     }
 
     /* do gridding for pair search */
@@ -1124,14 +1141,8 @@ void do_force(FILE                                     *fplog,
     {
         if (useGpuXBufOps == BufferOpsUseGpu::True)
         {
-            // The condition here was (pme != nullptr && pme_gpu_get_device_x(fr->pmedata) != nullptr)
-            if (!useGpuPme)
-            {
-                nbv->copyCoordinatesToGpu(Nbnxm::AtomLocality::Local, false,
-                                          x.unpaddedArrayRef());
-            }
             nbv->convertCoordinatesGpu(Nbnxm::AtomLocality::Local, false,
-                                       useGpuPme ? pme_gpu_get_device_x(fr->pmedata) : nbv->getDeviceCoordinates());
+                                       stateGpu->getCoordinates());
         }
         else
         {
@@ -1210,9 +1221,7 @@ void do_force(FILE                                     *fplog,
             wallcycle_stop(wcycle, ewcNS);
             if (ddUsesGpuDirectCommunication)
             {
-                rvec* d_x    = static_cast<rvec *> (nbv->get_gpu_xrvec());
-                rvec* d_f    = static_cast<rvec *> (nbv->get_gpu_frvec());
-                gpuHaloExchange->reinitHalo(d_x, d_f);
+                gpuHaloExchange->reinitHalo(stateGpu->getCoordinates(), stateGpu->getForces());
             }
         }
         else
@@ -1226,7 +1235,7 @@ void do_force(FILE                                     *fplog,
                 if (domainWork.haveCpuBondedWork || domainWork.haveFreeEnergyWork)
                 {
                     //non-local part of coordinate buffer must be copied back to host for CPU work
-                    nbv->launch_copy_x_from_gpu(as_rvec_array(x.unpaddedArrayRef().data()), Nbnxm::AtomLocality::NonLocal);
+                    stateGpu->copyCoordinatesFromGpu(x.unpaddedArrayRef(), gmx::StatePropagatorDataGpu::AtomLocality::NonLocal);
                 }
             }
             else
@@ -1239,11 +1248,10 @@ void do_force(FILE                                     *fplog,
                 // The condition here was (pme != nullptr && pme_gpu_get_device_x(fr->pmedata) != nullptr)
                 if (!useGpuPme && !ddUsesGpuDirectCommunication)
                 {
-                    nbv->copyCoordinatesToGpu(Nbnxm::AtomLocality::NonLocal, false,
-                                              x.unpaddedArrayRef());
+                    stateGpu->copyCoordinatesToGpu(x.unpaddedArrayRef(), gmx::StatePropagatorDataGpu::AtomLocality::NonLocal);
                 }
                 nbv->convertCoordinatesGpu(Nbnxm::AtomLocality::NonLocal, false,
-                                           useGpuPme ? pme_gpu_get_device_x(fr->pmedata) : nbv->getDeviceCoordinates());
+                                           stateGpu->getCoordinates());
             }
             else
             {
@@ -1494,17 +1502,16 @@ void do_force(FILE                                     *fplog,
                 // which are a dependency for the GPU force reduction.
                 bool  haveNonLocalForceContribInCpuBuffer = domainWork.haveCpuBondedWork || domainWork.haveFreeEnergyWork;
 
-                rvec *f = as_rvec_array(forceWithShiftForces.force().data());
                 if (haveNonLocalForceContribInCpuBuffer)
                 {
-                    nbv->launch_copy_f_to_gpu(f, Nbnxm::AtomLocality::NonLocal);
+                    stateGpu->copyForcesToGpu(forceOut.forceWithShiftForces().force(), gmx::StatePropagatorDataGpu::AtomLocality::NonLocal);
                 }
                 nbv->atomdata_add_nbat_f_to_f_gpu(Nbnxm::AtomLocality::NonLocal,
-                                                  nbv->getDeviceForces(),
+                                                  stateGpu->getForces(),
                                                   pme_gpu_get_device_f(fr->pmedata),
                                                   pme_gpu_get_f_ready_synchronizer(fr->pmedata),
                                                   useGpuPmeFReduction, haveNonLocalForceContribInCpuBuffer);
-                nbv->launch_copy_f_from_gpu(f, Nbnxm::AtomLocality::NonLocal);
+                stateGpu->copyForcesFromGpu(forceOut.forceWithShiftForces().force(), gmx::StatePropagatorDataGpu::AtomLocality::NonLocal);
             }
             else
             {
@@ -1538,17 +1545,14 @@ void do_force(FILE                                     *fplog,
 
         if (stepWork.computeForces)
         {
-            gmx::ArrayRef<gmx::RVec>  force  = forceOut.forceWithShiftForces().force();
-            rvec                     *f      = as_rvec_array(force.data());
 
             if (useGpuForcesHaloExchange)
             {
                 if (haveCpuLocalForces)
                 {
-                    nbv->launch_copy_f_to_gpu(f, Nbnxm::AtomLocality::Local);
+                    stateGpu->copyForcesToGpu(forceOut.forceWithShiftForces().force(), gmx::StatePropagatorDataGpu::AtomLocality::Local);
                 }
-                bool accumulateHaloForces = haveCpuLocalForces;
-                gpuHaloExchange->communicateHaloForces(accumulateHaloForces);
+                gpuHaloExchange->communicateHaloForces(haveCpuLocalForces);
             }
             else
             {
@@ -1643,10 +1647,9 @@ void do_force(FILE                                     *fplog,
             // - copy is not perfomed if GPU force halo exchange is active, because it would overwrite the result
             //   of the halo exchange. In that case the copy is instead performed above, before the exchange.
             //   These should be unified.
-            rvec *f = as_rvec_array(forceWithShift.data());
             if (haveLocalForceContribInCpuBuffer && !useGpuForcesHaloExchange)
             {
-                nbv->launch_copy_f_to_gpu(f, Nbnxm::AtomLocality::Local);
+                stateGpu->copyForcesToGpu(forceWithShift, gmx::StatePropagatorDataGpu::AtomLocality::Local);
             }
             if (useGpuForcesHaloExchange)
             {
@@ -1658,12 +1661,13 @@ void do_force(FILE                                     *fplog,
                 nbv->stream_local_wait_for_nonlocal();
             }
             nbv->atomdata_add_nbat_f_to_f_gpu(Nbnxm::AtomLocality::Local,
-                                              nbv->getDeviceForces(),
+                                              stateGpu->getForces(),
                                               pme_gpu_get_device_f(fr->pmedata),
                                               pme_gpu_get_f_ready_synchronizer(fr->pmedata),
                                               useGpuPmeFReduction, haveLocalForceContribInCpuBuffer);
-            nbv->launch_copy_f_from_gpu(f, Nbnxm::AtomLocality::Local);
+            // This function call synchronizes the local stream
             nbv->wait_for_gpu_force_reduction(Nbnxm::AtomLocality::Local);
+            stateGpu->copyForcesFromGpu(forceWithShift, gmx::StatePropagatorDataGpu::AtomLocality::Local);
         }
         else
         {