StatePropagatorDataGpu object to manage GPU forces, positions and velocities buffers
[alexxy/gromacs.git] / src / gromacs / ewald / pme_only.cpp
index 1a872bb383eb3a3ca16cdbb57482447343ae57cb..eb81ebb79f90b9d458ed94da5c32a0c0178b9890 100644 (file)
@@ -84,6 +84,7 @@
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/forceoutput.h"
 #include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/mdtypes/state_propagator_data_gpu.h"
 #include "gromacs/timing/cyclecounter.h"
 #include "gromacs/timing/wallcycle.h"
 #include "gromacs/utility/fatalerror.h"
@@ -543,15 +544,21 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
     std::vector<gmx_pme_t *> pmedata;
     pmedata.push_back(pme);
 
-    auto       pme_pp       = gmx_pme_pp_init(cr);
+    auto        pme_pp       = gmx_pme_pp_init(cr);
     //TODO the variable below should be queried from the task assignment info
-    const bool useGpuForPme = (runMode == PmeRunMode::GPU) || (runMode == PmeRunMode::Mixed);
+    const bool  useGpuForPme   = (runMode == PmeRunMode::GPU) || (runMode == PmeRunMode::Mixed);
+    const void *commandStream  = useGpuForPme ? pme_gpu_get_device_context(pme) : nullptr;
+    const void *gpuContext     = useGpuForPme ? pme_gpu_get_device_stream(pme) : nullptr;
+    const int   paddingSize    = pme_gpu_get_padding_size(pme);
     if (useGpuForPme)
     {
         changePinningPolicy(&pme_pp->chargeA, pme_get_pinning_policy());
         changePinningPolicy(&pme_pp->x, pme_get_pinning_policy());
     }
 
+    // Unconditionally initialize the StatePropagatorDataGpu object to get more verbose message if it is used from CPU builds
+    auto stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(commandStream, gpuContext, GpuApiCallBehavior::Sync, paddingSize);
+
     clear_nrnb(mynrnb);
 
     count = 0;
@@ -585,6 +592,11 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
             if (atomSetChanged)
             {
                 gmx_pme_reinit_atoms(pme, natoms, pme_pp->chargeA.data());
+                if (useGpuForPme)
+                {
+                    stateGpu->reinit(natoms, natoms);
+                    pme_gpu_set_device_x(pme, stateGpu->getCoordinates());
+                }
             }
 
             if (ret == pmerecvqxRESETCOUNTERS)
@@ -625,7 +637,8 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
             //TODO this should be set properly by gmx_pme_recv_coeffs_coords,
             // or maybe use inputrecDynamicBox(ir), at the very least - change this when this codepath is tested!
             pme_gpu_prepare_computation(pme, boxChanged, box, wcycle, pmeFlags, useGpuPmeForceReduction);
-            pme_gpu_copy_coordinates_to_gpu(pme, as_rvec_array(pme_pp->x.data()), wcycle);
+            stateGpu->copyCoordinatesToGpu(gmx::ArrayRef<gmx::RVec>(pme_pp->x), gmx::StatePropagatorDataGpu::AtomLocality::All);
+
             pme_gpu_launch_spread(pme, wcycle);
             pme_gpu_launch_complex_transforms(pme, wcycle);
             pme_gpu_launch_gather(pme, wcycle, PmeForceOutputHandling::Set);