#include "gromacs/fileio/tpxio.h"
#include "gromacs/gmxlib/network.h"
#include "gromacs/gmxlib/nrnb.h"
+#include "gromacs/gpu_utils/device_context.h"
#include "gromacs/gpu_utils/gpu_utils.h"
#include "gromacs/hardware/cpuinfo.h"
#include "gromacs/hardware/detecthardware.h"
EEL_PME(inputrec->coulombtype) && thisRankHasDuty(cr, DUTY_PME));
// Get the device handles for the modules, nullptr when no task is assigned.
+ // TODO: There should be only one DeviceInformation.
DeviceInformation* nonbondedDeviceInfo = gpuTaskAssignments.initNonbondedDevice(cr);
DeviceInformation* pmeDeviceInfo = gpuTaskAssignments.initPmeDevice();
+ std::unique_ptr<DeviceContext> deviceContext = nullptr;
+ if (pmeDeviceInfo)
+ {
+ deviceContext = std::make_unique<DeviceContext>(*pmeDeviceInfo);
+ }
+ else if (nonbondedDeviceInfo)
+ {
+ deviceContext = std::make_unique<DeviceContext>(*nonbondedDeviceInfo);
+ }
+
// TODO Initialize GPU streams here.
// TODO Currently this is always built, yet DD partition code
opt2fn("-tablep", filenames.size(), filenames.data()),
opt2fns("-tableb", filenames.size(), filenames.data()), pforce);
+ fr->deviceContext = deviceContext.get();
+
if (devFlags.enableGpuPmePPComm && !thisRankHasDuty(cr, DUTY_PME))
{
- fr->pmePpCommGpu = std::make_unique<gmx::PmePpCommGpu>(cr->mpi_comm_mysim, cr->dd->pme_nodeid);
+ GMX_RELEASE_ASSERT(
+ deviceContext != nullptr,
+ "Device context can not be nullptr when PME-PP direct communications object.");
+ fr->pmePpCommGpu = std::make_unique<gmx::PmePpCommGpu>(
+ cr->mpi_comm_mysim, cr->dd->pme_nodeid, *deviceContext);
}
fr->nbv = Nbnxm::init_nb_verlet(mdlog, inputrec, fr, cr, *hwinfo, nonbondedDeviceInfo,
- &mtop, box, wcycle);
+ fr->deviceContext, &mtop, box, wcycle);
if (useGpuForBonded)
{
auto stream = havePPDomainDecomposition(cr)
fr->nbv->gpu_nbv, gmx::InteractionLocality::NonLocal)
: Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv,
gmx::InteractionLocality::Local);
- gpuBonded = std::make_unique<GpuBonded>(mtop.ffparams, stream, wcycle);
+ GMX_RELEASE_ASSERT(
+ fr->deviceContext != nullptr,
+ "Device context can not be nullptr when computing bonded interactions on GPU.");
+ gpuBonded = std::make_unique<GpuBonded>(mtop.ffparams, *fr->deviceContext, stream, wcycle);
fr->gpuBonded = gpuBonded.get();
}
PmeGpuProgramStorage pmeGpuProgram;
if (thisRankHasPmeGpuTask)
{
- pmeGpuProgram = buildPmeGpuProgram(pmeDeviceInfo);
+ GMX_RELEASE_ASSERT(
+ pmeDeviceInfo != nullptr,
+ "Device information can not be nullptr when building PME GPU program object.");
+ GMX_RELEASE_ASSERT(
+ deviceContext != nullptr,
+ "Device context can not be nullptr when building PME GPU program object.");
+ pmeGpuProgram = buildPmeGpuProgram(*pmeDeviceInfo, *deviceContext);
}
/* Initiate PME if necessary,
fr->nbv->gpu_nbv != nullptr
? Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, InteractionLocality::NonLocal)
: nullptr;
- const DeviceContext& deviceContext = *pme_gpu_get_device_context(fr->pmedata);
- const int paddingSize = pme_gpu_get_padding_size(fr->pmedata);
+ const int paddingSize = pme_gpu_get_padding_size(fr->pmedata);
GpuApiCallBehavior transferKind = (inputrec->eI == eiMD && !doRerun && !useModularSimulator)
? GpuApiCallBehavior::Async
: GpuApiCallBehavior::Sync;
-
+ GMX_RELEASE_ASSERT(
+ deviceContext != nullptr,
+ "Device context can not be nullptr when building GPU propagator data object.");
stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(
- pmeStream, localStream, nonLocalStream, deviceContext, transferKind, paddingSize, wcycle);
+ pmeStream, localStream, nonLocalStream, *deviceContext, transferKind,
+ paddingSize, wcycle);
fr->stateGpu = stateGpu.get();
}
GMX_RELEASE_ASSERT(pmedata, "pmedata was NULL while cr->duty was not DUTY_PP");
/* do PME only */
walltime_accounting = walltime_accounting_init(gmx_omp_nthreads_get(emntPME));
- gmx_pmeonly(pmedata, cr, &nrnb, wcycle, walltime_accounting, inputrec, pmeRunMode);
+ gmx_pmeonly(pmedata, cr, &nrnb, wcycle, walltime_accounting, inputrec, pmeRunMode,
+ deviceContext.get());
}
wallcycle_stop(wcycle, ewcRUN);
free_gpu(nonbondedDeviceInfo);
free_gpu(pmeDeviceInfo);
+ deviceContext.reset(nullptr);
sfree(fcd);
if (doMembed)