Add management for velocities and forces copy events to StatePropagatorDataGpu
[alexxy/gromacs.git] / src / gromacs / mdtypes / state_propagator_data_gpu_impl_gpu.cpp
index ba6850db45ed7f2e187b626794fd8f8f7cf4deff..7070891d6c706b42352ae9ba0d2d8e8447b908a5 100644 (file)
@@ -112,11 +112,20 @@ StatePropagatorDataGpu::Impl::Impl(const void            *pmeStream,
         GMX_UNUSED_VALUE(deviceContext);
     }
 
-    // Map the atom locality to the stream that will be used for coordinates transfer.
-    // Same streams are used for H2D and D2H copies
+    // Map the atom locality to the stream that will be used for coordinates,
+    // velocities and forces transfers. Same streams are used for H2D and D2H copies.
+    // Note, that nullptr stream is used here to indicate that the copy is not supported.
     xCopyStreams_[AtomLocality::Local]    = updateStream_;
     xCopyStreams_[AtomLocality::NonLocal] = nonLocalStream_;
     xCopyStreams_[AtomLocality::All]      = updateStream_;
+
+    vCopyStreams_[AtomLocality::Local]    = updateStream_;
+    vCopyStreams_[AtomLocality::NonLocal] = nullptr;
+    vCopyStreams_[AtomLocality::All]      = updateStream_;
+
+    fCopyStreams_[AtomLocality::Local]    = localStream_;
+    fCopyStreams_[AtomLocality::NonLocal] = nonLocalStream_;
+    fCopyStreams_[AtomLocality::All]      = nullptr;
 }
 
 StatePropagatorDataGpu::Impl::~Impl()
@@ -293,15 +302,38 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getVelocities()
 void StatePropagatorDataGpu::Impl::copyVelocitiesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_v,
                                                        AtomLocality                          atomLocality)
 {
-    // TODO: Use the correct stream
-    copyToDevice(d_v_, h_v, d_vSize_, atomLocality, nullptr);
+    GMX_ASSERT(atomLocality < AtomLocality::Count, "Wrong atom locality.");
+    CommandStream commandStream = vCopyStreams_[atomLocality];
+    GMX_ASSERT(commandStream != nullptr, "No stream is valid for copying velocities with given atom locality.");
+
+    copyToDevice(d_v_, h_v, d_vSize_, atomLocality, commandStream);
+    // TODO: Remove When event-based synchronization is introduced
+    gpuStreamSynchronize(commandStream);
+    vReadyOnDevice_[atomLocality].markEvent(commandStream);
+}
+
+GpuEventSynchronizer* StatePropagatorDataGpu::Impl::getVelocitiesReadyOnDeviceEvent(AtomLocality  atomLocality)
+{
+    return &vReadyOnDevice_[atomLocality];
 }
 
+
 void StatePropagatorDataGpu::Impl::copyVelocitiesFromGpu(gmx::ArrayRef<gmx::RVec>  h_v,
                                                          AtomLocality              atomLocality)
 {
-    // TODO: Use the correct stream
-    copyFromDevice(h_v, d_v_, d_vSize_, atomLocality, nullptr);
+    GMX_ASSERT(atomLocality < AtomLocality::Count, "Wrong atom locality.");
+    CommandStream commandStream = vCopyStreams_[atomLocality];
+    GMX_ASSERT(commandStream != nullptr, "No stream is valid for copying velocities with given atom locality.");
+
+    copyFromDevice(h_v, d_v_, d_vSize_, atomLocality, commandStream);
+    // TODO: Remove When event-based synchronization is introduced
+    gpuStreamSynchronize(commandStream);
+    vReadyOnHost_[atomLocality].markEvent(commandStream);
+}
+
+void StatePropagatorDataGpu::Impl::waitVelocitiesReadyOnHost(AtomLocality  atomLocality)
+{
+    vReadyOnHost_[atomLocality].waitForEvent();
 }
 
 
