Wrap more device pointers in DeviceBuffer
[alexxy/gromacs.git] / src / gromacs / mdlib / gpuforcereduction_impl.cu
index f95f6f1439c36896d8bb67e0eb41cfe0d52e922f..ac89b47d24712f8bb8474b25fad098871809b185 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2020, by the GROMACS development team, led by
+ * Copyright (c) 2020,2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -43,7 +43,7 @@
 
 #include "gmxpre.h"
 
-#include "gpuforcereduction_impl.cuh"
+#include "gpuforcereduction_impl.h"
 
 #include <stdio.h>
 
@@ -112,15 +112,15 @@ GpuForceReduction::Impl::Impl(const DeviceContext& deviceContext,
     deviceStream_(deviceStream),
     wcycle_(wcycle){};
 
-void GpuForceReduction::Impl::reinit(float3*               baseForcePtr,
-                                     const int             numAtoms,
-                                     ArrayRef<const int>   cell,
-                                     const int             atomStart,
-                                     const bool            accumulate,
-                                     GpuEventSynchronizer* completionMarker)
+void GpuForceReduction::Impl::reinit(DeviceBuffer<gmx::RVec> baseForce,
+                                     const int               numAtoms,
+                                     ArrayRef<const int>     cell,
+                                     const int               atomStart,
+                                     const bool              accumulate,
+                                     GpuEventSynchronizer*   completionMarker)
 {
-    GMX_ASSERT((baseForcePtr != nullptr), "Input base force for reduction has no data");
-    baseForce_        = &(baseForcePtr[atomStart]);
+    GMX_ASSERT((baseForce != nullptr), "Input base force for reduction has no data");
+    baseForce_        = baseForce;
     numAtoms_         = numAtoms;
     atomStart_        = atomStart;
     accumulate_       = accumulate;
@@ -144,13 +144,13 @@ void GpuForceReduction::Impl::reinit(float3*               baseForcePtr,
 
 void GpuForceReduction::Impl::registerNbnxmForce(DeviceBuffer<RVec> forcePtr)
 {
-    GMX_ASSERT((forcePtr != nullptr), "Input force for reduction has no data");
+    GMX_ASSERT((forcePtr), "Input force for reduction has no data");
     nbnxmForceToAdd_ = forcePtr;
 };
 
 void GpuForceReduction::Impl::registerRvecForce(DeviceBuffer<RVec> forcePtr)
 {
-    GMX_ASSERT((forcePtr != nullptr), "Input force for reduction has no data");
+    GMX_ASSERT((forcePtr), "Input force for reduction has no data");
     rvecForceToAdd_ = forcePtr;
 };
 
@@ -172,11 +172,12 @@ void GpuForceReduction::Impl::execute()
     GMX_ASSERT((nbnxmForceToAdd_ != nullptr), "Nbnxm force for reduction has no data");
 
     // Enqueue wait on all dependencies passed
-    for (auto const synchronizer : dependencyList_)
+    for (const auto& synchronizer : dependencyList_)
     {
         synchronizer->enqueueWaitEvent(deviceStream_);
     }
 
+    float3* d_baseForce      = &(asFloat3(baseForce_)[atomStart_]);
     float3* d_nbnxmForce     = asFloat3(nbnxmForceToAdd_);
     float3* d_rvecForceToAdd = &(asFloat3(rvecForceToAdd_)[atomStart_]);
 
@@ -195,7 +196,7 @@ void GpuForceReduction::Impl::execute()
                             : (accumulate_ ? reduceKernel<false, true> : reduceKernel<false, false>);
 
     const auto kernelArgs = prepareGpuKernelArguments(
-            kernelFn, config, &d_nbnxmForce, &d_rvecForceToAdd, &baseForce_, &cellInfo_.d_cell, &numAtoms_);
+            kernelFn, config, &d_nbnxmForce, &d_rvecForceToAdd, &d_baseForce, &cellInfo_.d_cell, &numAtoms_);
 
     launchGpuKernel(kernelFn, config, deviceStream_, nullptr, "Force Reduction", kernelArgs);
 
@@ -218,14 +219,14 @@ GpuForceReduction::GpuForceReduction(const DeviceContext& deviceContext,
 {
 }
 
-void GpuForceReduction::registerNbnxmForce(void* forcePtr)
+void GpuForceReduction::registerNbnxmForce(DeviceBuffer<RVec> forcePtr)
 {
-    impl_->registerNbnxmForce(reinterpret_cast<DeviceBuffer<RVec>>(forcePtr));
+    impl_->registerNbnxmForce(std::move(forcePtr));
 }
 
-void GpuForceReduction::registerRvecForce(void* forcePtr)
+void GpuForceReduction::registerRvecForce(DeviceBuffer<RVec> forcePtr)
 {
-    impl_->registerRvecForce(reinterpret_cast<DeviceBuffer<RVec>>(forcePtr));
+    impl_->registerRvecForce(std::move(forcePtr));
 }
 
 void GpuForceReduction::addDependency(GpuEventSynchronizer* const dependency)
@@ -233,14 +234,14 @@ void GpuForceReduction::addDependency(GpuEventSynchronizer* const dependency)
     impl_->addDependency(dependency);
 }
 
-void GpuForceReduction::reinit(DeviceBuffer<RVec>    baseForcePtr,
+void GpuForceReduction::reinit(DeviceBuffer<RVec>    baseForce,
                                const int             numAtoms,
                                ArrayRef<const int>   cell,
                                const int             atomStart,
                                const bool            accumulate,
                                GpuEventSynchronizer* completionMarker)
 {
-    impl_->reinit(asFloat3(baseForcePtr), numAtoms, cell, atomStart, accumulate, completionMarker);
+    impl_->reinit(baseForce, numAtoms, cell, atomStart, accumulate, completionMarker);
 }
 void GpuForceReduction::execute()
 {