StatePropagatorDataGpu object to manage GPU forces, positions and velocities buffers
[alexxy/gromacs.git] / src / gromacs / mdlib / update_constrain_cuda_impl.cu
index f373d2791135c56cd38695e1ed9238ae2687523a..9ba96267380acc8efc92e24ce88857a6bf5f9871 100644 (file)
@@ -105,6 +105,9 @@ void UpdateConstrainCuda::Impl::integrate(const real                        dt,
         }
     }
 
+    // TODO: This should be eliminated
+    cudaMemcpy(d_x_, d_xp_, numAtoms_*sizeof(float3), cudaMemcpyDeviceToDevice);
+
     return;
 }
 
@@ -124,16 +127,24 @@ UpdateConstrainCuda::Impl::~Impl()
 {
 }
 
-void UpdateConstrainCuda::Impl::set(const t_idef    &idef,
-                                    const t_mdatoms &md,
-                                    const int        numTempScaleValues)
+void UpdateConstrainCuda::Impl::set(DeviceBuffer<float>        d_x,
+                                    DeviceBuffer<float>        d_v,
+                                    const DeviceBuffer<float>  d_f,
+                                    const t_idef              &idef,
+                                    const t_mdatoms           &md,
+                                    const int                  numTempScaleValues)
 {
+    GMX_ASSERT(d_x != nullptr, "Coordinates device buffer should not be null.");
+    GMX_ASSERT(d_v != nullptr, "Velocities device buffer should not be null.");
+    GMX_ASSERT(d_f != nullptr, "Forces device buffer should not be null.");
+
+    d_x_ = reinterpret_cast<float3*>(d_x);
+    d_v_ = reinterpret_cast<float3*>(d_v);
+    d_f_ = reinterpret_cast<float3*>(d_f);
+
     numAtoms_ = md.nr;
 
-    reallocateDeviceBuffer(&d_x_,  numAtoms_, &numX_,  &numXAlloc_,  nullptr);
     reallocateDeviceBuffer(&d_xp_, numAtoms_, &numXp_, &numXpAlloc_, nullptr);
-    reallocateDeviceBuffer(&d_v_,  numAtoms_, &numV_,  &numVAlloc_,  nullptr);
-    reallocateDeviceBuffer(&d_f_,  numAtoms_, &numF_,  &numFAlloc_,  nullptr);
 
     reallocateDeviceBuffer(&d_inverseMasses_, numAtoms_,
                            &numInverseMasses_, &numInverseMassesAlloc_, nullptr);
@@ -152,44 +163,6 @@ void UpdateConstrainCuda::Impl::setPbc(const t_pbc *pbc)
     settleCuda_->setPbc(pbc);
 }
 
-void UpdateConstrainCuda::Impl::copyCoordinatesToGpu(const rvec *h_x)
-{
-    copyToDeviceBuffer(&d_x_, (float3*)h_x, 0, numAtoms_, commandStream_, GpuApiCallBehavior::Sync, nullptr);
-}
-
-void UpdateConstrainCuda::Impl::copyVelocitiesToGpu(const rvec *h_v)
-{
-    copyToDeviceBuffer(&d_v_, (float3*)h_v, 0, numAtoms_, commandStream_, GpuApiCallBehavior::Sync, nullptr);
-}
-
-void UpdateConstrainCuda::Impl::copyForcesToGpu(const rvec *h_f)
-{
-    copyToDeviceBuffer(&d_f_, (float3*)h_f, 0, numAtoms_, commandStream_, GpuApiCallBehavior::Sync, nullptr);
-}
-
-void UpdateConstrainCuda::Impl::copyCoordinatesFromGpu(rvec *h_xp)
-{
-    copyFromDeviceBuffer((float3*)h_xp, &d_xp_, 0, numAtoms_, commandStream_, GpuApiCallBehavior::Sync, nullptr);
-}
-
-void UpdateConstrainCuda::Impl::copyVelocitiesFromGpu(rvec *h_v)
-{
-    copyFromDeviceBuffer((float3*)h_v, &d_v_, 0, numAtoms_, commandStream_, GpuApiCallBehavior::Sync, nullptr);
-}
-
-void UpdateConstrainCuda::Impl::copyForcesFromGpu(rvec *h_f)
-{
-    copyFromDeviceBuffer((float3*)h_f, &d_f_, 0, numAtoms_, commandStream_, GpuApiCallBehavior::Sync, nullptr);
-}
-
-void UpdateConstrainCuda::Impl::setXVFPointers(rvec *d_x, rvec *d_xp, rvec *d_v, rvec *d_f)
-{
-    d_x_  = (float3*)d_x;
-    d_xp_ = (float3*)d_xp;
-    d_v_  = (float3*)d_v;
-    d_f_  = (float3*)d_f;
-}
-
 UpdateConstrainCuda::UpdateConstrainCuda(const t_inputrec  &ir,
                                          const gmx_mtop_t  &mtop,
                                          const void        *commandStream)
