flags.useGpuXHalo = simulationWork.useGpuHaloExchange;
flags.useGpuFHalo = simulationWork.useGpuHaloExchange && flags.useGpuFBufferOps;
flags.haveGpuPmeOnThisRank = simulationWork.useGpuPme && rankHasPmeDuty && flags.computeSlowForces;
+ flags.combineMtsForcesBeforeHaloExchange =
+ (flags.computeForces && simulationWork.useMts && flags.computeSlowForces
+ && flags.useOnlyMtsCombinedForceBuffer
+ && !(flags.computeVirial || simulationWork.useGpuNonbonded || flags.haveGpuPmeOnThisRank));
return flags;
}
*/
static int getLocalAtomCount(const gmx_domdec_t* dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition)
{
- GMX_ASSERT(!(havePPDomainDecomposition && (dd == nullptr)), "Can't have PP decomposition with dd uninitialized!");
+ GMX_ASSERT(!(havePPDomainDecomposition && (dd == nullptr)),
+ "Can't have PP decomposition with dd uninitialized!");
return havePPDomainDecomposition ? dd_numAtomsZones(*dd) : mdatoms.homenr;
}
// Force output for MTS combined forces, only set at level1 MTS steps
std::optional<ForceOutputs> forceOutMts =
- (fr->useMts && stepWork.computeSlowForces)
+ (simulationWork.useMts && stepWork.computeSlowForces)
? std::optional(setupForceOutputs(&fr->forceHelperBuffers[1],
forceView->forceMtsCombinedWithPadding(),
domainWork,
: std::nullopt;
ForceOutputs* forceOutMtsLevel1 =
- fr->useMts ? (stepWork.computeSlowForces ? &forceOutMts.value() : nullptr) : &forceOutMtsLevel0;
+ simulationWork.useMts ? (stepWork.computeSlowForces ? &forceOutMts.value() : nullptr)
+ : &forceOutMtsLevel0;
const bool nonbondedAtMtsLevel1 = runScheduleWork->simulationWork.computeNonbondedAtMtsLevel1;
set_pbc_dd(&pbc, fr->pbcType, DOMAINDECOMP(cr) ? cr->dd->numCells : nullptr, TRUE, box);
}
- for (int mtsIndex = 0; mtsIndex < (fr->useMts && stepWork.computeSlowForces ? 2 : 1); mtsIndex++)
+ for (int mtsIndex = 0; mtsIndex < (simulationWork.useMts && stepWork.computeSlowForces ? 2 : 1);
+ mtsIndex++)
{
ListedForces& listedForces = fr->listedForces[mtsIndex];
ForceOutputs& forceOut = (mtsIndex == 0 ? forceOutMtsLevel0 : *forceOutMtsLevel1);
/* Combining the forces for multiple time stepping before the halo exchange, when possible,
* avoids an extra halo exchange (when DD is used) and post-processing step.
*/
- const bool combineMtsForcesBeforeHaloExchange =
- (stepWork.computeForces && fr->useMts && stepWork.computeSlowForces && stepWork.useOnlyMtsCombinedForceBuffer
- && !(stepWork.computeVirial || simulationWork.useGpuNonbonded || stepWork.haveGpuPmeOnThisRank));
- if (combineMtsForcesBeforeHaloExchange)
+ if (stepWork.combineMtsForcesBeforeHaloExchange)
{
combineMtsForces(getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)),
force.unpaddedArrayRef(),
// Without MTS or with MTS at slow steps with uncombined forces we need to
// communicate the fast forces
- if (!fr->useMts || !combineMtsForcesBeforeHaloExchange)
+ if (!simulationWork.useMts || !stepWork.combineMtsForcesBeforeHaloExchange)
{
dd_move_f(cr->dd, &forceOutMtsLevel0.forceWithShiftForces(), wcycle);
}
// With MTS we need to communicate the slow or combined (in forceOutMtsLevel1) forces
- if (fr->useMts && stepWork.computeSlowForces)
+ if (simulationWork.useMts && stepWork.computeSlowForces)
{
dd_move_f(cr->dd, &forceOutMtsLevel1->forceWithShiftForces(), wcycle);
}
dd_force_flop_stop(cr->dd, nrnb);
}
- const bool haveCombinedMtsForces = (stepWork.computeForces && fr->useMts && stepWork.computeSlowForces
- && combineMtsForcesBeforeHaloExchange);
+ const bool haveCombinedMtsForces = (stepWork.computeForces && simulationWork.useMts && stepWork.computeSlowForces
+ && stepWork.combineMtsForcesBeforeHaloExchange);
if (stepWork.computeForces)
{
postProcessForceWithShiftForces(
nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOutMtsLevel0, vir_force, *mdatoms, *fr, vsite, stepWork);
- if (fr->useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
+ if (simulationWork.useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
{
postProcessForceWithShiftForces(
nrnb, wcycle, box, x.unpaddedArrayRef(), forceOutMtsLevel1, vir_force, *mdatoms, *fr, vsite, stepWork);
postProcessForces(
cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOutCombined, vir_force, mdatoms, fr, vsite, stepWork);
- if (fr->useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
+ if (simulationWork.useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
{
postProcessForces(
cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), forceOutMtsLevel1, vir_force, mdatoms, fr, vsite, stepWork);