StatePropagatorDataGpu object to manage GPU forces, positions and velocities buffers
[alexxy/gromacs.git] / src / gromacs / ewald / tests / pmetestcommon.cpp
index d5ce8887984421f78defbeebbebeb1d15deb1fa7..61b28e587ab4cfae47f4e1f7cf19a214ae970cf9 100644 (file)
@@ -169,13 +169,14 @@ PmeSafePointer pmeInitEmpty(const t_inputrec         *inputRec,
 }
 
 //! PME initialization with atom data
-PmeSafePointer pmeInitAtoms(const t_inputrec         *inputRec,
-                            CodePath                  mode,
-                            const gmx_device_info_t  *gpuInfo,
-                            PmeGpuProgramHandle       pmeGpuProgram,
-                            const CoordinatesVector  &coordinates,
-                            const ChargesVector      &charges,
-                            const Matrix3x3          &box
+PmeSafePointer pmeInitAtoms(const t_inputrec                        *inputRec,
+                            CodePath                                 mode,
+                            const gmx_device_info_t                 *gpuInfo,
+                            PmeGpuProgramHandle                      pmeGpuProgram,
+                            const CoordinatesVector                 &coordinates,
+                            const ChargesVector                     &charges,
+                            const Matrix3x3                         &box,
+                            std::shared_ptr<StatePropagatorDataGpu>  stateGpu
                             )
 {
     const index     atomCount = coordinates.size();
@@ -199,7 +200,16 @@ PmeSafePointer pmeInitAtoms(const t_inputrec         *inputRec,
             // We need to set atc->n for passing the size in the tests
             atc->setNumAtoms(atomCount);
             gmx_pme_reinit_atoms(pmeSafe.get(), atomCount, charges.data());
-            pme_gpu_copy_input_coordinates(pmeSafe->gpu, as_rvec_array(coordinates.data()));
+
+            // TODO: Pin the host buffer and use async memory copies
+            stateGpu = std::make_unique<StatePropagatorDataGpu>(pme_gpu_get_device_stream(pmeSafe.get()),
+                                                                pme_gpu_get_device_context(pmeSafe.get()),
+                                                                GpuApiCallBehavior::Sync,
+                                                                pme_gpu_get_padding_size(pmeSafe.get()));
+            stateGpu->reinit(atomCount, atomCount);
+            stateGpu->copyCoordinatesToGpu(arrayRefFromArray(coordinates.data(), coordinates.size()), gmx::StatePropagatorDataGpu::AtomLocality::All);
+            pme_gpu_set_kernelparam_coordinates(pmeSafe->gpu, stateGpu->getCoordinates());
+
             break;
 
         default: