Make DeviceContext into a proper class
[alexxy/gromacs.git] / src / gromacs / mdtypes / state_propagator_data_gpu_impl_gpu.cpp
index 4b385a5a7b56ca5b82a22c57bb648a127b6b731d..d88f469711dcc5343560df94729a8455e7363d23 100644 (file)
@@ -50,6 +50,7 @@
 #        include "gromacs/gpu_utils/cudautils.cuh"
 #    endif
 #    include "gromacs/gpu_utils/devicebuffer.h"
+#    include "gromacs/gpu_utils/gputraits.h"
 #    if GMX_GPU == GMX_GPU_OPENCL
 #        include "gromacs/gpu_utils/oclutils.h"
 #    endif
 namespace gmx
 {
 
-StatePropagatorDataGpu::Impl::Impl(const void*        pmeStream,
-                                   const void*        localStream,
-                                   const void*        nonLocalStream,
-                                   const void*        deviceContext,
-                                   GpuApiCallBehavior transferKind,
-                                   int                paddingSize,
-                                   gmx_wallcycle*     wcycle) :
+StatePropagatorDataGpu::Impl::Impl(const void*          pmeStream,
+                                   const void*          localStream,
+                                   const void*          nonLocalStream,
+                                   const DeviceContext& deviceContext,
+                                   GpuApiCallBehavior   transferKind,
+                                   int                  paddingSize,
+                                   gmx_wallcycle*       wcycle) :
+    deviceContext_(deviceContext),
     transferKind_(transferKind),
     paddingSize_(paddingSize),
     wcycle_(wcycle)
@@ -81,13 +83,11 @@ StatePropagatorDataGpu::Impl::Impl(const void*        pmeStream,
     // TODO: Refactor when the StreamManager is introduced.
     if (GMX_GPU == GMX_GPU_OPENCL)
     {
-        GMX_ASSERT(deviceContext != nullptr, "GPU context should be set in OpenCL builds.");
         GMX_ASSERT(pmeStream != nullptr, "GPU PME stream should be set in OpenCL builds.");
 
         // The update stream is set to the PME stream in OpenCL, since PME stream is the only stream created in the PME context.
-        pmeStream_     = *static_cast<const CommandStream*>(pmeStream);
-        updateStream_  = *static_cast<const CommandStream*>(pmeStream);
-        deviceContext_ = *static_cast<const DeviceContext*>(deviceContext);
+        pmeStream_    = *static_cast<const CommandStream*>(pmeStream);
+        updateStream_ = *static_cast<const CommandStream*>(pmeStream);
         GMX_UNUSED_VALUE(localStream);
         GMX_UNUSED_VALUE(nonLocalStream);
     }
@@ -113,7 +113,6 @@ StatePropagatorDataGpu::Impl::Impl(const void*        pmeStream,
         stat = cudaStreamCreate(&updateStream_);
         CU_RET_ERR(stat, "CUDA stream creation failed in StatePropagatorDataGpu");
 #    endif
-        GMX_UNUSED_VALUE(deviceContext);
     }
 
     // Map the atom locality to the stream that will be used for coordinates,
@@ -132,11 +131,12 @@ StatePropagatorDataGpu::Impl::Impl(const void*        pmeStream,
     fCopyStreams_[AtomLocality::All]      = updateStream_;
 }
 
-StatePropagatorDataGpu::Impl::Impl(const void*        pmeStream,
-                                   const void*        deviceContext,
-                                   GpuApiCallBehavior transferKind,
-                                   int                paddingSize,
-                                   gmx_wallcycle*     wcycle) :
+StatePropagatorDataGpu::Impl::Impl(const void*          pmeStream,
+                                   const DeviceContext& deviceContext,
+                                   GpuApiCallBehavior   transferKind,
+                                   int                  paddingSize,
+                                   gmx_wallcycle*       wcycle) :
+    deviceContext_(deviceContext),
     transferKind_(transferKind),
     paddingSize_(paddingSize),
     wcycle_(wcycle)
@@ -144,12 +144,6 @@ StatePropagatorDataGpu::Impl::Impl(const void*        pmeStream,
     static_assert(GMX_GPU != GMX_GPU_NONE,
                   "This object should only be constructed on the GPU code-paths.");
 
-    if (GMX_GPU == GMX_GPU_OPENCL)
-    {
-        GMX_ASSERT(deviceContext != nullptr, "GPU context should be set in OpenCL builds.");
-        deviceContext_ = *static_cast<const DeviceContext*>(deviceContext);
-    }
-
     GMX_ASSERT(pmeStream != nullptr, "GPU PME stream should be set.");
     pmeStream_ = *static_cast<const CommandStream*>(pmeStream);
 
@@ -551,22 +545,22 @@ int StatePropagatorDataGpu::Impl::numAtomsAll()
 }
 
 
-StatePropagatorDataGpu::StatePropagatorDataGpu(const void*        pmeStream,
-                                               const void*        localStream,
-                                               const void*        nonLocalStream,
-                                               const void*        deviceContext,
-                                               GpuApiCallBehavior transferKind,
-                                               int                paddingSize,
-                                               gmx_wallcycle*     wcycle) :
+StatePropagatorDataGpu::StatePropagatorDataGpu(const void*          pmeStream,
+                                               const void*          localStream,
+                                               const void*          nonLocalStream,
+                                               const DeviceContext& deviceContext,
+                                               GpuApiCallBehavior   transferKind,
+                                               int                  paddingSize,
+                                               gmx_wallcycle*       wcycle) :
     impl_(new Impl(pmeStream, localStream, nonLocalStream, deviceContext, transferKind, paddingSize, wcycle))
 {
 }
 
-StatePropagatorDataGpu::StatePropagatorDataGpu(const void*        pmeStream,
-                                               const void*        deviceContext,
-                                               GpuApiCallBehavior transferKind,
-                                               int                paddingSize,
-                                               gmx_wallcycle*     wcycle) :
+StatePropagatorDataGpu::StatePropagatorDataGpu(const void*          pmeStream,
+                                               const DeviceContext& deviceContext,
+                                               GpuApiCallBehavior   transferKind,
+                                               int                  paddingSize,
+                                               gmx_wallcycle*       wcycle) :
     impl_(new Impl(pmeStream, deviceContext, transferKind, paddingSize, wcycle))
 {
 }