Take over management of OpenCL context from PME and NBNXM
[alexxy/gromacs.git] / src / gromacs / ewald / pme_only.cpp
index 2ee17b32674959e1d78363a29ed7d87420a8d1e3..845b1a33ecf45b99b78d846f0107533647e4f485 100644 (file)
@@ -603,7 +603,8 @@ int gmx_pmeonly(struct gmx_pme_t*         pme,
                 gmx_wallcycle*            wcycle,
                 gmx_walltime_accounting_t walltime_accounting,
                 t_inputrec*               ir,
-                PmeRunMode                runMode)
+                PmeRunMode                runMode,
+                const DeviceContext*      deviceContext)
 {
     int     ret;
     int     natoms = 0;
@@ -628,8 +629,7 @@ int gmx_pmeonly(struct gmx_pme_t*         pme,
     const bool useGpuForPme = (runMode == PmeRunMode::GPU) || (runMode == PmeRunMode::Mixed);
     if (useGpuForPme)
     {
-        const void*          commandStream = pme_gpu_get_device_stream(pme);
-        const DeviceContext& deviceContext = *pme_gpu_get_device_context(pme);
+        const void* commandStream = pme_gpu_get_device_stream(pme);
 
         changePinningPolicy(&pme_pp->chargeA, pme_get_pinning_policy());
         changePinningPolicy(&pme_pp->x, pme_get_pinning_policy());
@@ -640,10 +640,13 @@ int gmx_pmeonly(struct gmx_pme_t*         pme,
             pme_pp->pmeForceSenderGpu = std::make_unique<gmx::PmeForceSenderGpu>(
                     commandStream, pme_pp->mpi_comm_mysim, pme_pp->ppRanks);
         }
+        GMX_RELEASE_ASSERT(
+                deviceContext != nullptr,
+                "Device context can not be nullptr when building GPU propagator data object.");
         // TODO: Special PME-only constructor is used here. There is no mechanism to prevent from using the other constructor here.
         //       This should be made safer.
         stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(
-                commandStream, deviceContext, GpuApiCallBehavior::Async,
+                commandStream, *deviceContext, GpuApiCallBehavior::Async,
                 pme_gpu_get_padding_size(pme), wcycle);
     }