Pass the GPU streams to StatePropagatorDataGpu constructor
[alexxy/gromacs.git] / src / gromacs / mdtypes / state_propagator_data_gpu_impl_gpu.cpp
index 89446f247759253f88a37854c1e3a810930bb23b..3cc0bb7653bc4286265bb178ec67785732f5f135 100644 (file)
 namespace gmx
 {
 
-StatePropagatorDataGpu::Impl::Impl(gmx_unused const void *commandStream,
-                                   gmx_unused const void *deviceContext,
+StatePropagatorDataGpu::Impl::Impl(const void            *pmeStream,
+                                   const void            *localStream,
+                                   const void            *nonLocalStream,
+                                   const void            *deviceContext,
                                    GpuApiCallBehavior     transferKind,
                                    int                    paddingSize) :
     transferKind_(transferKind),
@@ -72,18 +74,29 @@ StatePropagatorDataGpu::Impl::Impl(gmx_unused const void *commandStream,
 
     GMX_RELEASE_ASSERT(getenv("GMX_USE_GPU_BUFFER_OPS") == nullptr, "GPU buffer ops are not supported in this build.");
 
-    // Set the stream-context pair for the OpenCL builds,
-    // use the nullptr stream for CUDA builds
-#if GMX_GPU == GMX_GPU_OPENCL
-    if (commandStream != nullptr)
+    if (pmeStream != nullptr)
+    {
+        pmeStream_ = *static_cast<const CommandStream*>(pmeStream);
+    }
+    if (localStream != nullptr)
+    {
+        localStream_ = *static_cast<const CommandStream*>(localStream);
+    }
+    if (nonLocalStream != nullptr)
     {
-        commandStream_ = *static_cast<const CommandStream*>(commandStream);
+        nonLocalStream_ = *static_cast<const CommandStream*>(nonLocalStream);
     }
+// The OpenCL build will never use the updateStream
+// TODO: The update stream should be created only when it is needed.
+#if GMX_GPU == GMX_GPU_CUDA
+    cudaError_t stat;
+    stat = cudaStreamCreate(&updateStream_);
+    CU_RET_ERR(stat, "CUDA stream creation failed in StatePropagatorDataGpu");
+#endif
     if (deviceContext != nullptr)
     {
         deviceContext_ = *static_cast<const DeviceContext*>(deviceContext);
     }
-#endif
 
 }
 
@@ -114,7 +127,10 @@ void StatePropagatorDataGpu::Impl::reinit(int numAtomsLocal, int numAtomsAll)
     const size_t paddingAllocationSize = numAtomsPadded - numAtomsAll_;
     if (paddingAllocationSize > 0)
     {
-        clearDeviceBufferAsync(&d_x_, DIM*numAtomsAll_, DIM*paddingAllocationSize, commandStream_);
+        // The PME stream is used here because:
+        // 1. The padding clearing is only needed by PME.
+        // 2. It is the stream that is created in the PME OpenCL context.
+        clearDeviceBufferAsync(&d_x_, DIM*numAtomsAll_, DIM*paddingAllocationSize, pmeStream_);
     }
 
     reallocateDeviceBuffer(&d_v_, DIM*numAtomsAll_, &d_vSize_, &d_vCapacity_, deviceContext_);
@@ -151,11 +167,16 @@ std::tuple<int, int> StatePropagatorDataGpu::Impl::getAtomRangesFromAtomLocality
 void StatePropagatorDataGpu::Impl::copyToDevice(DeviceBuffer<float>                   d_data,
                                                 const gmx::ArrayRef<const gmx::RVec>  h_data,
                                                 int                                   dataSize,
-                                                AtomLocality                          atomLocality)
+                                                AtomLocality                          atomLocality,
+                                                CommandStream                         commandStream)
 {
 
 #if GMX_GPU == GMX_GPU_OPENCL
     GMX_ASSERT(deviceContext_ != nullptr, "GPU context should be set in OpenCL builds.");
+    // The PME stream is used for OpenCL builds, because it is the context that it associated with the
+    // PME task which requires the coordinates managed here in OpenCL.
+    // TODO: This will have to be changed when the OpenCL implementation will be extended.
+    commandStream = pmeStream_;
 #endif
 
     GMX_UNUSED_VALUE(dataSize);
@@ -173,21 +194,25 @@ void StatePropagatorDataGpu::Impl::copyToDevice(DeviceBuffer<float>
         GMX_ASSERT(elementsStartAt + numElementsToCopy <= dataSize, "The device allocation is smaller than requested copy range.");
         GMX_ASSERT(atomsStartAt + numAtomsToCopy <= h_data.ssize(), "The host buffer is smaller than the requested copy range.");
 
-        // TODO: Use the proper stream
         copyToDeviceBuffer(&d_data, reinterpret_cast<const float *>(&h_data.data()[atomsStartAt]),
                            elementsStartAt, numElementsToCopy,
-                           commandStream_, transferKind_, nullptr);
+                           commandStream, transferKind_, nullptr);
     }
 }
 
 void StatePropagatorDataGpu::Impl::copyFromDevice(gmx::ArrayRef<gmx::RVec>  h_data,
                                                   DeviceBuffer<float>       d_data,
                                                   int                       dataSize,
-                                                  AtomLocality              atomLocality)
+                                                  AtomLocality              atomLocality,
+                                                  CommandStream             commandStream)
 {
 
 #if GMX_GPU == GMX_GPU_OPENCL
     GMX_ASSERT(deviceContext_ != nullptr, "GPU context should be set in OpenCL builds.");
+    // The PME stream is used for OpenCL builds, because it is the context that it associated with the
+    // PME task which requires the coordinates managed here in OpenCL.
+    // TODO: This will have to be changed when the OpenCL implementation will be extended.
+    commandStream = pmeStream_;
 #endif
 
     GMX_UNUSED_VALUE(dataSize);
@@ -205,11 +230,9 @@ void StatePropagatorDataGpu::Impl::copyFromDevice(gmx::ArrayRef<gmx::RVec>  h_da
         GMX_ASSERT(elementsStartAt + numElementsToCopy <= dataSize, "The device allocation is smaller than requested copy range.");
         GMX_ASSERT(atomsStartAt + numAtomsToCopy <= h_data.ssize(), "The host buffer is smaller than the requested copy range.");
 
-        // TODO: Use the proper stream
         copyFromDeviceBuffer(reinterpret_cast<float*>(&h_data.data()[atomsStartAt]), &d_data,
                              elementsStartAt, numElementsToCopy,
-                             commandStream_, transferKind_, nullptr);
-
+                             commandStream, transferKind_, nullptr);
     }
 }
 
