Remove two-stage initialization in DeviceStream
[alexxy/gromacs.git] / src / gromacs / gpu_utils / device_stream_manager.cpp
index 8b09713695ae9c3e24465a4d262164dfc43d28ed..c732156942e9601b99c3754b137fbe821b1ad0a7 100644 (file)
@@ -80,7 +80,7 @@ public:
     //! Device context.
     DeviceContext context_;
     //! GPU command streams.
-    EnumerationArray<DeviceStreamType, DeviceStream> streams_;
+    EnumerationArray<DeviceStreamType, std::unique_ptr<DeviceStream>> streams_;
 };
 
 // DeviceStreamManager::Impl
@@ -92,7 +92,8 @@ DeviceStreamManager::Impl::Impl(const DeviceInformation& deviceInfo,
 {
     try
     {
-        streams_[DeviceStreamType::NonBondedLocal].init(context_, DeviceStreamPriority::Normal, useTiming);
+        streams_[DeviceStreamType::NonBondedLocal] =
+                std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
 
         if (simulationWork.useGpuPme)
         {
@@ -100,23 +101,25 @@ DeviceStreamManager::Impl::Impl(const DeviceInformation& deviceInfo,
              * - default high priority with CUDA
              * - no priorities implemented yet with OpenCL; see #2532
              */
-            streams_[DeviceStreamType::Pme].init(context_, DeviceStreamPriority::High, useTiming);
+            streams_[DeviceStreamType::Pme] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::High, useTiming);
         }
 
         if (havePpDomainDecomposition)
         {
-            streams_[DeviceStreamType::NonBondedNonLocal].init(context_, DeviceStreamPriority::High,
-                                                               useTiming);
+            streams_[DeviceStreamType::NonBondedNonLocal] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::High, useTiming);
         }
         // Update stream is used both for coordinates transfers and for GPU update/constraints
         if (simulationWork.useGpuPme || simulationWork.useGpuUpdate || simulationWork.useGpuBufferOps)
         {
-            streams_[DeviceStreamType::UpdateAndConstraints].init(
-                    context_, DeviceStreamPriority::Normal, useTiming);
+            streams_[DeviceStreamType::UpdateAndConstraints] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
         }
         if (simulationWork.useGpuPmePpCommunication)
         {
-            streams_[DeviceStreamType::PmePpTransfer].init(context_, DeviceStreamPriority::Normal, useTiming);
+            streams_[DeviceStreamType::PmePpTransfer] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
         }
     }
     GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
@@ -147,7 +150,7 @@ const DeviceContext& DeviceStreamManager::context() const
 
 const DeviceStream& DeviceStreamManager::stream(DeviceStreamType streamToGet) const
 {
-    return impl_->streams_[streamToGet];
+    return *impl_->streams_[streamToGet];
 }
 
 const DeviceStream& DeviceStreamManager::bondedStream(bool hasPPDomainDecomposition) const
@@ -170,7 +173,7 @@ const DeviceStream& DeviceStreamManager::bondedStream(bool hasPPDomainDecomposit
 
 bool DeviceStreamManager::streamIsValid(DeviceStreamType streamToCheck) const
 {
-    return impl_->streams_[streamToCheck].isValid();
+    return impl_->streams_[streamToCheck] != nullptr && impl_->streams_[streamToCheck]->isValid();
 }
 
 } // namespace gmx