Make use of the DeviceStreamManager
[alexxy/gromacs.git] / src / gromacs / domdec / domdec.cpp
index ebcc92bf2ea1949d794860cc22f0db393771fefb..aed2af87cbe5522259c31d9311df501fa50cbae0 100644 (file)
@@ -64,6 +64,7 @@
 #include "gromacs/domdec/partition.h"
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/gmxlib/nrnb.h"
+#include "gromacs/gpu_utils/device_stream_manager.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/hardware/hw_info.h"
 #include "gromacs/listed_forces/manage_threading.h"
@@ -3200,13 +3201,16 @@ gmx_bool change_dd_cutoff(t_commrec* cr, const matrix box, gmx::ArrayRef<const g
     return bCutoffAllowed;
 }
 
-void constructGpuHaloExchange(const gmx::MDLogger& mdlog,
-                              const t_commrec&     cr,
-                              const DeviceContext& deviceContext,
-                              const DeviceStream&  streamLocal,
-                              const DeviceStream&  streamNonLocal)
+void constructGpuHaloExchange(const gmx::MDLogger&            mdlog,
+                              const t_commrec&                cr,
+                              const gmx::DeviceStreamManager& deviceStreamManager)
 {
-
+    GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
+                       "Local non-bonded stream should be valid when using"
+                       "GPU halo exchange.");
+    GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedNonLocal),
+                       "Non-local non-bonded stream should be valid when using "
+                       "GPU halo exchange.");
     int gpuHaloExchangeSize = 0;
     int pulseStart          = 0;
     if (cr.dd->gpuHaloExchange.empty())
@@ -3228,7 +3232,9 @@ void constructGpuHaloExchange(const gmx::MDLogger& mdlog,
         for (int pulse = pulseStart; pulse < cr.dd->comm->cd[0].numPulses(); pulse++)
         {
             cr.dd->gpuHaloExchange.push_back(std::make_unique<gmx::GpuHaloExchange>(
-                    cr.dd, cr.mpi_comm_mysim, deviceContext, streamLocal, streamNonLocal, pulse));
+                    cr.dd, cr.mpi_comm_mysim, deviceStreamManager.context(),
+                    deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal),
+                    deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal), pulse));
         }
     }
 }