@@ -313,15 +345,38 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getForces()
 void StatePropagatorDataGpu::Impl::copyForcesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_f,
                                                    AtomLocality                          atomLocality)
 {
-    // TODO: Use the correct stream
-    copyToDevice(d_f_, h_f, d_fSize_, atomLocality, nullptr);
+    GMX_ASSERT(atomLocality < AtomLocality::Count, "Wrong atom locality.");
+    CommandStream commandStream = fCopyStreams_[atomLocality];
+    GMX_ASSERT(commandStream != nullptr, "No stream is valid for copying forces with given atom locality.");
+
+    copyToDevice(d_f_, h_f, d_fSize_, atomLocality, commandStream);
+    // TODO: Remove When event-based synchronization is introduced
+    gpuStreamSynchronize(commandStream);
+    fReadyOnDevice_[atomLocality].markEvent(commandStream);
 }
 
+GpuEventSynchronizer* StatePropagatorDataGpu::Impl::getForcesReadyOnDeviceEvent(AtomLocality  atomLocality)
+{
+    return &fReadyOnDevice_[atomLocality];
+}
+
+
 void StatePropagatorDataGpu::Impl::copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  h_f,
                                                      AtomLocality              atomLocality)
 {
-    // TODO: Use the correct stream
-    copyFromDevice(h_f, d_f_, d_fSize_, atomLocality, nullptr);
+    GMX_ASSERT(atomLocality < AtomLocality::Count, "Wrong atom locality.");
+    CommandStream commandStream = fCopyStreams_[atomLocality];
+    GMX_ASSERT(commandStream != nullptr, "No stream is valid for copying forces with given atom locality.");
+
+    copyFromDevice(h_f, d_f_, d_fSize_, atomLocality, commandStream);
+    // TODO: Remove When event-based synchronization is introduced
+    gpuStreamSynchronize(commandStream);
+    fReadyOnHost_[atomLocality].markEvent(commandStream);
+}
+
+void StatePropagatorDataGpu::Impl::waitForcesReadyOnHost(AtomLocality  atomLocality)
+{
+    fReadyOnHost_[atomLocality].waitForEvent();
 }
 
 void* StatePropagatorDataGpu::Impl::getUpdateStream()
@@ -413,12 +468,22 @@ void StatePropagatorDataGpu::copyVelocitiesToGpu(const gmx::ArrayRef<const gmx::
     return impl_->copyVelocitiesToGpu(h_v, atomLocality);
 }
 
+GpuEventSynchronizer* StatePropagatorDataGpu::getVelocitiesReadyOnDeviceEvent(AtomLocality  atomLocality)
+{
+    return impl_->getVelocitiesReadyOnDeviceEvent(atomLocality);
+}
+
 void StatePropagatorDataGpu::copyVelocitiesFromGpu(gmx::ArrayRef<RVec>  h_v,
                                                    AtomLocality         atomLocality)
 {
     return impl_->copyVelocitiesFromGpu(h_v, atomLocality);
 }
 
+void StatePropagatorDataGpu::waitVelocitiesReadyOnHost(AtomLocality  atomLocality)
+{
+    return impl_->waitVelocitiesReadyOnHost(atomLocality);
+}
+
 
 DeviceBuffer<float> StatePropagatorDataGpu::getForces()
 {
@@ -431,12 +496,23 @@ void StatePropagatorDataGpu::copyForcesToGpu(const gmx::ArrayRef<const gmx::RVec
     return impl_->copyForcesToGpu(h_f, atomLocality);
 }
 
+GpuEventSynchronizer* StatePropagatorDataGpu::getForcesReadyOnDeviceEvent(AtomLocality  atomLocality)
+{
+    return impl_->getForcesReadyOnDeviceEvent(atomLocality);
+}
+
 void StatePropagatorDataGpu::copyForcesFromGpu(gmx::ArrayRef<RVec>  h_f,
                                                AtomLocality         atomLocality)
 {
     return impl_->copyForcesFromGpu(h_f, atomLocality);
 }
 
+void StatePropagatorDataGpu::waitForcesReadyOnHost(AtomLocality  atomLocality)
+{
+    return impl_->waitForcesReadyOnHost(atomLocality);
+}
+
+
 void* StatePropagatorDataGpu::getUpdateStream()
 {
     return impl_->getUpdateStream();