Take over management of OpenCL context from PME and NBNXM
[alexxy/gromacs.git] / src / gromacs / mdrun / md.cpp
index 2cb1388dcabe0b8a1175e9dff2b611f5e155a720..941a7030c9eca738ea4dee3201bbefd4c0628db7 100644 (file)
@@ -400,8 +400,13 @@ void gmx::LegacySimulator::do_md()
         {
             GMX_LOG(mdlog.info).asParagraph().appendText("Updating coordinates on the GPU.");
         }
-        integrator = std::make_unique<UpdateConstrainGpu>(
-                *ir, *top_global, stateGpu->getUpdateStream(), stateGpu->xUpdatedOnDevice());
+
+        GMX_RELEASE_ASSERT(fr->deviceContext != nullptr,
+                           "GPU device context should be initialized to use GPU update.");
+
+        integrator = std::make_unique<UpdateConstrainGpu>(*ir, *top_global, *fr->deviceContext,
+                                                          stateGpu->getUpdateStream(),
+                                                          stateGpu->xUpdatedOnDevice());
 
         integrator->setPbc(PbcType::Xyz, state->box);
     }
@@ -866,7 +871,10 @@ void gmx::LegacySimulator::do_md()
                             Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, InteractionLocality::Local);
                     void* streamNonLocal = Nbnxm::gpu_get_command_stream(
                             fr->nbv->gpu_nbv, InteractionLocality::NonLocal);
-                    constructGpuHaloExchange(mdlog, *cr, streamLocal, streamNonLocal);
+                    GMX_RELEASE_ASSERT(
+                            fr->deviceContext != nullptr,
+                            "GPU device context should be initialized to use GPU halo exchange.");
+                    constructGpuHaloExchange(mdlog, *cr, *fr->deviceContext, streamLocal, streamNonLocal);
                 }
             }
         }