Take over management of OpenCL context from PME and NBNXM
[alexxy/gromacs.git] / src / gromacs / mdrun / runner.cpp
index b233b0737c9540c019f12b869fa4d906a2333ca4..081501bfff528aef8447c9dc2ec0799915b200e1 100644 (file)
@@ -73,6 +73,7 @@
 #include "gromacs/fileio/tpxio.h"
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/gmxlib/nrnb.h"
+#include "gromacs/gpu_utils/device_context.h"
 #include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/hardware/cpuinfo.h"
 #include "gromacs/hardware/detecthardware.h"
@@ -1140,9 +1141,20 @@ int Mdrunner::mdrunner()
             EEL_PME(inputrec->coulombtype) && thisRankHasDuty(cr, DUTY_PME));
 
     // Get the device handles for the modules, nullptr when no task is assigned.
+    // TODO: There should be only one DeviceInformation.
     DeviceInformation* nonbondedDeviceInfo = gpuTaskAssignments.initNonbondedDevice(cr);
     DeviceInformation* pmeDeviceInfo       = gpuTaskAssignments.initPmeDevice();
 
+    std::unique_ptr<DeviceContext> deviceContext = nullptr;
+    if (pmeDeviceInfo)
+    {
+        deviceContext = std::make_unique<DeviceContext>(*pmeDeviceInfo);
+    }
+    else if (nonbondedDeviceInfo)
+    {
+        deviceContext = std::make_unique<DeviceContext>(*nonbondedDeviceInfo);
+    }
+
     // TODO Initialize GPU streams here.
 
     // TODO Currently this is always built, yet DD partition code
@@ -1338,13 +1350,19 @@ int Mdrunner::mdrunner()
                       opt2fn("-tablep", filenames.size(), filenames.data()),
                       opt2fns("-tableb", filenames.size(), filenames.data()), pforce);
 
+        fr->deviceContext = deviceContext.get();
+
         if (devFlags.enableGpuPmePPComm && !thisRankHasDuty(cr, DUTY_PME))
         {
-            fr->pmePpCommGpu = std::make_unique<gmx::PmePpCommGpu>(cr->mpi_comm_mysim, cr->dd->pme_nodeid);
+            GMX_RELEASE_ASSERT(
+                    deviceContext != nullptr,
+                    "Device context can not be nullptr when PME-PP direct communications object.");
+            fr->pmePpCommGpu = std::make_unique<gmx::PmePpCommGpu>(
+                    cr->mpi_comm_mysim, cr->dd->pme_nodeid, *deviceContext);
         }
 
         fr->nbv = Nbnxm::init_nb_verlet(mdlog, inputrec, fr, cr, *hwinfo, nonbondedDeviceInfo,
-                                        &mtop, box, wcycle);
+                                        fr->deviceContext, &mtop, box, wcycle);
         if (useGpuForBonded)
         {
             auto stream = havePPDomainDecomposition(cr)
@@ -1352,7 +1370,10 @@ int Mdrunner::mdrunner()
                                             fr->nbv->gpu_nbv, gmx::InteractionLocality::NonLocal)
                                   : Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv,
                                                                   gmx::InteractionLocality::Local);
-            gpuBonded     = std::make_unique<GpuBonded>(mtop.ffparams, stream, wcycle);
+            GMX_RELEASE_ASSERT(
+                    fr->deviceContext != nullptr,
+                    "Device context can not be nullptr when computing bonded interactions on GPU.");
+            gpuBonded = std::make_unique<GpuBonded>(mtop.ffparams, *fr->deviceContext, stream, wcycle);
             fr->gpuBonded = gpuBonded.get();
         }
 
@@ -1428,7 +1449,13 @@ int Mdrunner::mdrunner()
     PmeGpuProgramStorage pmeGpuProgram;
     if (thisRankHasPmeGpuTask)
     {
-        pmeGpuProgram = buildPmeGpuProgram(pmeDeviceInfo);
+        GMX_RELEASE_ASSERT(
+                pmeDeviceInfo != nullptr,
+                "Device information can not be nullptr when building PME GPU program object.");
+        GMX_RELEASE_ASSERT(
+                deviceContext != nullptr,
+                "Device context can not be nullptr when building PME GPU program object.");
+        pmeGpuProgram = buildPmeGpuProgram(*pmeDeviceInfo, *deviceContext);
     }
 
     /* Initiate PME if necessary,
@@ -1566,14 +1593,16 @@ int Mdrunner::mdrunner()
                     fr->nbv->gpu_nbv != nullptr
                             ? Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, InteractionLocality::NonLocal)
                             : nullptr;
-            const DeviceContext& deviceContext = *pme_gpu_get_device_context(fr->pmedata);
-            const int            paddingSize   = pme_gpu_get_padding_size(fr->pmedata);
+            const int          paddingSize = pme_gpu_get_padding_size(fr->pmedata);
             GpuApiCallBehavior transferKind = (inputrec->eI == eiMD && !doRerun && !useModularSimulator)
                                                       ? GpuApiCallBehavior::Async
                                                       : GpuApiCallBehavior::Sync;
-
+            GMX_RELEASE_ASSERT(
+                    deviceContext != nullptr,
+                    "Device context can not be nullptr when building GPU propagator data object.");
             stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(
-                    pmeStream, localStream, nonLocalStream, deviceContext, transferKind, paddingSize, wcycle);
+                    pmeStream, localStream, nonLocalStream, *deviceContext, transferKind,
+                    paddingSize, wcycle);
             fr->stateGpu = stateGpu.get();
         }
 
@@ -1608,7 +1637,8 @@ int Mdrunner::mdrunner()
         GMX_RELEASE_ASSERT(pmedata, "pmedata was NULL while cr->duty was not DUTY_PP");
         /* do PME only */
         walltime_accounting = walltime_accounting_init(gmx_omp_nthreads_get(emntPME));
-        gmx_pmeonly(pmedata, cr, &nrnb, wcycle, walltime_accounting, inputrec, pmeRunMode);
+        gmx_pmeonly(pmedata, cr, &nrnb, wcycle, walltime_accounting, inputrec, pmeRunMode,
+                    deviceContext.get());
     }
 
     wallcycle_stop(wcycle, ewcRUN);
@@ -1670,6 +1700,7 @@ int Mdrunner::mdrunner()
 
     free_gpu(nonbondedDeviceInfo);
     free_gpu(pmeDeviceInfo);
+    deviceContext.reset(nullptr);
     sfree(fcd);
 
     if (doMembed)