Split simulationWork.useGpuBufferOps into separate x and f flags
[alexxy/gromacs.git] / src / gromacs / gpu_utils / device_stream_manager.cpp
index 8c7457a3d3b0da8c8b44b5f4ddef87632a14f162..910ccdd0e2cb60b251bb6186da9d09fe2170d011 100644 (file)
@@ -1,7 +1,7 @@
 /*
  * This file is part of the GROMACS molecular simulation package.
  *
- * Copyright (c) 2019,2020, by the GROMACS development team, led by
+ * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -47,7 +47,7 @@
 
 #include "gromacs/gpu_utils/device_context.h"
 #include "gromacs/gpu_utils/device_stream.h"
-#include "gromacs/gpu_utils/gputraits.h"
+#include "gromacs/mdtypes/simulation_workload.h"
 #include "gromacs/utility/enumerationhelpers.h"
 #include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/gmxassert.h"
@@ -72,55 +72,54 @@ public:
      * \throws InternalError  If any of the required resources could not be initialized.
      */
     Impl(const DeviceInformation& deviceInfo,
-         bool                     useGpuForPme,
          bool                     havePpDomainDecomposition,
-         bool                     doGpuPmePpTransfer,
-         bool                     useGpuForUpdate,
+         SimulationWorkload       simulationWork,
          bool                     useTiming);
     ~Impl();
 
     //! Device context.
     DeviceContext context_;
     //! GPU command streams.
-    EnumerationArray<DeviceStreamType, DeviceStream> streams_;
+    EnumerationArray<DeviceStreamType, std::unique_ptr<DeviceStream>> streams_;
 };
 
 // DeviceStreamManager::Impl
 DeviceStreamManager::Impl::Impl(const DeviceInformation& deviceInfo,
-                                const bool               useGpuForPme,
                                 const bool               havePpDomainDecomposition,
-                                const bool               doGpuPmePpTransfer,
-                                const bool               useGpuForUpdate,
+                                const SimulationWorkload simulationWork,
                                 const bool               useTiming) :
     context_(deviceInfo)
 {
     try
     {
-        streams_[DeviceStreamType::NonBondedLocal].init(context_, DeviceStreamPriority::Normal, useTiming);
+        streams_[DeviceStreamType::NonBondedLocal] =
+                std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
 
-        if (useGpuForPme)
+        if (simulationWork.useGpuPme)
         {
             /* Creating a PME GPU stream:
              * - default high priority with CUDA
              * - no priorities implemented yet with OpenCL; see #2532
              */
-            streams_[DeviceStreamType::Pme].init(context_, DeviceStreamPriority::High, useTiming);
+            streams_[DeviceStreamType::Pme] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::High, useTiming);
         }
 
         if (havePpDomainDecomposition)
         {
-            streams_[DeviceStreamType::NonBondedNonLocal].init(context_, DeviceStreamPriority::High,
-                                                               useTiming);
+            streams_[DeviceStreamType::NonBondedNonLocal] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::High, useTiming);
         }
         // Update stream is used both for coordinates transfers and for GPU update/constraints
-        if (useGpuForPme || useGpuForUpdate)
+        if (simulationWork.useGpuPme || simulationWork.useGpuUpdate || simulationWork.useGpuXBufferOps)
         {
-            streams_[DeviceStreamType::UpdateAndConstraints].init(
-                    context_, DeviceStreamPriority::Normal, useTiming);
+            streams_[DeviceStreamType::UpdateAndConstraints] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
         }
-        if (doGpuPmePpTransfer)
+        if (simulationWork.useGpuPmePpCommunication)
         {
-            streams_[DeviceStreamType::PmePpTransfer].init(context_, DeviceStreamPriority::Normal, useTiming);
+            streams_[DeviceStreamType::PmePpTransfer] =
+                    std::make_unique<DeviceStream>(context_, DeviceStreamPriority::Normal, useTiming);
         }
     }
     GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
@@ -130,12 +129,10 @@ DeviceStreamManager::Impl::~Impl() = default;
 
 // DeviceStreamManager
 DeviceStreamManager::DeviceStreamManager(const DeviceInformation& deviceInfo,
-                                         const bool               useGpuForPme,
                                          const bool               havePpDomainDecomposition,
-                                         const bool               doGpuPmePpTransfer,
-                                         const bool               useGpuForUpdate,
+                                         const SimulationWorkload simulationWork,
                                          const bool               useTiming) :
-    impl_(new Impl(deviceInfo, useGpuForPme, havePpDomainDecomposition, doGpuPmePpTransfer, useGpuForUpdate, useTiming))
+    impl_(new Impl(deviceInfo, havePpDomainDecomposition, simulationWork, useTiming))
 {
 }
 
@@ -153,7 +150,7 @@ const DeviceContext& DeviceStreamManager::context() const
 
 const DeviceStream& DeviceStreamManager::stream(DeviceStreamType streamToGet) const
 {
-    return impl_->streams_[streamToGet];
+    return *impl_->streams_[streamToGet];
 }
 
 const DeviceStream& DeviceStreamManager::bondedStream(bool hasPPDomainDecomposition) const
@@ -176,7 +173,7 @@ const DeviceStream& DeviceStreamManager::bondedStream(bool hasPPDomainDecomposit
 
 bool DeviceStreamManager::streamIsValid(DeviceStreamType streamToCheck) const
 {
-    return impl_->streams_[streamToCheck].isValid();
+    return impl_->streams_[streamToCheck] != nullptr && impl_->streams_[streamToCheck]->isValid();
 }
 
 } // namespace gmx