Use DeviceBuffer<RVec> in GPU force reduction and PME code
[alexxy/gromacs.git] / src / gromacs / ewald / pme_pp_comm_gpu_impl.cu
index 0ecf0281333003926fb8a875fd61eb60ca4337ff..cb9e787c446836d27549d3e1c6775017520bc07f 100644 (file)
@@ -64,7 +64,8 @@ PmePpCommGpu::Impl::Impl(MPI_Comm             comm,
     deviceContext_(deviceContext),
     pmePpCommStream_(deviceStream),
     comm_(comm),
-    pmeRank_(pmeRank)
+    pmeRank_(pmeRank),
+    d_pmeForces_(nullptr)
 {
     GMX_RELEASE_ASSERT(
             GMX_THREAD_MPI,
@@ -155,9 +156,10 @@ void PmePpCommGpu::Impl::sendCoordinatesToPmeCudaDirect(void* sendPtr,
     GMX_UNUSED_VALUE(coordinatesReadyOnDeviceEvent);
 #endif
 }
-void* PmePpCommGpu::Impl::getGpuForceStagingPtr()
+
+DeviceBuffer<Float3> PmePpCommGpu::Impl::getGpuForceStagingPtr()
 {
-    return static_cast<void*>(d_pmeForces_);
+    return d_pmeForces_;
 }
 
 GpuEventSynchronizer* PmePpCommGpu::Impl::getForcesReadySynchronizer()
@@ -194,7 +196,7 @@ void PmePpCommGpu::sendCoordinatesToPmeCudaDirect(void*                 sendPtr,
             sendPtr, sendSize, sendPmeCoordinatesFromGpu, coordinatesReadyOnDeviceEvent);
 }
 
-void* PmePpCommGpu::getGpuForceStagingPtr()
+DeviceBuffer<gmx::RVec> PmePpCommGpu::getGpuForceStagingPtr()
 {
     return impl_->getGpuForceStagingPtr();
 }