Revert "Wrap more device pointers in DeviceBuffer" (!1244)
[alexxy/gromacs.git] / src / gromacs / mdlib / gpuforcereduction_impl.cu
index ac89b47d24712f8bb8474b25fad098871809b185..6e1e7e920a50b2b91e8e2d26149beee45c6d1380 100644 (file)
@@ -43,7 +43,7 @@
 
 #include "gmxpre.h"
 
-#include "gpuforcereduction_impl.h"
+#include "gpuforcereduction_impl.cuh"
 
 #include <stdio.h>
 
@@ -112,15 +112,15 @@ GpuForceReduction::Impl::Impl(const DeviceContext& deviceContext,
     deviceStream_(deviceStream),
     wcycle_(wcycle){};
 
-void GpuForceReduction::Impl::reinit(DeviceBuffer<gmx::RVec> baseForce,
-                                     const int               numAtoms,
-                                     ArrayRef<const int>     cell,
-                                     const int               atomStart,
-                                     const bool              accumulate,
-                                     GpuEventSynchronizer*   completionMarker)
+void GpuForceReduction::Impl::reinit(float3*               baseForcePtr,
+                                     const int             numAtoms,
+                                     ArrayRef<const int>   cell,
+                                     const int             atomStart,
+                                     const bool            accumulate,
+                                     GpuEventSynchronizer* completionMarker)
 {
-    GMX_ASSERT((baseForce != nullptr), "Input base force for reduction has no data");
-    baseForce_        = baseForce;
+    GMX_ASSERT((baseForcePtr != nullptr), "Input base force for reduction has no data");
+    baseForce_        = &(baseForcePtr[atomStart]);
     numAtoms_         = numAtoms;
     atomStart_        = atomStart;
     accumulate_       = accumulate;
@@ -144,13 +144,13 @@ void GpuForceReduction::Impl::reinit(DeviceBuffer<gmx::RVec> baseForce,
 
 void GpuForceReduction::Impl::registerNbnxmForce(DeviceBuffer<RVec> forcePtr)
 {
-    GMX_ASSERT((forcePtr), "Input force for reduction has no data");
+    GMX_ASSERT((forcePtr != nullptr), "Input force for reduction has no data");
     nbnxmForceToAdd_ = forcePtr;
 };
 
 void GpuForceReduction::Impl::registerRvecForce(DeviceBuffer<RVec> forcePtr)
 {
-    GMX_ASSERT((forcePtr), "Input force for reduction has no data");
+    GMX_ASSERT((forcePtr != nullptr), "Input force for reduction has no data");
     rvecForceToAdd_ = forcePtr;
 };
 
@@ -172,12 +172,11 @@ void GpuForceReduction::Impl::execute()
     GMX_ASSERT((nbnxmForceToAdd_ != nullptr), "Nbnxm force for reduction has no data");
 
     // Enqueue wait on all dependencies passed
-    for (const auto& synchronizer : dependencyList_)
+    for (auto const synchronizer : dependencyList_)
     {
         synchronizer->enqueueWaitEvent(deviceStream_);
     }
 
-    float3* d_baseForce      = &(asFloat3(baseForce_)[atomStart_]);
     float3* d_nbnxmForce     = asFloat3(nbnxmForceToAdd_);
     float3* d_rvecForceToAdd = &(asFloat3(rvecForceToAdd_)[atomStart_]);
 
@@ -196,7 +195,7 @@ void GpuForceReduction::Impl::execute()
                             : (accumulate_ ? reduceKernel<false, true> : reduceKernel<false, false>);
 
     const auto kernelArgs = prepareGpuKernelArguments(
-            kernelFn, config, &d_nbnxmForce, &d_rvecForceToAdd, &d_baseForce, &cellInfo_.d_cell, &numAtoms_);
+            kernelFn, config, &d_nbnxmForce, &d_rvecForceToAdd, &baseForce_, &cellInfo_.d_cell, &numAtoms_);
 
     launchGpuKernel(kernelFn, config, deviceStream_, nullptr, "Force Reduction", kernelArgs);
 
@@ -219,14 +218,14 @@ GpuForceReduction::GpuForceReduction(const DeviceContext& deviceContext,
 {
 }
 
-void GpuForceReduction::registerNbnxmForce(DeviceBuffer<RVec> forcePtr)
+void GpuForceReduction::registerNbnxmForce(void* forcePtr)
 {
-    impl_->registerNbnxmForce(std::move(forcePtr));
+    impl_->registerNbnxmForce(reinterpret_cast<DeviceBuffer<RVec>>(forcePtr));
 }
 
-void GpuForceReduction::registerRvecForce(DeviceBuffer<RVec> forcePtr)
+void GpuForceReduction::registerRvecForce(void* forcePtr)
 {
-    impl_->registerRvecForce(std::move(forcePtr));
+    impl_->registerRvecForce(reinterpret_cast<DeviceBuffer<RVec>>(forcePtr));
 }
 
 void GpuForceReduction::addDependency(GpuEventSynchronizer* const dependency)
@@ -234,14 +233,14 @@ void GpuForceReduction::addDependency(GpuEventSynchronizer* const dependency)
     impl_->addDependency(dependency);
 }
 
-void GpuForceReduction::reinit(DeviceBuffer<RVec>    baseForce,
+void GpuForceReduction::reinit(DeviceBuffer<RVec>    baseForcePtr,
                                const int             numAtoms,
                                ArrayRef<const int>   cell,
                                const int             atomStart,
                                const bool            accumulate,
                                GpuEventSynchronizer* completionMarker)
 {
-    impl_->reinit(baseForce, numAtoms, cell, atomStart, accumulate, completionMarker);
+    impl_->reinit(asFloat3(baseForcePtr), numAtoms, cell, atomStart, accumulate, completionMarker);
 }
 void GpuForceReduction::execute()
 {