Make use of the DeviceStreamManager
[alexxy/gromacs.git] / src / gromacs / ewald / tests / pmetestcommon.cpp
index eaf697e1d5fd8d82ca7faaa39e8a32db4ab5ed67..888cac58734ead34abf88a1606470f1900d0c893 100644 (file)
@@ -59,6 +59,7 @@
 #include "gromacs/ewald/pme_solve.h"
 #include "gromacs/ewald/pme_spread.h"
 #include "gromacs/fft/parallel_3dfft.h"
+#include "gromacs/gpu_utils/device_stream_manager.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/math/invertmatrix.h"
 #include "gromacs/mdtypes/commrec.h"
@@ -106,21 +107,22 @@ uint64_t getSplineModuliDoublePrecisionUlps(int splineOrder)
 }
 
 //! PME initialization
-PmeSafePointer pmeInitWrapper(const t_inputrec*        inputRec,
-                              const CodePath           mode,
-                              const DeviceInformation* deviceInfo,
-                              const PmeGpuProgram*     pmeGpuProgram,
-                              const Matrix3x3&         box,
-                              const real               ewaldCoeff_q,
-                              const real               ewaldCoeff_lj)
+PmeSafePointer pmeInitWrapper(const t_inputrec*    inputRec,
+                              const CodePath       mode,
+                              const DeviceContext* deviceContext,
+                              const DeviceStream*  deviceStream,
+                              const PmeGpuProgram* pmeGpuProgram,
+                              const Matrix3x3&     box,
+                              const real           ewaldCoeff_q,
+                              const real           ewaldCoeff_lj)
 {
     const MDLogger dummyLogger;
     const auto     runMode       = (mode == CodePath::CPU) ? PmeRunMode::CPU : PmeRunMode::Mixed;
     t_commrec      dummyCommrec  = { 0 };
     NumPmeDomains  numPmeDomains = { 1, 1 };
-    gmx_pme_t*     pmeDataRaw =
-            gmx_pme_init(&dummyCommrec, numPmeDomains, inputRec, false, false, true, ewaldCoeff_q,
-                         ewaldCoeff_lj, 1, runMode, nullptr, deviceInfo, pmeGpuProgram, dummyLogger);
+    gmx_pme_t* pmeDataRaw = gmx_pme_init(&dummyCommrec, numPmeDomains, inputRec, false, false, true,
+                                         ewaldCoeff_q, ewaldCoeff_lj, 1, runMode, nullptr,
+                                         deviceContext, deviceStream, pmeGpuProgram, dummyLogger);
     PmeSafePointer pme(pmeDataRaw); // taking ownership
 
     // TODO get rid of this with proper matrix type
@@ -151,33 +153,35 @@ PmeSafePointer pmeInitWrapper(const t_inputrec*        inputRec,
 }
 
 //! Simple PME initialization based on input, no atom data
-PmeSafePointer pmeInitEmpty(const t_inputrec*        inputRec,
-                            const CodePath           mode,
-                            const DeviceInformation* deviceInfo,
-                            const PmeGpuProgram*     pmeGpuProgram,
-                            const Matrix3x3&         box,
-                            const real               ewaldCoeff_q,
-                            const real               ewaldCoeff_lj)
-{
-    return pmeInitWrapper(inputRec, mode, deviceInfo, pmeGpuProgram, box, ewaldCoeff_q, ewaldCoeff_lj);
+PmeSafePointer pmeInitEmpty(const t_inputrec*    inputRec,
+                            const CodePath       mode,
+                            const DeviceContext* deviceContext,
+                            const DeviceStream*  deviceStream,
+                            const PmeGpuProgram* pmeGpuProgram,
+                            const Matrix3x3&     box,
+                            const real           ewaldCoeff_q,
+                            const real           ewaldCoeff_lj)
+{
+    return pmeInitWrapper(inputRec, mode, deviceContext, deviceStream, pmeGpuProgram, box,
+                          ewaldCoeff_q, ewaldCoeff_lj);
     // hiding the fact that PME actually needs to know the number of atoms in advance
 }
 
 PmeSafePointer pmeInitEmpty(const t_inputrec* inputRec)
 {
     const Matrix3x3 defaultBox = { { 1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F } };
-    return pmeInitWrapper(inputRec, CodePath::CPU, nullptr, nullptr, defaultBox, 0.0F, 0.0F);
+    return pmeInitWrapper(inputRec, CodePath::CPU, nullptr, nullptr, nullptr, defaultBox, 0.0F, 0.0F);
 }
 
 //! Make a GPU state-propagator manager
 std::unique_ptr<StatePropagatorDataGpu> makeStatePropagatorDataGpu(const gmx_pme_t&     pme,
-                                                                   const DeviceContext& deviceContext)
+                                                                   const DeviceContext* deviceContext,
+                                                                   const DeviceStream* deviceStream)
 {
     // TODO: Pin the host buffer and use async memory copies
     // TODO: Special constructor for PME-only rank / PME-tests is used here. There should be a mechanism to
     //       restrict one from using other constructor here.
-    return std::make_unique<StatePropagatorDataGpu>(pme_gpu_get_device_stream(&pme), deviceContext,
-                                                    GpuApiCallBehavior::Sync,
+    return std::make_unique<StatePropagatorDataGpu>(deviceStream, *deviceContext, GpuApiCallBehavior::Sync,
                                                     pme_gpu_get_block_size(&pme), nullptr);
 }