Make DeviceContext into a proper class
[alexxy/gromacs.git] / src / gromacs / mdlib / leapfrog_gpu.cu
index 61bc231e3660262406ed3e8ef49c79595a1fa040..b77162c1af47f363499338346efc02826f8057af 100644 (file)
@@ -316,7 +316,9 @@ void LeapFrogGpu::integrate(const float3*                     d_x,
     return;
 }
 
-LeapFrogGpu::LeapFrogGpu(CommandStream commandStream) : commandStream_(commandStream)
+LeapFrogGpu::LeapFrogGpu(const DeviceContext& deviceContext, CommandStream commandStream) :
+    deviceContext_(deviceContext),
+    commandStream_(commandStream)
 {
     numAtoms_ = 0;
 
@@ -342,7 +344,7 @@ void LeapFrogGpu::set(const t_mdatoms& md, const int numTempScaleValues, const u
     numTempScaleValues_ = numTempScaleValues;
 
     reallocateDeviceBuffer(&d_inverseMasses_, numAtoms_, &numInverseMasses_,
-                           &numInverseMassesAlloc_, nullptr);
+                           &numInverseMassesAlloc_, deviceContext_);
     copyToDeviceBuffer(&d_inverseMasses_, (float*)md.invmass, 0, numAtoms_, commandStream_,
                        GpuApiCallBehavior::Sync, nullptr);
 
@@ -350,7 +352,7 @@ void LeapFrogGpu::set(const t_mdatoms& md, const int numTempScaleValues, const u
     if (numTempScaleValues > 1)
     {
         reallocateDeviceBuffer(&d_tempScaleGroups_, numAtoms_, &numTempScaleGroups_,
-                               &numTempScaleGroupsAlloc_, nullptr);
+                               &numTempScaleGroupsAlloc_, deviceContext_);
         copyToDeviceBuffer(&d_tempScaleGroups_, tempScaleGroups, 0, numAtoms_, commandStream_,
                            GpuApiCallBehavior::Sync, nullptr);
     }
@@ -359,7 +361,8 @@ void LeapFrogGpu::set(const t_mdatoms& md, const int numTempScaleValues, const u
     if (numTempScaleValues_ > 0)
     {
         h_lambdas_.resize(numTempScaleValues);
-        reallocateDeviceBuffer(&d_lambdas_, numTempScaleValues_, &numLambdas_, &numLambdasAlloc_, nullptr);
+        reallocateDeviceBuffer(&d_lambdas_, numTempScaleValues_, &numLambdas_, &numLambdasAlloc_,
+                               deviceContext_);
     }
 }