Pass the GPU streams to StatePropagatorDataGpu constructor
[alexxy/gromacs.git] / src / gromacs / mdrun / runner.cpp
index f7f24405b553349924a728a8d2232533b272355d..92057c8320fe263aaad67100d7d79f35a817ab3e 100644 (file)
@@ -1501,26 +1501,32 @@ int Mdrunner::mdrunner()
                                                          fcd->orires.nr != 0,
                                                          fcd->disres.nsystems != 0);
 
-        const void *commandStream = ((GMX_GPU == GMX_GPU_OPENCL) && thisRankHasPmeGpuTask) ? pme_gpu_get_device_stream(fr->pmedata) : nullptr;
-        const void *deviceContext = ((GMX_GPU == GMX_GPU_OPENCL) && thisRankHasPmeGpuTask) ? pme_gpu_get_device_context(fr->pmedata) : nullptr;
-        const int   paddingSize   = pme_gpu_get_padding_size(fr->pmedata);
-
-        const bool  inputIsCompatibleWithModularSimulator = ModularSimulator::isInputCompatible(
+        const bool inputIsCompatibleWithModularSimulator = ModularSimulator::isInputCompatible(
                     false,
                     inputrec, doRerun, vsite.get(), ms, replExParams,
                     fcd, static_cast<int>(filenames.size()), filenames.data(),
                     &observablesHistory, membed);
 
-        const bool          useModularSimulator = inputIsCompatibleWithModularSimulator && !(getenv("GMX_DISABLE_MODULAR_SIMULATOR") != nullptr);
-        GpuApiCallBehavior  transferKind        = (inputrec->eI == eiMD && !doRerun && !useModularSimulator) ? GpuApiCallBehavior::Async : GpuApiCallBehavior::Sync;
+        const bool useModularSimulator = inputIsCompatibleWithModularSimulator && !(getenv("GMX_DISABLE_MODULAR_SIMULATOR") != nullptr);
 
-        // We initialize GPU state even for the CPU runs so we will have a more verbose
-        // error if someone will try accessing it from the CPU codepath
-        gmx::StatePropagatorDataGpu stateGpu(commandStream,
-                                             deviceContext,
-                                             transferKind,
-                                             paddingSize);
-        fr->stateGpu = &stateGpu;
+        std::unique_ptr<gmx::StatePropagatorDataGpu> stateGpu;
+        if (gpusWereDetected && ((useGpuForPme && thisRankHasDuty(cr, DUTY_PME)) || useGpuForUpdate))
+        {
+            const void         *pmeStream      = pme_gpu_get_device_stream(fr->pmedata);
+            const void         *localStream    = fr->nbv->gpu_nbv != nullptr ? Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, Nbnxm::InteractionLocality::Local) : nullptr;
+            const void         *nonLocalStream = fr->nbv->gpu_nbv != nullptr ? Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, Nbnxm::InteractionLocality::NonLocal) : nullptr;
+            const void         *deviceContext  = pme_gpu_get_device_context(fr->pmedata);
+            const int           paddingSize    = pme_gpu_get_padding_size(fr->pmedata);
+            GpuApiCallBehavior  transferKind   = (inputrec->eI == eiMD && !doRerun && !useModularSimulator) ? GpuApiCallBehavior::Async : GpuApiCallBehavior::Sync;
+
+            stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(pmeStream,
+                                                                     localStream,
+                                                                     nonLocalStream,
+                                                                     deviceContext,
+                                                                     transferKind,
+                                                                     paddingSize);
+            fr->stateGpu = stateGpu.get();
+        }
 
         // TODO This is not the right place to manage the lifetime of
         // this data structure, but currently it's the easiest way to