Move computeSlowForces into stepWork
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index eed67b4540f411755c78505b3d43066558f5c53f..2af9a25280dafb8e2baa4b3bb74ace3919039c2e 100644 (file)
@@ -44,6 +44,7 @@
 #include <cstring>
 
 #include <array>
+#include <optional>
 
 #include "gromacs/applied_forces/awh/awh.h"
 #include "gromacs/domdec/dlbtiming.h"
@@ -92,6 +93,7 @@
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/mdtypes/mdatom.h"
+#include "gromacs/mdtypes/multipletimestepping.h"
 #include "gromacs/mdtypes/simulation_workload.h"
 #include "gromacs/mdtypes/state.h"
 #include "gromacs/mdtypes/state_propagator_data_gpu.h"
@@ -716,7 +718,8 @@ static void launchPmeGpuFftAndGather(gmx_pme_t*               pmedata,
  *
  * \param[in]     nbv              Nonbonded verlet structure
  * \param[in,out] pmedata          PME module data
- * \param[in,out] forceOutputs     Output buffer for the forces and virial
+ * \param[in,out] forceOutputsNonbonded  Force outputs for the non-bonded forces and shift forces
+ * \param[in,out] forceOutputsPme  Force outputs for the PME forces and virial
  * \param[in,out] enerd            Energy data structure results are reduced into
  * \param[in]     lambdaQ          The Coulomb lambda of the current system state.
  * \param[in]     stepWork         Step schedule flags
@@ -724,7 +727,8 @@ static void launchPmeGpuFftAndGather(gmx_pme_t*               pmedata,
  */
 static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
                                         gmx_pme_t*          pmedata,
-                                        gmx::ForceOutputs*  forceOutputs,
+                                        gmx::ForceOutputs*  forceOutputsNonbonded,
+                                        gmx::ForceOutputs*  forceOutputsPme,
                                         gmx_enerdata_t*     enerd,
                                         const real          lambdaQ,
                                         const StepWorkload& stepWork,
@@ -733,10 +737,6 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
     bool isPmeGpuDone = false;
     bool isNbGpuDone  = false;
 
-
-    gmx::ForceWithShiftForces& forceWithShiftForces = forceOutputs->forceWithShiftForces();
-    gmx::ForceWithVirial&      forceWithVirial      = forceOutputs->forceWithVirial();
-
     gmx::ArrayRef<const gmx::RVec> pmeGpuForces;
 
     while (!isPmeGpuDone || !isNbGpuDone)
@@ -745,22 +745,24 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
         {
             GpuTaskCompletion completionType =
                     (isNbGpuDone) ? GpuTaskCompletion::Wait : GpuTaskCompletion::Check;
-            isPmeGpuDone = pme_gpu_try_finish_task(pmedata, stepWork, wcycle, &forceWithVirial,
-                                                   enerd, lambdaQ, completionType);
+            isPmeGpuDone = pme_gpu_try_finish_task(pmedata, stepWork, wcycle,
+                                                   &forceOutputsPme->forceWithVirial(), enerd,
+                                                   lambdaQ, completionType);
         }
 
         if (!isNbGpuDone)
         {
+            auto&             forceBuffersNonbonded = forceOutputsNonbonded->forceWithShiftForces();
             GpuTaskCompletion completionType =
                     (isPmeGpuDone) ? GpuTaskCompletion::Wait : GpuTaskCompletion::Check;
             isNbGpuDone = Nbnxm::gpu_try_finish_task(
                     nbv->gpu_nbv, stepWork, AtomLocality::Local, enerd->grpp.ener[egLJSR].data(),
-                    enerd->grpp.ener[egCOULSR].data(), forceWithShiftForces.shiftForces(),
+                    enerd->grpp.ener[egCOULSR].data(), forceBuffersNonbonded.shiftForces(),
                     completionType, wcycle);
 
             if (isNbGpuDone)
             {
-                nbv->atomdata_add_nbat_f_to_f(AtomLocality::Local, forceWithShiftForces.force());
+                nbv->atomdata_add_nbat_f_to_f(AtomLocality::Local, forceBuffersNonbonded.force());
             }
         }
     }
@@ -769,8 +771,6 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
 /*! \brief Set up the different force buffers; also does clearing.
  *
  * \param[in] forceHelperBuffers  Helper force buffers
- * \param[in] pull_work The pull work object.
- * \param[in] inputrec  input record
  * \param[in] force     force array
  * \param[in] stepWork  Step schedule flags
  * \param[out] wcycle   wallcycle recording structure
@@ -778,8 +778,6 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t* nbv,
  * \returns             Cleared force output structure
  */
 static ForceOutputs setupForceOutputs(ForceHelperBuffers*                 forceHelperBuffers,
-                                      pull_t*                             pull_work,
-                                      const t_inputrec&                   inputrec,
                                       gmx::ArrayRefWithPadding<gmx::RVec> force,
                                       const StepWorkload&                 stepWork,
                                       gmx_wallcycle_t                     wcycle)
@@ -823,11 +821,6 @@ static ForceOutputs setupForceOutputs(ForceHelperBuffers*                 forceH
         clearRVecs(forceWithVirial.force_, true);
     }
 
-    if (inputrec.bPull && pull_have_constraint(pull_work))
-    {
-        clear_pull_forces(pull_work);
-    }
-
     wallcycle_sub_stop(wcycle, ewcsCLEAR_FORCE_BUFFER);
 
     return ForceOutputs(forceWithShiftForces, forceHelperBuffers->haveDirectVirialContributions(),
@@ -878,25 +871,34 @@ static DomainLifetimeWorkload setupDomainLifetimeWorkload(const t_inputrec&
 /*! \brief Set up force flag stuct from the force bitmask.
  *
  * \param[in]      legacyFlags          Force bitmask flags used to construct the new flags
+ * \param[in]      mtsLevels            The multiple time-stepping levels, either empty or 2 levels
+ * \param[in]      step                 The current MD step
  * \param[in]      simulationWork       Simulation workload description.
  * \param[in]      rankHasPmeDuty       If this rank computes PME.
  *
  * \returns New Stepworkload description.
  */
-static StepWorkload setupStepWorkload(const int                 legacyFlags,
-                                      const SimulationWorkload& simulationWork,
-                                      const bool                rankHasPmeDuty)
+static StepWorkload setupStepWorkload(const int                     legacyFlags,
+                                      ArrayRef<const gmx::MtsLevel> mtsLevels,
+                                      const int64_t                 step,
+                                      const SimulationWorkload&     simulationWork,
+                                      const bool                    rankHasPmeDuty)
 {
+    GMX_ASSERT(mtsLevels.empty() || mtsLevels.size() == 2, "Expect 0 or 2 MTS levels");
+    const bool computeSlowForces = (mtsLevels.empty() || step % mtsLevels[1].stepFactor == 0);
+
     StepWorkload flags;
     flags.stateChanged        = ((legacyFlags & GMX_FORCE_STATECHANGED) != 0);
     flags.haveDynamicBox      = ((legacyFlags & GMX_FORCE_DYNAMICBOX) != 0);
     flags.doNeighborSearch    = ((legacyFlags & GMX_FORCE_NS) != 0);
+    flags.computeSlowForces   = computeSlowForces;
     flags.computeVirial       = ((legacyFlags & GMX_FORCE_VIRIAL) != 0);
     flags.computeEnergy       = ((legacyFlags & GMX_FORCE_ENERGY) != 0);
     flags.computeForces       = ((legacyFlags & GMX_FORCE_FORCES) != 0);
     flags.computeListedForces = ((legacyFlags & GMX_FORCE_LISTED) != 0);
     flags.computeNonbondedForces =
-            ((legacyFlags & GMX_FORCE_NONBONDED) != 0) && simulationWork.computeNonbonded;
+            ((legacyFlags & GMX_FORCE_NONBONDED) != 0) && simulationWork.computeNonbonded
+            && !(simulationWork.computeNonbondedAtMtsLevel1 && !computeSlowForces);
     flags.computeDhdl = ((legacyFlags & GMX_FORCE_DHDL) != 0);
 
     if (simulationWork.useGpuBufferOps)
@@ -906,10 +908,9 @@ static StepWorkload setupStepWorkload(const int                 legacyFlags,
     }
     flags.useGpuXBufferOps = simulationWork.useGpuBufferOps;
     // on virial steps the CPU reduction path is taken
-    flags.useGpuFBufferOps    = simulationWork.useGpuBufferOps && !flags.computeVirial;
-    flags.useGpuPmeFReduction = flags.useGpuFBufferOps
-                                && (simulationWork.useGpuPme
-                                    && (rankHasPmeDuty || simulationWork.useGpuPmePpCommunication));
+    flags.useGpuFBufferOps = simulationWork.useGpuBufferOps && !flags.computeVirial;
+    flags.useGpuPmeFReduction = flags.computeSlowForces && flags.useGpuFBufferOps && simulationWork.useGpuPme
+                                && (rankHasPmeDuty || simulationWork.useGpuPmePpCommunication);
 
     return flags;
 }
@@ -929,7 +930,7 @@ static void launchGpuEndOfStepTasks(nonbonded_verlet_t*               nbv,
                                     int64_t                           step,
                                     gmx_wallcycle_t                   wcycle)
 {
-    if (runScheduleWork.simulationWork.useGpuNonbonded)
+    if (runScheduleWork.simulationWork.useGpuNonbonded && runScheduleWork.stepWork.computeNonbondedForces)
     {
         /* Launch pruning before buffer clearing because the API overhead of the
          * clear kernel launches can leave the GPU idle while it could be running
@@ -1008,6 +1009,27 @@ static void reduceAndUpdateMuTot(DipoleData*                   dipoleData,
     }
 }
 
+/*! \brief Combines MTS level0 and level1 force buffes into a full and MTS-combined force buffer.
+ *
+ * \param[in]     numAtoms        The number of atoms to combine forces for
+ * \param[in,out] forceMtsLevel0  Input: F_level0, output: F_level0 + F_level1
+ * \param[in,out] forceMts        Input: F_level1, output: F_level0 + mtsFactor * F_level1
+ * \param[in]     mtsFactor       The factor between the level0 and level1 time step
+ */
+static void combineMtsForces(const int      numAtoms,
+                             ArrayRef<RVec> forceMtsLevel0,
+                             ArrayRef<RVec> forceMts,
+                             const real     mtsFactor)
+{
+    const int gmx_unused numThreads = gmx_omp_nthreads_get(emntDefault);
+#pragma omp parallel for num_threads(numThreads) schedule(static)
+    for (int i = 0; i < numAtoms; i++)
+    {
+        const RVec forceMtsLevel0Tmp = forceMtsLevel0[i];
+        forceMtsLevel0[i] += forceMts[i];
+        forceMts[i] = forceMtsLevel0Tmp + mtsFactor * forceMts[i];
+    }
+}
 
 /*! \brief Setup for the local and non-local GPU force reductions:
  * reinitialization plus the registration of forces and dependencies.
@@ -1120,19 +1142,19 @@ void do_force(FILE*                               fplog,
                "The size of the force buffer should be at least the number of atoms to compute "
                "forces for");
 
-    nonbonded_verlet_t*          nbv      = fr->nbv.get();
-    interaction_const_t*         ic       = fr->ic;
+    nonbonded_verlet_t*  nbv = fr->nbv.get();
+    interaction_const_t* ic  = fr->ic;
+
     gmx::StatePropagatorDataGpu* stateGpu = fr->stateGpu;
 
     const SimulationWorkload& simulationWork = runScheduleWork->simulationWork;
 
-
-    runScheduleWork->stepWork =
-            setupStepWorkload(legacyFlags, simulationWork, thisRankHasDuty(cr, DUTY_PME));
+    runScheduleWork->stepWork    = setupStepWorkload(legacyFlags, inputrec->mtsLevels, step,
+                                                  simulationWork, thisRankHasDuty(cr, DUTY_PME));
     const StepWorkload& stepWork = runScheduleWork->stepWork;
 
-
-    const bool useGpuPmeOnThisRank = simulationWork.useGpuPme && thisRankHasDuty(cr, DUTY_PME);
+    const bool useGpuPmeOnThisRank =
+            simulationWork.useGpuPme && thisRankHasDuty(cr, DUTY_PME) && stepWork.computeSlowForces;
 
     /* At a search step we need to start the first balancing region
      * somewhere early inside the step after communication during domain
@@ -1179,7 +1201,7 @@ void do_force(FILE*                               fplog,
 
     // If coordinates are to be sent to PME task from CPU memory, perform that send here.
     // Otherwise the send will occur after H2D coordinate transfer.
-    if (GMX_MPI && !thisRankHasDuty(cr, DUTY_PME) && !pmeSendCoordinatesFromGpu)
+    if (GMX_MPI && !thisRankHasDuty(cr, DUTY_PME) && !pmeSendCoordinatesFromGpu && stepWork.computeSlowForces)
     {
         /* Send particle coordinates to the pme nodes */
         if (!stepWork.doNeighborSearch && simulationWork.useGpuUpdate)
@@ -1364,7 +1386,7 @@ void do_force(FILE*                               fplog,
             setupGpuForceReductions(runScheduleWork, cr, fr, ddUsesGpuDirectCommunication);
         }
     }
-    else if (!EI_TPI(inputrec->eI))
+    else if (!EI_TPI(inputrec->eI) && stepWork.computeNonbondedForces)
     {
         if (stepWork.useGpuXBufferOps)
         {
@@ -1385,7 +1407,7 @@ void do_force(FILE*                               fplog,
         }
     }
 
-    if (simulationWork.useGpuNonbonded)
+    if (simulationWork.useGpuNonbonded && (stepWork.computeNonbondedForces || domainWork.haveGpuBondedWork))
     {
         ddBalanceRegionHandler.openBeforeForceComputationGpu();
 
@@ -1516,7 +1538,7 @@ void do_force(FILE*                               fplog,
         }
     }
 
-    if (simulationWork.useGpuNonbonded)
+    if (simulationWork.useGpuNonbonded && stepWork.computeNonbondedForces)
     {
         /* launch D2H copy-back F */
         wallcycle_start_nocount(wcycle, ewcLAUNCH_GPU);
@@ -1589,10 +1611,35 @@ void do_force(FILE*                               fplog,
      */
     wallcycle_start(wcycle, ewcFORCE);
 
-    // Set up and clear force outputs.
-    // We use std::move to keep the compiler happy, it has no effect.
-    ForceOutputs forceOut = setupForceOutputs(fr->forceHelperBuffers.get(), pull_work, *inputrec,
-                                              std::move(force), stepWork, wcycle);
+    /* Set up and clear force outputs:
+     * forceOutMtsLevel0:  everything except what is in the other two outputs
+     * forceOutMtsLevel1:  PME-mesh and listed-forces group 1
+     * forceOutNonbonded: non-bonded forces
+     * Without multiple time stepping all point to the same object.
+     * With multiple time-stepping the use is different for MTS fast (level0 only) and slow steps.
+     */
+    ForceOutputs forceOutMtsLevel0 =
+            setupForceOutputs(&fr->forceHelperBuffers[0], force, stepWork, wcycle);
+
+    // Force output for MTS combined forces, only set at level1 MTS steps
+    std::optional<ForceOutputs> forceOutMts =
+            (fr->useMts && stepWork.computeSlowForces)
+                    ? std::optional(setupForceOutputs(&fr->forceHelperBuffers[1],
+                                                      forceView->forceMtsCombinedWithPadding(),
+                                                      stepWork, wcycle))
+                    : std::nullopt;
+
+    ForceOutputs* forceOutMtsLevel1 =
+            fr->useMts ? (stepWork.computeSlowForces ? &forceOutMts.value() : nullptr) : &forceOutMtsLevel0;
+
+    const bool nonbondedAtMtsLevel1 = runScheduleWork->simulationWork.computeNonbondedAtMtsLevel1;
+
+    ForceOutputs* forceOutNonbonded = nonbondedAtMtsLevel1 ? forceOutMtsLevel1 : &forceOutMtsLevel0;
+
+    if (inputrec->bPull && pull_have_constraint(pull_work))
+    {
+        clear_pull_forces(pull_work);
+    }
 
     /* We calculate the non-bonded forces, when done on the CPU, here.
      * We do this before calling do_force_lowlevel, because in that
@@ -1609,26 +1656,26 @@ void do_force(FILE*                               fplog,
         do_nb_verlet(fr, ic, enerd, stepWork, InteractionLocality::Local, enbvClearFYes, step, nrnb, wcycle);
     }
 
-    if (fr->efep != efepNO)
+    if (fr->efep != efepNO && stepWork.computeNonbondedForces)
     {
         /* Calculate the local and non-local free energy interactions here.
          * Happens here on the CPU both with and without GPU.
          */
         nbv->dispatchFreeEnergyKernel(InteractionLocality::Local, fr,
                                       as_rvec_array(x.unpaddedArrayRef().data()),
-                                      &forceOut.forceWithShiftForces(), *mdatoms, inputrec->fepvals,
-                                      lambda, enerd, stepWork, nrnb);
+                                      &forceOutNonbonded->forceWithShiftForces(), *mdatoms,
+                                      inputrec->fepvals, lambda, enerd, stepWork, nrnb);
 
         if (havePPDomainDecomposition(cr))
         {
             nbv->dispatchFreeEnergyKernel(InteractionLocality::NonLocal, fr,
                                           as_rvec_array(x.unpaddedArrayRef().data()),
-                                          &forceOut.forceWithShiftForces(), *mdatoms,
+                                          &forceOutNonbonded->forceWithShiftForces(), *mdatoms,
                                           inputrec->fepvals, lambda, enerd, stepWork, nrnb);
         }
     }
 
-    if (!useOrEmulateGpuNb)
+    if (stepWork.computeNonbondedForces && !useOrEmulateGpuNb)
     {
         if (havePPDomainDecomposition(cr))
         {
@@ -1643,7 +1690,8 @@ void do_force(FILE*                               fplog,
              * communication with calculation with domain decomposition.
              */
             wallcycle_stop(wcycle, ewcFORCE);
-            nbv->atomdata_add_nbat_f_to_f(AtomLocality::All, forceOut.forceWithShiftForces().force());
+            nbv->atomdata_add_nbat_f_to_f(AtomLocality::All,
+                                          forceOutNonbonded->forceWithShiftForces().force());
             wallcycle_start_nocount(wcycle, ewcFORCE);
         }
 
@@ -1652,8 +1700,8 @@ void do_force(FILE*                               fplog,
         {
             /* This is not in a subcounter because it takes a
                negligible and constant-sized amount of time */
-            nbnxn_atomdata_add_nbat_fshift_to_fshift(*nbv->nbat,
-                                                     forceOut.forceWithShiftForces().shiftForces());
+            nbnxn_atomdata_add_nbat_fshift_to_fshift(
+                    *nbv->nbat, forceOutNonbonded->forceWithShiftForces().shiftForces());
         }
     }
 
@@ -1670,7 +1718,7 @@ void do_force(FILE*                               fplog,
     {
         /* foreign lambda component for walls */
         real dvdl_walls = do_walls(*inputrec, *fr, box, *mdatoms, x.unpaddedConstArrayRef(),
-                                   &forceOut.forceWithVirial(), lambda[efptVDW],
+                                   &forceOutMtsLevel0.forceWithVirial(), lambda[efptVDW],
                                    enerd->grpp.ener[egLJSR].data(), nrnb);
         enerd->dvdl_lin[efptVDW] += dvdl_walls;
     }
@@ -1697,8 +1745,10 @@ void do_force(FILE*                               fplog,
             set_pbc_dd(&pbc, fr->pbcType, DOMAINDECOMP(cr) ? cr->dd->numCells : nullptr, TRUE, box);
         }
 
-        for (auto& listedForces : fr->listedForces)
+        for (int mtsIndex = 0; mtsIndex < (fr->useMts && stepWork.computeSlowForces ? 2 : 1); mtsIndex++)
         {
+            ListedForces& listedForces = fr->listedForces[mtsIndex];
+            ForceOutputs& forceOut     = (mtsIndex == 0 ? forceOutMtsLevel0 : *forceOutMtsLevel1);
             listedForces.calculate(
                     wcycle, box, inputrec->fepvals, cr, ms, x, xWholeMolecules, fr->fcdata.get(),
                     hist, &forceOut, fr, &pbc, enerd, nrnb, lambda.data(), mdatoms,
@@ -1706,9 +1756,13 @@ void do_force(FILE*                               fplog,
         }
     }
 
-    calculateLongRangeNonbondeds(fr, inputrec, cr, nrnb, wcycle, mdatoms, x.unpaddedConstArrayRef(),
-                                 &forceOut.forceWithVirial(), enerd, box, lambda.data(),
-                                 as_rvec_array(dipoleData.muStateAB), stepWork, ddBalanceRegionHandler);
+    if (stepWork.computeSlowForces)
+    {
+        calculateLongRangeNonbondeds(fr, inputrec, cr, nrnb, wcycle, mdatoms,
+                                     x.unpaddedConstArrayRef(), &forceOutMtsLevel1->forceWithVirial(),
+                                     enerd, box, lambda.data(), as_rvec_array(dipoleData.muStateAB),
+                                     stepWork, ddBalanceRegionHandler);
+    }
 
     wallcycle_stop(wcycle, ewcFORCE);
 
@@ -1733,16 +1787,19 @@ void do_force(FILE*                               fplog,
     }
 
     computeSpecialForces(fplog, cr, inputrec, awh, enforcedRotation, imdSession, pull_work, step, t,
-                         wcycle, fr->forceProviders, box, x.unpaddedArrayRef(), mdatoms, lambda,
-                         stepWork, &forceOut.forceWithVirial(), enerd, ed, stepWork.doNeighborSearch);
-
+                         wcycle, fr->forceProviders, box, x.unpaddedArrayRef(), mdatoms, lambda, stepWork,
+                         &forceOutMtsLevel0.forceWithVirial(), enerd, ed, stepWork.doNeighborSearch);
 
+    GMX_ASSERT(!(nonbondedAtMtsLevel1 && stepWork.useGpuFBufferOps),
+               "The schedule below does not allow for nonbonded MTS with GPU buffer ops");
+    GMX_ASSERT(!(nonbondedAtMtsLevel1 && useGpuForcesHaloExchange),
+               "The schedule below does not allow for nonbonded MTS with GPU halo exchange");
     // Will store the amount of cycles spent waiting for the GPU that
     // will be later used in the DLB accounting.
     float cycles_wait_gpu = 0;
-    if (useOrEmulateGpuNb)
+    if (useOrEmulateGpuNb && stepWork.computeNonbondedForces)
     {
-        auto& forceWithShiftForces = forceOut.forceWithShiftForces();
+        auto& forceWithShiftForces = forceOutNonbonded->forceWithShiftForces();
 
         /* wait for non-local forces (or calculate in emulation mode) */
         if (havePPDomainDecomposition(cr))
@@ -1771,7 +1828,7 @@ void do_force(FILE*                               fplog,
 
                 if (haveNonLocalForceContribInCpuBuffer)
                 {
-                    stateGpu->copyForcesToGpu(forceOut.forceWithShiftForces().force(),
+                    stateGpu->copyForcesToGpu(forceOutMtsLevel0.forceWithShiftForces().force(),
                                               AtomLocality::NonLocal);
                 }
 
@@ -1780,7 +1837,7 @@ void do_force(FILE*                               fplog,
                 if (!useGpuForcesHaloExchange)
                 {
                     // copy from GPU input for dd_move_f()
-                    stateGpu->copyForcesFromGpu(forceOut.forceWithShiftForces().force(),
+                    stateGpu->copyForcesFromGpu(forceOutMtsLevel0.forceWithShiftForces().force(),
                                                 AtomLocality::NonLocal);
                 }
             }
@@ -1789,7 +1846,6 @@ void do_force(FILE*                               fplog,
                 nbv->atomdata_add_nbat_f_to_f(AtomLocality::NonLocal, forceWithShiftForces.force());
             }
 
-
             if (fr->nbv->emulateGpu() && stepWork.computeVirial)
             {
                 nbnxn_atomdata_add_nbat_fshift_to_fshift(*nbv->nbat, forceWithShiftForces.shiftForces());
@@ -1797,6 +1853,20 @@ void do_force(FILE*                               fplog,
         }
     }
 
+    /* Combining the forces for multiple time stepping before the halo exchange, when possible,
+     * avoids an extra halo exchange (when DD is used) and post-processing step.
+     */
+    const bool combineMtsForcesBeforeHaloExchange =
+            (stepWork.computeForces && fr->useMts && stepWork.computeSlowForces
+             && (legacyFlags & GMX_FORCE_DO_NOT_NEED_NORMAL_FORCE) != 0
+             && !(stepWork.computeVirial || simulationWork.useGpuNonbonded || useGpuPmeOnThisRank));
+    if (combineMtsForcesBeforeHaloExchange)
+    {
+        const int numAtoms = havePPDomainDecomposition(cr) ? dd_numAtomsZones(*cr->dd) : mdatoms->homenr;
+        combineMtsForces(numAtoms, force.unpaddedArrayRef(), forceView->forceMtsCombined(),
+                         inputrec->mtsLevels[1].stepFactor);
+    }
+
     if (havePPDomainDecomposition(cr))
     {
         /* We are done with the CPU compute.
@@ -1808,12 +1878,12 @@ void do_force(FILE*                               fplog,
 
         if (stepWork.computeForces)
         {
-
             if (useGpuForcesHaloExchange)
             {
                 if (domainWork.haveCpuLocalForceWork)
                 {
-                    stateGpu->copyForcesToGpu(forceOut.forceWithShiftForces().force(), AtomLocality::Local);
+                    stateGpu->copyForcesToGpu(forceOutMtsLevel0.forceWithShiftForces().force(),
+                                              AtomLocality::Local);
                 }
                 communicateGpuHaloForces(*cr, domainWork.haveCpuLocalForceWork);
             }
@@ -1823,7 +1893,18 @@ void do_force(FILE*                               fplog,
                 {
                     stateGpu->waitForcesReadyOnHost(AtomLocality::NonLocal);
                 }
-                dd_move_f(cr->dd, &forceOut.forceWithShiftForces(), wcycle);
+
+                // Without MTS or with MTS at slow steps with uncombined forces we need to
+                // communicate the fast forces
+                if (!fr->useMts || !combineMtsForcesBeforeHaloExchange)
+                {
+                    dd_move_f(cr->dd, &forceOutMtsLevel0.forceWithShiftForces(), wcycle);
+                }
+                // With MTS we need to communicate the slow or combined (in forceOutMtsLevel1) forces
+                if (fr->useMts && stepWork.computeSlowForces)
+                {
+                    dd_move_f(cr->dd, &forceOutMtsLevel1->forceWithShiftForces(), wcycle);
+                }
             }
         }
     }
@@ -1834,18 +1915,18 @@ void do_force(FILE*                               fplog,
                              && !DOMAINDECOMP(cr) && !stepWork.useGpuFBufferOps);
     if (alternateGpuWait)
     {
-        alternatePmeNbGpuWaitReduce(fr->nbv.get(), fr->pmedata, &forceOut, enerd, lambda[efptCOUL],
-                                    stepWork, wcycle);
+        alternatePmeNbGpuWaitReduce(fr->nbv.get(), fr->pmedata, forceOutNonbonded,
+                                    forceOutMtsLevel1, enerd, lambda[efptCOUL], stepWork, wcycle);
     }
 
     if (!alternateGpuWait && useGpuPmeOnThisRank)
     {
-        pme_gpu_wait_and_reduce(fr->pmedata, stepWork, wcycle, &forceOut.forceWithVirial(), enerd,
-                                lambda[efptCOUL]);
+        pme_gpu_wait_and_reduce(fr->pmedata, stepWork, wcycle,
+                                &forceOutMtsLevel1->forceWithVirial(), enerd, lambda[efptCOUL]);
     }
 
     /* Wait for local GPU NB outputs on the non-alternating wait path */
-    if (!alternateGpuWait && simulationWork.useGpuNonbonded)
+    if (!alternateGpuWait && stepWork.computeNonbondedForces && simulationWork.useGpuNonbonded)
     {
         /* Measured overhead on CUDA and OpenCL with(out) GPU sharing
          * is between 0.5 and 1.5 Mcycles. So 2 MCycles is an overestimate,
@@ -1855,7 +1936,8 @@ void do_force(FILE*                               fplog,
         const float gpuWaitApiOverheadMargin = 2e6F; /* cycles */
         const float waitCycles               = Nbnxm::gpu_wait_finish_task(
                 nbv->gpu_nbv, stepWork, AtomLocality::Local, enerd->grpp.ener[egLJSR].data(),
-                enerd->grpp.ener[egCOULSR].data(), forceOut.forceWithShiftForces().shiftForces(), wcycle);
+                enerd->grpp.ener[egCOULSR].data(),
+                forceOutNonbonded->forceWithShiftForces().shiftForces(), wcycle);
 
         if (ddBalanceRegionHandler.useBalancingRegion())
         {
@@ -1886,13 +1968,13 @@ void do_force(FILE*                               fplog,
 
     // If on GPU PME-PP comms or GPU update path, receive forces from PME before GPU buffer ops
     // TODO refactor this and unify with below default-path call to the same function
-    if (PAR(cr) && !thisRankHasDuty(cr, DUTY_PME)
+    if (PAR(cr) && !thisRankHasDuty(cr, DUTY_PME) && stepWork.computeSlowForces
         && (simulationWork.useGpuPmePpCommunication || simulationWork.useGpuUpdate))
     {
         /* In case of node-splitting, the PP nodes receive the long-range
          * forces, virial and energy from the PME nodes here.
          */
-        pme_receive_force_ener(fr, cr, &forceOut.forceWithVirial(), enerd,
+        pme_receive_force_ener(fr, cr, &forceOutMtsLevel1->forceWithVirial(), enerd,
                                simulationWork.useGpuPmePpCommunication,
                                stepWork.useGpuPmeFReduction, wcycle);
     }
@@ -1900,12 +1982,14 @@ void do_force(FILE*                               fplog,
 
     /* Do the nonbonded GPU (or emulation) force buffer reduction
      * on the non-alternating path. */
+    GMX_ASSERT(!(nonbondedAtMtsLevel1 && stepWork.useGpuFBufferOps),
+               "The schedule below does not allow for nonbonded MTS with GPU buffer ops");
     if (useOrEmulateGpuNb && !alternateGpuWait)
     {
-        gmx::ArrayRef<gmx::RVec> forceWithShift = forceOut.forceWithShiftForces().force();
-
         if (stepWork.useGpuFBufferOps)
         {
+            ArrayRef<gmx::RVec> forceWithShift = forceOutNonbonded->forceWithShiftForces().force();
+
             // Flag to specify whether the CPU force buffer has contributions to
             // local atoms. This depends on whether there are CPU-based force tasks
             // or when DD is active the halo exchange has resulted in contributions
@@ -1932,7 +2016,10 @@ void do_force(FILE*                               fplog,
                 stateGpu->copyForcesToGpu(forceWithShift, locality);
             }
 
-            fr->gpuForceReduction[gmx::AtomLocality::Local]->execute();
+            if (stepWork.computeNonbondedForces)
+            {
+                fr->gpuForceReduction[gmx::AtomLocality::Local]->execute();
+            }
 
             // Copy forces to host if they are needed for update or if virtual sites are enabled.
             // If there are vsites, we need to copy forces every step to spread vsite forces on host.
@@ -1946,8 +2033,9 @@ void do_force(FILE*                               fplog,
                 stateGpu->waitForcesReadyOnHost(AtomLocality::Local);
             }
         }
-        else
+        else if (stepWork.computeNonbondedForces)
         {
+            ArrayRef<gmx::RVec> forceWithShift = forceOutNonbonded->forceWithShiftForces().force();
             nbv->atomdata_add_nbat_f_to_f(AtomLocality::Local, forceWithShift);
         }
     }
@@ -1960,27 +2048,49 @@ void do_force(FILE*                               fplog,
         dd_force_flop_stop(cr->dd, nrnb);
     }
 
+    const bool haveCombinedMtsForces = (stepWork.computeForces && fr->useMts && stepWork.computeSlowForces
+                                        && combineMtsForcesBeforeHaloExchange);
     if (stepWork.computeForces)
     {
-        postProcessForceWithShiftForces(nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut,
+        postProcessForceWithShiftForces(nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOutMtsLevel0,
                                         vir_force, *mdatoms, *fr, vsite, stepWork);
+
+        if (fr->useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
+        {
+            postProcessForceWithShiftForces(nrnb, wcycle, box, x.unpaddedArrayRef(), forceOutMtsLevel1,
+                                            vir_force, *mdatoms, *fr, vsite, stepWork);
+        }
     }
 
     // TODO refactor this and unify with above GPU PME-PP / GPU update path call to the same function
     if (PAR(cr) && !thisRankHasDuty(cr, DUTY_PME) && !simulationWork.useGpuPmePpCommunication
-        && !simulationWork.useGpuUpdate)
+        && !simulationWork.useGpuUpdate && stepWork.computeSlowForces)
     {
         /* In case of node-splitting, the PP nodes receive the long-range
          * forces, virial and energy from the PME nodes here.
          */
-        pme_receive_force_ener(fr, cr, &forceOut.forceWithVirial(), enerd,
+        pme_receive_force_ener(fr, cr, &forceOutMtsLevel1->forceWithVirial(), enerd,
                                simulationWork.useGpuPmePpCommunication, false, wcycle);
     }
 
     if (stepWork.computeForces)
     {
-        postProcessForces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOut, vir_force,
-                          mdatoms, fr, vsite, stepWork);
+        /* If we don't use MTS or if we already combined the MTS forces before, we only
+         * need to post-process one ForceOutputs object here, called forceOutCombined,
+         * otherwise we have to post-process two outputs and then combine them.
+         */
+        ForceOutputs& forceOutCombined = (haveCombinedMtsForces ? forceOutMts.value() : forceOutMtsLevel0);
+        postProcessForces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), &forceOutCombined,
+                          vir_force, mdatoms, fr, vsite, stepWork);
+
+        if (fr->useMts && stepWork.computeSlowForces && !haveCombinedMtsForces)
+        {
+            postProcessForces(cr, step, nrnb, wcycle, box, x.unpaddedArrayRef(), forceOutMtsLevel1,
+                              vir_force, mdatoms, fr, vsite, stepWork);
+
+            combineMtsForces(mdatoms->homenr, force.unpaddedArrayRef(),
+                             forceView->forceMtsCombined(), inputrec->mtsLevels[1].stepFactor);
+        }
     }
 
     if (stepWork.computeEnergy)