Make use of the DeviceStreamManager
[alexxy/gromacs.git] / src / gromacs / nbnxm / opencl / nbnxm_ocl_data_mgmt.cpp
index f11aa2d807b357d9ca2b9ef56e9d764fa8b56421..bc913e0e24d6f2e45bf3e8b89c791f4aae1d7642 100644 (file)
@@ -52,6 +52,7 @@
 
 #include <cmath>
 
+#include "gromacs/gpu_utils/device_stream_manager.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/gpu_utils/oclutils.h"
 #include "gromacs/hardware/gpu_hw_info.h"
@@ -485,7 +486,7 @@ static void nbnxn_ocl_clear_e_fshift(NbnxmGpu* nb)
 
     cl_int           cl_error;
     cl_atomdata_t*   adat = nb->atdat;
-    cl_command_queue ls   = nb->deviceStreams[InteractionLocality::Local].stream();
+    cl_command_queue ls   = nb->deviceStreams[InteractionLocality::Local]->stream();
 
     size_t local_work_size[3]  = { 1, 1, 1 };
     size_t global_work_size[3] = { 1, 1, 1 };
@@ -555,16 +556,16 @@ static void nbnxn_ocl_init_const(cl_atomdata_t*                  atomData,
 
 
 //! This function is documented in the header file
-NbnxmGpu* gpu_init(const DeviceContext&       deviceContext,
-                   const interaction_const_t* ic,
-                   const PairlistParams&      listParams,
-                   const nbnxn_atomdata_t*    nbat,
-                   const bool                 bLocalAndNonlocal)
+NbnxmGpu* gpu_init(const gmx::DeviceStreamManager& deviceStreamManager,
+                   const interaction_const_t*      ic,
+                   const PairlistParams&           listParams,
+                   const nbnxn_atomdata_t*         nbat,
+                   const bool                      bLocalAndNonlocal)
 {
     GMX_ASSERT(ic, "Need a valid interaction constants object");
 
     auto nb            = new NbnxmGpu();
-    nb->deviceContext_ = &deviceContext;
+    nb->deviceContext_ = &deviceStreamManager.context();
     snew(nb->atdat, 1);
     snew(nb->nbparam, 1);
     snew(nb->plist[InteractionLocality::Local], 1);
@@ -578,6 +579,7 @@ NbnxmGpu* gpu_init(const DeviceContext&       deviceContext,
     nb->timers = new cl_timers_t();
     snew(nb->timings, 1);
 
+    /* set device info, just point it to the right GPU among the detected ones */
     nb->dev_rundata = new gmx_device_runtime_data_t();
 
     /* init nbst */
@@ -591,15 +593,20 @@ NbnxmGpu* gpu_init(const DeviceContext&       deviceContext,
     nb->bDoTime = (getenv("GMX_DISABLE_GPU_TIMING") == nullptr);
 
     /* local/non-local GPU streams */
-    nb->deviceStreams[InteractionLocality::Local].init(*nb->deviceContext_,
-                                                       DeviceStreamPriority::Normal, nb->bDoTime);
+    GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedLocal),
+                       "Local non-bonded stream should be initialized to use GPU for non-bonded.");
+    nb->deviceStreams[InteractionLocality::Local] =
+            &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedLocal);
 
     if (nb->bUseTwoStreams)
     {
         init_plist(nb->plist[InteractionLocality::NonLocal]);
 
-        nb->deviceStreams[InteractionLocality::NonLocal].init(
-                *nb->deviceContext_, DeviceStreamPriority::High, nb->bDoTime);
+        GMX_RELEASE_ASSERT(deviceStreamManager.streamIsValid(gmx::DeviceStreamType::NonBondedNonLocal),
+                           "Non-local non-bonded stream should be initialized to use GPU for "
+                           "non-bonded with domain decomposition.");
+        nb->deviceStreams[InteractionLocality::NonLocal] =
+                &deviceStreamManager.stream(gmx::DeviceStreamType::NonBondedNonLocal);
     }
 
     if (nb->bDoTime)
@@ -647,7 +654,7 @@ static void nbnxn_ocl_clear_f(NbnxmGpu* nb, int natoms_clear)
     cl_int gmx_used_in_debug cl_error;
 
     cl_atomdata_t*   atomData = nb->atdat;
-    cl_command_queue ls       = nb->deviceStreams[InteractionLocality::Local].stream();
+    cl_command_queue ls       = nb->deviceStreams[InteractionLocality::Local]->stream();
     cl_float         value    = 0.0F;
 
     cl_error = clEnqueueFillBuffer(ls, atomData->f, &value, sizeof(cl_float), 0,
@@ -669,7 +676,7 @@ void gpu_clear_outputs(NbnxmGpu* nb, bool computeVirial)
 
     /* kick off buffer clearing kernel to ensure concurrency with constraints/update */
     cl_int gmx_unused cl_error;
-    cl_error = clFlush(nb->deviceStreams[InteractionLocality::Local].stream());
+    cl_error = clFlush(nb->deviceStreams[InteractionLocality::Local]->stream());
     GMX_ASSERT(cl_error == CL_SUCCESS, ("clFlush failed: " + ocl_get_error_string(cl_error)).c_str());
 }
 
@@ -681,7 +688,7 @@ void gpu_init_pairlist(NbnxmGpu* nb, const NbnxnPairlistGpu* h_plist, const Inte
     // because getLastRangeTime() gets skipped with empty lists later
     // which leads to the counter not being reset.
     bool                bDoTime      = (nb->bDoTime && !h_plist->sci.empty());
-    const DeviceStream& deviceStream = nb->deviceStreams[iloc];
+    const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
     cl_plist_t*         d_plist      = nb->plist[iloc];
 
     if (d_plist->na_c < 0)
@@ -740,7 +747,7 @@ void gpu_init_pairlist(NbnxmGpu* nb, const NbnxnPairlistGpu* h_plist, const Inte
 void gpu_upload_shiftvec(NbnxmGpu* nb, const nbnxn_atomdata_t* nbatom)
 {
     cl_atomdata_t*   adat = nb->atdat;
-    cl_command_queue ls   = nb->deviceStreams[InteractionLocality::Local].stream();
+    cl_command_queue ls   = nb->deviceStreams[InteractionLocality::Local]->stream();
 
     /* only if we have a dynamic box */
     if (nbatom->bDynamicBox || !adat->bShiftVecUploaded)
@@ -760,7 +767,7 @@ void gpu_init_atomdata(NbnxmGpu* nb, const nbnxn_atomdata_t* nbat)
     bool                bDoTime      = nb->bDoTime;
     cl_timers_t*        timers       = nb->timers;
     cl_atomdata_t*      d_atdat      = nb->atdat;
-    const DeviceStream& deviceStream = nb->deviceStreams[InteractionLocality::Local];
+    const DeviceStream& deviceStream = *nb->deviceStreams[InteractionLocality::Local];
 
     natoms    = nbat->numAtoms();
     realloced = false;