@@ -221,13 +244,15 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getCoordinates()
 void StatePropagatorDataGpu::Impl::copyCoordinatesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_x,
                                                         AtomLocality                          atomLocality)
 {
-    copyToDevice(d_x_, h_x, d_xSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyToDevice(d_x_, h_x, d_xSize_, atomLocality, nullptr);
 }
 
 void StatePropagatorDataGpu::Impl::copyCoordinatesFromGpu(gmx::ArrayRef<gmx::RVec>  h_x,
                                                           AtomLocality              atomLocality)
 {
-    copyFromDevice(h_x, d_x_, d_xSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyFromDevice(h_x, d_x_, d_xSize_, atomLocality, nullptr);
 }
 
 
@@ -239,13 +264,15 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getVelocities()
 void StatePropagatorDataGpu::Impl::copyVelocitiesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_v,
                                                        AtomLocality                          atomLocality)
 {
-    copyToDevice(d_v_, h_v, d_vSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyToDevice(d_v_, h_v, d_vSize_, atomLocality, nullptr);
 }
 
 void StatePropagatorDataGpu::Impl::copyVelocitiesFromGpu(gmx::ArrayRef<gmx::RVec>  h_v,
                                                          AtomLocality              atomLocality)
 {
-    copyFromDevice(h_v, d_v_, d_vSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyFromDevice(h_v, d_v_, d_vSize_, atomLocality, nullptr);
 }
 
 
@@ -257,18 +284,20 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getForces()
 void StatePropagatorDataGpu::Impl::copyForcesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_f,
                                                    AtomLocality                          atomLocality)
 {
-    copyToDevice(d_f_, h_f, d_fSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyToDevice(d_f_, h_f, d_fSize_, atomLocality, nullptr);
 }
 
 void StatePropagatorDataGpu::Impl::copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  h_f,
                                                      AtomLocality              atomLocality)
 {
-    copyFromDevice(h_f, d_f_, d_fSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyFromDevice(h_f, d_f_, d_fSize_, atomLocality, nullptr);
 }
 
-void StatePropagatorDataGpu::Impl::synchronizeStream()
+void* StatePropagatorDataGpu::Impl::getUpdateStream()
 {
-    gpuStreamSynchronize(commandStream_);
+    return updateStream_;
 }
 
 int StatePropagatorDataGpu::Impl::numAtomsLocal()
@@ -283,11 +312,15 @@ int StatePropagatorDataGpu::Impl::numAtomsAll()
 
 
 
-StatePropagatorDataGpu::StatePropagatorDataGpu(const void        *commandStream,
+StatePropagatorDataGpu::StatePropagatorDataGpu(const void        *pmeStream,
+                                               const void        *localStream,
+                                               const void        *nonLocalStream,
                                                const void        *deviceContext,
                                                GpuApiCallBehavior transferKind,
                                                int                paddingSize)
-    : impl_(new Impl(commandStream,
+    : impl_(new Impl(pmeStream,
+                     localStream,
+                     nonLocalStream,
                      deviceContext,
                      transferKind,
                      paddingSize))
@@ -365,9 +398,9 @@ void StatePropagatorDataGpu::copyForcesFromGpu(gmx::ArrayRef<RVec>  h_f,
     return impl_->copyForcesFromGpu(h_f, atomLocality);
 }
 
-void StatePropagatorDataGpu::synchronizeStream()
+void* StatePropagatorDataGpu::getUpdateStream()
 {
-    return impl_->synchronizeStream();
+    return impl_->getUpdateStream();
 }
 
 int StatePropagatorDataGpu::numAtomsLocal()