@@ -207,18 +180,21 @@ void UpdateConstrainCuda::integrate(const real                        dt,
                                     gmx::ArrayRef<const t_grp_tcstat> tcstat,
                                     const bool                        doPressureCouple,
                                     const float                       dtPressureCouple,
-                                    const matrix                      pRVScalingMatrix)
+                                    const matrix                      velocityScalingMatrix)
 {
     impl_->integrate(dt, updateVelocities, computeVirial, virialScaled,
                      doTempCouple, tcstat,
-                     doPressureCouple, dtPressureCouple, pRVScalingMatrix);
+                     doPressureCouple, dtPressureCouple, velocityScalingMatrix);
 }
 
-void UpdateConstrainCuda::set(const t_idef    &idef,
-                              const t_mdatoms &md,
-                              const int        numTempScaleValues)
+void UpdateConstrainCuda::set(DeviceBuffer<float>        d_x,
+                              DeviceBuffer<float>        d_v,
+                              const DeviceBuffer<float>  d_f,
+                              const t_idef              &idef,
+                              const t_mdatoms           &md,
+                              const int                  numTempScaleValues)
 {
-    impl_->set(idef, md, numTempScaleValues);
+    impl_->set(d_x, d_v, d_f, idef, md, numTempScaleValues);
 }
 
 void UpdateConstrainCuda::setPbc(const t_pbc *pbc)
@@ -226,39 +202,4 @@ void UpdateConstrainCuda::setPbc(const t_pbc *pbc)
     impl_->setPbc(pbc);
 }
 
-void UpdateConstrainCuda::copyCoordinatesToGpu(const rvec *h_x)
-{
-    impl_->copyCoordinatesToGpu(h_x);
-}
-
-void UpdateConstrainCuda::copyVelocitiesToGpu(const rvec *h_v)
-{
-    impl_->copyVelocitiesToGpu(h_v);
-}
-
-void UpdateConstrainCuda::copyForcesToGpu(const rvec *h_f)
-{
-    impl_->copyForcesToGpu(h_f);
-}
-
-void UpdateConstrainCuda::copyCoordinatesFromGpu(rvec *h_xp)
-{
-    impl_->copyCoordinatesFromGpu(h_xp);
-}
-
-void UpdateConstrainCuda::copyVelocitiesFromGpu(rvec *h_v)
-{
-    impl_->copyVelocitiesFromGpu(h_v);
-}
-
-void UpdateConstrainCuda::copyForcesFromGpu(rvec *h_f)
-{
-    impl_->copyForcesFromGpu(h_f);
-}
-
-void UpdateConstrainCuda::setXVFPointers(rvec *d_x, rvec *d_xp, rvec *d_v, rvec *d_f)
-{
-    impl_->setXVFPointers(d_x, d_xp, d_v, d_f);
-}
-
 } //namespace gmx