StatePropagatorDataGpu object to manage GPU forces, positions and velocities buffers
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu.cpp
index ddbc2f9485be5afab91754ade5734f4e99c2aac5..4685913dbde8e7922d0bff0e983f875b96fb4b94 100644 (file)
@@ -174,23 +174,6 @@ void pme_gpu_prepare_computation(gmx_pme_t            *pme,
     }
 }
 
-void pme_gpu_copy_coordinates_to_gpu(gmx_pme_t            *pme,
-                                     const rvec           *coordinatesHost,
-                                     gmx_wallcycle        *wcycle)
-{
-    GMX_ASSERT(pme_gpu_active(pme), "This should be a GPU run of PME but it is not enabled.");
-
-    PmeGpu *pmeGpu = pme->gpu;
-
-    // The only spot of PME GPU where LAUNCH_GPU counter increases call-count
-    wallcycle_start(wcycle, ewcLAUNCH_GPU);
-    // The only spot of PME GPU where ewcsLAUNCH_GPU_PME subcounter increases call-count
-    wallcycle_sub_start(wcycle, ewcsLAUNCH_GPU_PME);
-    pme_gpu_copy_input_coordinates(pmeGpu, coordinatesHost);
-    wallcycle_sub_stop(wcycle, ewcsLAUNCH_GPU_PME);
-    wallcycle_stop(wcycle, ewcLAUNCH_GPU);
-}
-
 void pme_gpu_launch_spread(gmx_pme_t            *pme,
                            gmx_wallcycle        *wcycle)
 {
@@ -444,6 +427,15 @@ void *pme_gpu_get_device_f(const gmx_pme_t *pme)
     return pme_gpu_get_kernelparam_forces(pme->gpu);
 }
 
+void pme_gpu_set_device_x(const gmx_pme_t     *pme,
+                          DeviceBuffer<float>  d_x)
+{
+    GMX_ASSERT(pme != nullptr, "Null pointer is passed as a PME to the set coordinates function.");
+    GMX_ASSERT(pme_gpu_active(pme), "This should be a GPU run of PME but it is not enabled.");
+
+    pme_gpu_set_kernelparam_coordinates(pme->gpu, d_x);
+}
+
 void *pme_gpu_get_device_stream(const gmx_pme_t *pme)
 {
     if (!pme || !pme_gpu_active(pme))
@@ -453,6 +445,15 @@ void *pme_gpu_get_device_stream(const gmx_pme_t *pme)
     return pme_gpu_get_stream(pme->gpu);
 }
 
+void *pme_gpu_get_device_context(const gmx_pme_t *pme)
+{
+    if (!pme || !pme_gpu_active(pme))
+    {
+        return nullptr;
+    }
+    return pme_gpu_get_context(pme->gpu);
+}
+
 GpuEventSynchronizer * pme_gpu_get_f_ready_synchronizer(const gmx_pme_t *pme)
 {
     if (!pme || !pme_gpu_active(pme))