Add separate constructor to StatePropagatorDataGpu for PME-only rank / PME tests
[alexxy/gromacs.git] / src / gromacs / mdtypes / state_propagator_data_gpu_impl_gpu.cpp
index 5aa64f6e2551482d8852be7a30b022f3d4bc12a6..70c73a1657d38fb21041e6b1a9cabd3d231ff04f 100644 (file)
@@ -128,6 +128,45 @@ StatePropagatorDataGpu::Impl::Impl(const void            *pmeStream,
     fCopyStreams_[AtomLocality::All]      = nullptr;
 }
 
+StatePropagatorDataGpu::Impl::Impl(const void            *pmeStream,
+                                   const void            *deviceContext,
+                                   GpuApiCallBehavior     transferKind,
+                                   int                    paddingSize) :
+    transferKind_(transferKind),
+    paddingSize_(paddingSize)
+{
+    static_assert(GMX_GPU != GMX_GPU_NONE, "This object should only be constructed on the GPU code-paths.");
+    GMX_RELEASE_ASSERT(getenv("GMX_USE_GPU_BUFFER_OPS") == nullptr, "GPU buffer ops are not supported in this build.");
+
+    if (GMX_GPU == GMX_GPU_OPENCL)
+    {
+        GMX_ASSERT(deviceContext != nullptr, "GPU context should be set in OpenCL builds.");
+        deviceContext_  = *static_cast<const DeviceContext*>(deviceContext);
+    }
+
+    GMX_ASSERT(pmeStream != nullptr, "GPU PME stream should be set.");
+    pmeStream_      = *static_cast<const CommandStream*>(pmeStream);
+
+    localStream_    = nullptr;
+    nonLocalStream_ = nullptr;
+    updateStream_   = nullptr;
+
+
+    // Only local/all coordinates are allowed to be copied in PME-only rank/ PME tests.
+    // This it temporary measure to make it safe to use this class in those cases.
+    xCopyStreams_[AtomLocality::Local]    = pmeStream_;
+    xCopyStreams_[AtomLocality::NonLocal] = nullptr;
+    xCopyStreams_[AtomLocality::All]      = pmeStream_;
+
+    vCopyStreams_[AtomLocality::Local]    = nullptr;
+    vCopyStreams_[AtomLocality::NonLocal] = nullptr;
+    vCopyStreams_[AtomLocality::All]      = nullptr;
+
+    fCopyStreams_[AtomLocality::Local]    = nullptr;
+    fCopyStreams_[AtomLocality::NonLocal] = nullptr;
+    fCopyStreams_[AtomLocality::All]      = nullptr;
+}
+
 StatePropagatorDataGpu::Impl::~Impl()
 {
 }
@@ -440,6 +479,17 @@ StatePropagatorDataGpu::StatePropagatorDataGpu(const void        *pmeStream,
 {
 }
 
+StatePropagatorDataGpu::StatePropagatorDataGpu(const void        *pmeStream,
+                                               const void        *deviceContext,
+                                               GpuApiCallBehavior transferKind,
+                                               int                paddingSize)
+    : impl_(new Impl(pmeStream,
+                     deviceContext,
+                     transferKind,
+                     paddingSize))
+{
+}
+
 StatePropagatorDataGpu::StatePropagatorDataGpu(StatePropagatorDataGpu && /* other */) noexcept = default;
 
 StatePropagatorDataGpu &StatePropagatorDataGpu::operator=(StatePropagatorDataGpu && /* other */) noexcept = default;