Move useMts flag from forcerec to simulationWorkload
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index 1439c40199f16d876c7a5acd7b5064c3a6b4c350..8795ec6188643a4d001b362b55928739c0831ed7 100644 (file)
@@ -992,6 +992,10 @@ static StepWorkload setupStepWorkload(const int                     legacyFlags,
     flags.useGpuXHalo          = simulationWork.useGpuHaloExchange;
     flags.useGpuFHalo          = simulationWork.useGpuHaloExchange && flags.useGpuFBufferOps;
     flags.haveGpuPmeOnThisRank = simulationWork.useGpuPme && rankHasPmeDuty && flags.computeSlowForces;
+    flags.combineMtsForcesBeforeHaloExchange =
+            (flags.computeForces && simulationWork.useMts && flags.computeSlowForces
+             && flags.useOnlyMtsCombinedForceBuffer
+             && !(flags.computeVirial || simulationWork.useGpuNonbonded || flags.haveGpuPmeOnThisRank));
 
     return flags;
 }
@@ -1202,7 +1206,8 @@ static void setupGpuForceReductions(gmx::MdrunScheduleWorkload* runScheduleWork,
  */
 static int getLocalAtomCount(const gmx_domdec_t* dd, const t_mdatoms& mdatoms, bool havePPDomainDecomposition)
 {
-    GMX_ASSERT(!(havePPDomainDecomposition && (dd == nullptr)), "Can't have PP decomposition with dd uninitialized!");
+    GMX_ASSERT(!(havePPDomainDecomposition && (dd == nullptr)),
+               "Can't have PP decomposition with dd uninitialized!");
     return havePPDomainDecomposition ? dd_numAtomsZones(*dd) : mdatoms.homenr;
 }
 
@@ -1735,7 +1740,7 @@ void do_force(FILE*                               fplog,
 
     // Force output for MTS combined forces, only set at level1 MTS steps
     std::optional<ForceOutputs> forceOutMts =
-            (fr->useMts && stepWork.computeSlowForces)
+            (simulationWork.useMts && stepWork.computeSlowForces)
                     ? std::optional(setupForceOutputs(&fr->forceHelperBuffers[1],
                                                       forceView->forceMtsCombinedWithPadding(),
                                                       domainWork,
@@ -1745,7 +1750,8 @@ void do_force(FILE*                               fplog,
                     : std::nullopt;
 
     ForceOutputs* forceOutMtsLevel1 =
-            fr->useMts ? (stepWork.computeSlowForces ? &forceOutMts.value() : nullptr) : &forceOutMtsLevel0;
+            simulationWork.useMts ? (stepWork.computeSlowForces ? &forceOutMts.value() : nullptr)
+                                  : &forceOutMtsLevel0;
 
     const bool nonbondedAtMtsLevel1 = runScheduleWork->simulationWork.computeNonbondedAtMtsLevel1;
 
@@ -1914,7 +1920,8 @@ void do_force(FILE*                               fplog,
             set_pbc_dd(&pbc, fr->pbcType, DOMAINDECOMP(cr) ? cr->dd->numCells : nullptr, TRUE, box);
         }
 
-        for (int mtsIndex = 0; mtsIndex < (fr->useMts && stepWork.computeSlowForces ? 2 : 1); mtsIndex++)
+        for (int mtsIndex = 0; mtsIndex < (simulationWork.useMts && stepWork.computeSlowForces ? 2 : 1);
+             mtsIndex++)
         {
             ListedForces& listedForces = fr->listedForces[mtsIndex];
             ForceOutputs& forceOut     = (mtsIndex == 0 ? forceOutMtsLevel0 : *forceOutMtsLevel1);
@@ -2073,10 +2080,7 @@ void do_force(FILE*                               fplog,
     /* Combining the forces for multiple time stepping before the halo exchange, when possible,
      * avoids an extra halo exchange (when DD is used) and post-processing step.
      */
-    const bool combineMtsForcesBeforeHaloExchange =
-            (stepWork.computeForces && fr->useMts && stepWork.computeSlowForces && stepWork.useOnlyMtsCombinedForceBuffer
-             && !(stepWork.computeVirial || simulationWork.useGpuNonbonded || stepWork.haveGpuPmeOnThisRank));
-    if (combineMtsForcesBeforeHaloExchange)
+    if (stepWork.combineMtsForcesBeforeHaloExchange)
     {
         combineMtsForces(getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)),
                          force.unpaddedArrayRef(),
@@ -2117,12 +2121,12 @@ void do_force(FILE*                               fplog,
 
                 // Without MTS or with MTS at slow steps with uncombined forces we need to
                 // communicate the fast forces
-                if (!fr->useMts || !combineMtsForcesBeforeHaloExchange)
+                if (!simulationWork.useMts || !stepWork.combineMtsForcesBeforeHaloExchange)
                 {
                     dd_move_f(cr->dd, &forceOutMtsLevel0.forceWithShiftForces(), wcycle);
                 }
                 // With MTS we need to communicate the slow or combined (in forceOutMtsLevel1) forces
-                if (fr->useMts && stepWork.computeSlowForces)
+                if (simulationWork.useMts && stepWork.computeSlowForces)
                 {
                     dd_move_f(cr->dd, &forceOutMtsLevel1->forceWithShiftForces(), wcycle);
                 }
@@ -2288,14 +2292,14 @@ void do_force(FILE*                               fplog,
         dd_force_flop_stop(cr->dd, nrnb);
     }
 
-    const bool haveCombinedMtsForces = (stepWork.computeForces && fr->useMts && stepWork.computeSlowForces
-                                        && combineMtsForcesBeforeHaloExchange);
+    const bool haveCombinedMtsForces = (stepWork.computeForces && simulationWork.useMts && stepWork.computeSlowForces
+                                        && stepWork.combineMtsForcesBeforeHaloExchange);
     if (stepWork.computeForces)
     {
         postProcessForceWithShiftForces(
                 nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOutMtsLevel0, vir_force, *mdatoms, *fr, vsite, stepWork);
 
-        if (fr->useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
+        if (simulationWork.useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
         {
             postProcessForceWithShiftForces(
                     nrnb, wcycle, box, x.unpaddedArrayRef(), forceOutMtsLevel1, vir_force, *mdatoms, *fr, vsite, stepWork);
@@ -2328,7 +2332,7 @@ void do_force(FILE*                               fplog,
         postProcessForces(
                 cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOutCombined, vir_force, mdatoms, fr, vsite, stepWork);
 
-        if (fr->useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
+        if (simulationWork.useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
         {
             postProcessForces(
                     cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), forceOutMtsLevel1, vir_force, mdatoms, fr, vsite, stepWork);