Make DeviceStream into a class
[alexxy/gromacs.git] / src / gromacs / mdlib / settle_gpu.cu
index 20933baf965604f7d72e04f5dd02a12ae3bbeeca..76daf34c1acf718f5942dc52812471ed137d9683 100644 (file)
@@ -434,7 +434,7 @@ void SettleGpu::apply(const float3* d_x,
     {
         // Fill with zeros so the values can be reduced to it
         // Only 6 values are needed because virial is symmetrical
-        clearDeviceBufferAsync(&d_virialScaled_, 0, 6, commandStream_);
+        clearDeviceBufferAsync(&d_virialScaled_, 0, 6, deviceStream_);
     }
 
     auto kernelPtr = getSettleKernelPtr(updateVelocities, computeVirial);
@@ -455,7 +455,7 @@ void SettleGpu::apply(const float3* d_x,
     {
         config.sharedMemorySize = 0;
     }
-    config.stream = commandStream_;
+    config.stream = deviceStream_.stream();
 
     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, &numSettles_, &d_atomIds_,
                                                       &settleParameters_, &d_x, &d_xp, &invdt, &d_v,
@@ -465,7 +465,7 @@ void SettleGpu::apply(const float3* d_x,
 
     if (computeVirial)
     {
-        copyFromDeviceBuffer(h_virialScaled_.data(), &d_virialScaled_, 0, 6, commandStream_,
+        copyFromDeviceBuffer(h_virialScaled_.data(), &d_virialScaled_, 0, 6, deviceStream_,
                              GpuApiCallBehavior::Sync, nullptr);
 
         // Mapping [XX, XY, XZ, YY, YZ, ZZ] internal format to a tensor object
@@ -485,9 +485,9 @@ void SettleGpu::apply(const float3* d_x,
     return;
 }
 
-SettleGpu::SettleGpu(const gmx_mtop_t& mtop, const DeviceContext& deviceContext, CommandStream commandStream) :
+SettleGpu::SettleGpu(const gmx_mtop_t& mtop, const DeviceContext& deviceContext, const DeviceStream& deviceStream) :
     deviceContext_(deviceContext),
-    commandStream_(commandStream)
+    deviceStream_(deviceStream)
 {
     static_assert(sizeof(real) == sizeof(float),
                   "Real numbers should be in single precision in GPU code.");
@@ -622,7 +622,7 @@ void SettleGpu::set(const InteractionDefinitions& idef, const t_mdatoms gmx_unus
         settler.z        = iatoms[i * nral1 + 3]; // Second hydrogen index
         h_atomIds_.at(i) = settler;
     }
-    copyToDeviceBuffer(&d_atomIds_, h_atomIds_.data(), 0, numSettles_, commandStream_,
+    copyToDeviceBuffer(&d_atomIds_, h_atomIds_.data(), 0, numSettles_, deviceStream_,
                        GpuApiCallBehavior::Sync, nullptr);
 }