Make use of the DeviceStreamManager
[alexxy/gromacs.git] / src / gromacs / mdtypes / state_propagator_data_gpu_impl_gpu.cpp
index e60e9fa73b7eb7ded130cccc5d27699d6c63da84..bf927f2da297084caae44ef98a499038fabbccc7 100644 (file)
 
 #if GMX_GPU != GMX_GPU_NONE
 
-#    if GMX_GPU == GMX_GPU_CUDA
-#        include "gromacs/gpu_utils/cudautils.cuh"
-#    endif
+#    include "gromacs/gpu_utils/device_stream_manager.h"
 #    include "gromacs/gpu_utils/devicebuffer.h"
 #    include "gromacs/gpu_utils/gputraits.h"
-#    if GMX_GPU == GMX_GPU_OPENCL
-#        include "gromacs/gpu_utils/oclutils.h"
-#    endif
 #    include "gromacs/math/vectypes.h"
 #    include "gromacs/mdtypes/state_propagator_data_gpu.h"
 #    include "gromacs/timing/wallcycle.h"
 namespace gmx
 {
 
-StatePropagatorDataGpu::Impl::Impl(const DeviceStream*  pmeStream,
-                                   const DeviceStream*  localStream,
-                                   const DeviceStream*  nonLocalStream,
-                                   const DeviceContext& deviceContext,
-                                   GpuApiCallBehavior   transferKind,
-                                   int                  allocationBlockSizeDivisor,
-                                   gmx_wallcycle*       wcycle) :
-    deviceContext_(deviceContext),
+StatePropagatorDataGpu::Impl::Impl(const DeviceStreamManager& deviceStreamManager,
+                                   GpuApiCallBehavior         transferKind,
+                                   int                        allocationBlockSizeDivisor,
+                                   gmx_wallcycle*             wcycle) :
+    deviceContext_(deviceStreamManager.context()),
     transferKind_(transferKind),
     allocationBlockSizeDivisor_(allocationBlockSizeDivisor),
     wcycle_(wcycle)
 {
-    static_assert(GMX_GPU != GMX_GPU_NONE,
-                  "This object should only be constructed on the GPU code-paths.");
+    static_assert(
+            GMX_GPU != GMX_GPU_NONE,
+            "GPU state propagator data object should only be constructed on the GPU code-paths.");
 
-    // TODO: Refactor when the StreamManager is introduced.
+    // We need to keep local copies for re-initialization.
+    pmeStream_      = &deviceStreamManager.stream(DeviceStreamType::Pme);
+    localStream_    = &deviceStreamManager.stream(DeviceStreamType::NonBondedLocal);
+    nonLocalStream_ = &deviceStreamManager.stream(DeviceStreamType::NonBondedNonLocal);
+    // PME stream is used in OpenCL for H2D coordinate transfer
     if (GMX_GPU == GMX_GPU_OPENCL)
     {
-        GMX_ASSERT(pmeStream != nullptr, "GPU PME stream should be set in OpenCL builds.");
-
-        // The update stream is set to the PME stream in OpenCL, since PME stream is the only stream created in the PME context.
-        pmeStream_    = pmeStream;
-        updateStream_ = pmeStream;
-        GMX_UNUSED_VALUE(localStream);
-        GMX_UNUSED_VALUE(nonLocalStream);
+        updateStream_ = &deviceStreamManager.stream(DeviceStreamType::Pme);
     }
-
-    if (GMX_GPU == GMX_GPU_CUDA)
+    else
     {
-        if (pmeStream != nullptr)
-        {
-            pmeStream_ = pmeStream;
-        }
-        if (localStream != nullptr)
-        {
-            localStream_ = localStream;
-        }
-        if (nonLocalStream != nullptr)
-        {
-            nonLocalStream_ = nonLocalStream;
-        }
-
-        // TODO: The update stream should be created only when it is needed.
-#    if (GMX_GPU == GMX_GPU_CUDA)
-        // In CUDA we only need priority to create stream.
-        // (note that this will be moved from here in the follow-up patch)
-        updateStreamOwn_.init(deviceContext, DeviceStreamPriority::Normal, false);
-        updateStream_ = &updateStreamOwn_;
-#    endif
+        updateStream_ = &deviceStreamManager.stream(DeviceStreamType::UpdateAndConstraints);
     }
 
     // Map the atom locality to the stream that will be used for coordinates,
@@ -142,10 +113,11 @@ StatePropagatorDataGpu::Impl::Impl(const DeviceStream*  pmeStream,
     allocationBlockSizeDivisor_(allocationBlockSizeDivisor),
     wcycle_(wcycle)
 {
-    static_assert(GMX_GPU != GMX_GPU_NONE,
-                  "This object should only be constructed on the GPU code-paths.");
+    static_assert(
+            GMX_GPU != GMX_GPU_NONE,
+            "GPU state propagator data object should only be constructed on the GPU code-paths.");
 
-    GMX_ASSERT(pmeStream != nullptr, "GPU PME stream should be set.");
+    GMX_ASSERT(pmeStream->isValid(), "GPU PME stream should be valid.");
     pmeStream_      = pmeStream;
     localStream_    = pmeStream; // For clearing the force buffer
     nonLocalStream_ = nullptr;
@@ -256,8 +228,7 @@ void StatePropagatorDataGpu::Impl::copyToDevice(DeviceBuffer<RVec>
 
     GMX_ASSERT(dataSize >= 0, "Trying to copy to device buffer before it was allocated.");
 
-    GMX_ASSERT(deviceStream.stream() != nullptr,
-               "No stream is valid for copying with given atom locality.");
+    GMX_ASSERT(deviceStream.isValid(), "No stream is valid for copying with given atom locality.");
     wallcycle_start_nocount(wcycle_, ewcLAUNCH_GPU);
     wallcycle_sub_start(wcycle_, ewcsLAUNCH_STATE_PROPAGATOR_DATA);
 
@@ -291,8 +262,7 @@ void StatePropagatorDataGpu::Impl::copyFromDevice(gmx::ArrayRef<gmx::RVec> h_dat
 
     GMX_ASSERT(dataSize >= 0, "Trying to copy from device buffer before it was allocated.");
 
-    GMX_ASSERT(deviceStream.stream() != nullptr,
-               "No stream is valid for copying with given atom locality.");
+    GMX_ASSERT(deviceStream.isValid(), "No stream is valid for copying with given atom locality.");
     wallcycle_start_nocount(wcycle_, ewcLAUNCH_GPU);
     wallcycle_sub_start(wcycle_, ewcsLAUNCH_STATE_PROPAGATOR_DATA);
 
@@ -546,14 +516,11 @@ int StatePropagatorDataGpu::Impl::numAtomsAll()
 }
 
 
-StatePropagatorDataGpu::StatePropagatorDataGpu(const DeviceStream*  pmeStream,
-                                               const DeviceStream*  localStream,
-                                               const DeviceStream*  nonLocalStream,
-                                               const DeviceContext& deviceContext,
-                                               GpuApiCallBehavior   transferKind,
-                                               int                  allocationBlockSizeDivisor,
-                                               gmx_wallcycle*       wcycle) :
-    impl_(new Impl(pmeStream, localStream, nonLocalStream, deviceContext, transferKind, allocationBlockSizeDivisor, wcycle))
+StatePropagatorDataGpu::StatePropagatorDataGpu(const DeviceStreamManager& deviceStreamManager,
+                                               GpuApiCallBehavior         transferKind,
+                                               int            allocationBlockSizeDivisor,
+                                               gmx_wallcycle* wcycle) :
+    impl_(new Impl(deviceStreamManager, transferKind, allocationBlockSizeDivisor, wcycle))
 {
 }