Use DeviceBuffer in GPU update and NBNXM code
[alexxy/gromacs.git] / src / gromacs / mdlib / lincs_gpu.cu
index 0967c20781b0470c595c3274d59f7eca877ad852..466c250f4c72eccba7c220f5d2da8b0faa68aa92 100644 (file)
@@ -59,7 +59,8 @@
 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
 #include "gromacs/gpu_utils/cudautils.cuh"
 #include "gromacs/gpu_utils/devicebuffer.cuh"
-#include "gromacs/gpu_utils/gputraits.cuh"
+#include "gromacs/gpu_utils/gputraits.h"
+#include "gromacs/gpu_utils/typecasts.cuh"
 #include "gromacs/gpu_utils/vectype_ops.cuh"
 #include "gromacs/math/functions.h"
 #include "gromacs/math/vec.h"
@@ -427,14 +428,14 @@ inline auto getLincsKernelPtr(const bool updateVelocities, const bool computeVir
     return kernelPtr;
 }
 
-void LincsGpu::apply(const float3* d_x,
-                     float3*       d_xp,
-                     const bool    updateVelocities,
-                     float3*       d_v,
-                     const real    invdt,
-                     const bool    computeVirial,
-                     tensor        virialScaled,
-                     const PbcAiuc pbcAiuc)
+void LincsGpu::apply(const DeviceBuffer<Float3> d_x,
+                     DeviceBuffer<Float3>       d_xp,
+                     const bool                 updateVelocities,
+                     DeviceBuffer<Float3>       d_v,
+                     const real                 invdt,
+                     const bool                 computeVirial,
+                     tensor                     virialScaled,
+                     const PbcAiuc              pbcAiuc)
 {
     ensureNoPendingDeviceError("In CUDA version of LINCS");
 
@@ -479,8 +480,13 @@ void LincsGpu::apply(const float3* d_x,
 
     kernelParams_.pbcAiuc = pbcAiuc;
 
-    const auto kernelArgs =
-            prepareGpuKernelArguments(kernelPtr, config, &kernelParams_, &d_x, &d_xp, &d_v, &invdt);
+    const auto kernelArgs = prepareGpuKernelArguments(kernelPtr,
+                                                      config,
+                                                      &kernelParams_,
+                                                      asFloat3Pointer(&d_x),
+                                                      asFloat3Pointer(&d_xp),
+                                                      asFloat3Pointer(&d_v),
+                                                      &invdt);
 
     launchGpuKernel(kernelPtr,
                     config,