Make stepWorkload.useGpuXBufferOps flag consistent
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index 8795ec6188643a4d001b362b55928739c0831ed7..89da40d77cded04ba52a717538d170598a75bc2d 100644 (file)
@@ -54,6 +54,7 @@
 #include "gromacs/domdec/partition.h"
 #include "gromacs/essentialdynamics/edsam.h"
 #include "gromacs/ewald/pme.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
 #include "gromacs/ewald/pme_pp.h"
 #include "gromacs/ewald/pme_pp_comm_gpu.h"
 #include "gromacs/gmxlib/network.h"
@@ -526,7 +527,8 @@ static real averageKineticEnergyEstimate(const t_grpopts& groupOptions)
  *
  * \param[in] step      The step number, used for checking and printing
  * \param[in] enerd     The energy data; the non-bonded group energies need to be added to
- * enerd.term[F_EPOT] before calling this routine \param[in] inputrec  The input record
+ *                      \c enerd.term[F_EPOT] before calling this routine
+ * \param[in] inputrec  The input record
  */
 static void checkPotentialEnergyValidity(int64_t step, const gmx_enerdata_t& enerd, const t_inputrec& inputrec)
 {
@@ -746,7 +748,10 @@ static inline void launchPmeGpuSpread(gmx_pme_t*            pmedata,
                                       gmx_wallcycle*        wcycle)
 {
     pme_gpu_prepare_computation(pmedata, box, wcycle, stepWork);
-    pme_gpu_launch_spread(pmedata, xReadyOnDevice, wcycle, lambdaQ);
+    bool                           useGpuDirectComm         = false;
+    gmx::PmeCoordinateReceiverGpu* pmeCoordinateReceiverGpu = nullptr;
+    pme_gpu_launch_spread(
+            pmedata, xReadyOnDevice, wcycle, lambdaQ, useGpuDirectComm, pmeCoordinateReceiverGpu);
 }
 
 /*! \brief Launch the FFT and gather stages of PME GPU
@@ -900,12 +905,11 @@ static ForceOutputs setupForceOutputs(ForceHelperBuffers*                 forceH
 
 /*! \brief Set up flags that have the lifetime of the domain indicating what type of work is there to compute.
  */
-static DomainLifetimeWorkload setupDomainLifetimeWorkload(const t_inputrec& inputrec,
-                                                          const t_forcerec& fr,
-                                                          const pull_t*     pull_work,
-                                                          const gmx_edsam*  ed,
-                                                          const t_mdatoms&  mdatoms,
-                                                          bool havePPDomainDecomposition,
+static DomainLifetimeWorkload setupDomainLifetimeWorkload(const t_inputrec&         inputrec,
+                                                          const t_forcerec&         fr,
+                                                          const pull_t*             pull_work,
+                                                          const gmx_edsam*          ed,
+                                                          const t_mdatoms&          mdatoms,
                                                           const SimulationWorkload& simulationWork,
                                                           const StepWorkload&       stepWork)
 {
@@ -938,7 +942,7 @@ static DomainLifetimeWorkload setupDomainLifetimeWorkload(const t_inputrec& inpu
             || domainWork.haveFreeEnergyWork || simulationWork.useCpuNonbonded || simulationWork.useCpuPme
             || simulationWork.haveEwaldSurfaceContribution || inputrec.nwall > 0;
     domainWork.haveLocalForceContribInCpuBuffer =
-            domainWork.haveCpuLocalForceWork || havePPDomainDecomposition;
+            domainWork.haveCpuLocalForceWork || simulationWork.havePpDomainDecomposition;
     domainWork.haveNonLocalForceContribInCpuBuffer =
             domainWork.haveCpuBondedWork || domainWork.haveFreeEnergyWork;
 
@@ -951,15 +955,13 @@ static DomainLifetimeWorkload setupDomainLifetimeWorkload(const t_inputrec& inpu
  * \param[in]      mtsLevels            The multiple time-stepping levels, either empty or 2 levels
  * \param[in]      step                 The current MD step
  * \param[in]      simulationWork       Simulation workload description.
- * \param[in]      rankHasPmeDuty       If this rank computes PME.
  *
  * \returns New Stepworkload description.
  */
 static StepWorkload setupStepWorkload(const int                     legacyFlags,
                                       ArrayRef<const gmx::MtsLevel> mtsLevels,
                                       const int64_t                 step,
-                                      const SimulationWorkload&     simulationWork,
-                                      const bool                    rankHasPmeDuty)
+                                      const SimulationWorkload&     simulationWork)
 {
     GMX_ASSERT(mtsLevels.empty() || mtsLevels.size() == 2, "Expect 0 or 2 MTS levels");
     const bool computeSlowForces = (mtsLevels.empty() || step % mtsLevels[1].stepFactor == 0);
@@ -984,14 +986,15 @@ static StepWorkload setupStepWorkload(const int                     legacyFlags,
         GMX_ASSERT(simulationWork.useGpuNonbonded,
                    "Can only offload buffer ops if nonbonded computation is also offloaded");
     }
-    flags.useGpuXBufferOps = simulationWork.useGpuBufferOps;
+    flags.useGpuXBufferOps = simulationWork.useGpuBufferOps && !flags.doNeighborSearch;
     // on virial steps the CPU reduction path is taken
-    flags.useGpuFBufferOps = simulationWork.useGpuBufferOps && !flags.computeVirial;
-    flags.useGpuPmeFReduction = flags.computeSlowForces && flags.useGpuFBufferOps && simulationWork.useGpuPme
-                                && (rankHasPmeDuty || simulationWork.useGpuPmePpCommunication);
-    flags.useGpuXHalo          = simulationWork.useGpuHaloExchange;
+    flags.useGpuFBufferOps       = simulationWork.useGpuBufferOps && !flags.computeVirial;
+    const bool rankHasGpuPmeTask = simulationWork.useGpuPme && !simulationWork.haveSeparatePmeRank;
+    flags.useGpuPmeFReduction    = flags.computeSlowForces && flags.useGpuFBufferOps
+                                && (rankHasGpuPmeTask || simulationWork.useGpuPmePpCommunication);
+    flags.useGpuXHalo          = simulationWork.useGpuHaloExchange && !flags.doNeighborSearch;
     flags.useGpuFHalo          = simulationWork.useGpuHaloExchange && flags.useGpuFBufferOps;
-    flags.haveGpuPmeOnThisRank = simulationWork.useGpuPme && rankHasPmeDuty && flags.computeSlowForces;
+    flags.haveGpuPmeOnThisRank = rankHasGpuPmeTask && flags.computeSlowForces;
     flags.combineMtsForcesBeforeHaloExchange =
             (flags.computeForces && simulationWork.useMts && flags.computeSlowForces
              && flags.useOnlyMtsCombinedForceBuffer
@@ -1093,7 +1096,7 @@ static void reduceAndUpdateMuTot(DipoleData*                   dipoleData,
     }
 }
 
-/*! \brief Combines MTS level0 and level1 force buffes into a full and MTS-combined force buffer.
+/*! \brief Combines MTS level0 and level1 force buffers into a full and MTS-combined force buffer.
  *
  * \param[in]     numAtoms        The number of atoms to combine forces for
  * \param[in,out] forceMtsLevel0  Input: F_level0, output: F_level0 + F_level1
@@ -1115,89 +1118,117 @@ static void combineMtsForces(const int      numAtoms,
     }
 }
 
-/*! \brief Setup for the local and non-local GPU force reductions:
+/*! \brief Setup for the local GPU force reduction:
  * reinitialization plus the registration of forces and dependencies.
  *
- * \param [in] runScheduleWork               Schedule workload flag structure
- * \param [in] cr                            Communication record object
- * \param [in] fr                            Force record object
+ * \param [in] runScheduleWork     Schedule workload flag structure
+ * \param [in] nbv                 Non-bonded Verlet object
+ * \param [in] stateGpu            GPU state propagator object
+ * \param [in] gpuForceReduction   GPU force reduction object
+ * \param [in] pmePpCommGpu        PME-PP GPU communication object
+ * \param [in] pmedata             PME data object
+ * \param [in] dd                  Domain decomposition object
  */
-static void setupGpuForceReductions(gmx::MdrunScheduleWorkload* runScheduleWork,
-                                    const t_commrec*            cr,
-                                    t_forcerec*                 fr)
+static void setupLocalGpuForceReduction(const gmx::MdrunScheduleWorkload* runScheduleWork,
+                                        const nonbonded_verlet_t*         nbv,
+                                        gmx::StatePropagatorDataGpu*      stateGpu,
+                                        gmx::GpuForceReduction*           gpuForceReduction,
+                                        gmx::PmePpCommGpu*                pmePpCommGpu,
+                                        const gmx_pme_t*                  pmedata,
+                                        const gmx_domdec_t*               dd)
 {
-
-    nonbonded_verlet_t*          nbv      = fr->nbv.get();
-    gmx::StatePropagatorDataGpu* stateGpu = fr->stateGpu;
+    GMX_ASSERT(!runScheduleWork->simulationWork.useMts,
+               "GPU force reduction is not compatible with MTS");
 
     // (re-)initialize local GPU force reduction
-    const bool accumulate =
-            runScheduleWork->domainWork.haveCpuLocalForceWork || havePPDomainDecomposition(cr);
+    const bool accumulate = runScheduleWork->domainWork.haveCpuLocalForceWork
+                            || runScheduleWork->simulationWork.havePpDomainDecomposition;
     const int atomStart = 0;
-    fr->gpuForceReduction[gmx::AtomLocality::Local]->reinit(stateGpu->getForces(),
-                                                            nbv->getNumAtoms(AtomLocality::Local),
-                                                            nbv->getGridIndices(),
-                                                            atomStart,
-                                                            accumulate,
-                                                            stateGpu->fReducedOnDevice());
+    gpuForceReduction->reinit(stateGpu->getForces(),
+                              nbv->getNumAtoms(AtomLocality::Local),
+                              nbv->getGridIndices(),
+                              atomStart,
+                              accumulate,
+                              stateGpu->fReducedOnDevice(AtomLocality::Local));
 
     // register forces and add dependencies
-    fr->gpuForceReduction[gmx::AtomLocality::Local]->registerNbnxmForce(Nbnxm::gpu_get_f(nbv->gpu_nbv));
-
-    if (runScheduleWork->simulationWork.useGpuPme
-        && (thisRankHasDuty(cr, DUTY_PME) || runScheduleWork->simulationWork.useGpuPmePpCommunication))
-    {
-        DeviceBuffer<gmx::RVec> forcePtr =
-                thisRankHasDuty(cr, DUTY_PME) ? pme_gpu_get_device_f(fr->pmedata)
-                                              :                    // PME force buffer on same GPU
-                        fr->pmePpCommGpu->getGpuForceStagingPtr(); // buffer received from other GPU
-        fr->gpuForceReduction[gmx::AtomLocality::Local]->registerRvecForce(forcePtr);
+    gpuForceReduction->registerNbnxmForce(Nbnxm::gpu_get_f(nbv->gpu_nbv));
 
-        GpuEventSynchronizer* const pmeSynchronizer =
-                (thisRankHasDuty(cr, DUTY_PME) ? pme_gpu_get_f_ready_synchronizer(fr->pmedata)
-                                               : // PME force buffer on same GPU
-                         fr->pmePpCommGpu->getForcesReadySynchronizer()); // buffer received from other GPU
+    DeviceBuffer<gmx::RVec> pmeForcePtr;
+    GpuEventSynchronizer*   pmeSynchronizer     = nullptr;
+    bool                    havePmeContribution = false;
 
+    if (runScheduleWork->simulationWork.useGpuPme && !runScheduleWork->simulationWork.haveSeparatePmeRank)
+    {
+        pmeForcePtr         = pme_gpu_get_device_f(pmedata);
+        pmeSynchronizer     = pme_gpu_get_f_ready_synchronizer(pmedata);
+        havePmeContribution = true;
+    }
+    else if (runScheduleWork->simulationWork.useGpuPmePpCommunication)
+    {
+        pmeForcePtr = pmePpCommGpu->getGpuForceStagingPtr();
         if (GMX_THREAD_MPI)
         {
-            GMX_ASSERT(pmeSynchronizer != nullptr, "PME force ready cuda event should not be NULL");
-            fr->gpuForceReduction[gmx::AtomLocality::Local]->addDependency(pmeSynchronizer);
+            pmeSynchronizer = pmePpCommGpu->getForcesReadySynchronizer();
         }
+        havePmeContribution = true;
     }
 
-    if (runScheduleWork->domainWork.haveCpuLocalForceWork && !runScheduleWork->simulationWork.useGpuHaloExchange)
+    if (havePmeContribution)
     {
-        // in the DD case we use the same stream for H2D and reduction, hence no explicit dependency needed
-        if (!havePPDomainDecomposition(cr))
+        gpuForceReduction->registerRvecForce(pmeForcePtr);
+        if (!runScheduleWork->simulationWork.useGpuPmePpCommunication || GMX_THREAD_MPI)
         {
-            const bool useGpuForceBufferOps = true;
-            fr->gpuForceReduction[gmx::AtomLocality::Local]->addDependency(
-                    stateGpu->getForcesReadyOnDeviceEvent(AtomLocality::All, useGpuForceBufferOps));
+            GMX_ASSERT(pmeSynchronizer != nullptr, "PME force ready cuda event should not be NULL");
+            gpuForceReduction->addDependency(pmeSynchronizer);
         }
     }
 
-    if (runScheduleWork->simulationWork.useGpuHaloExchange)
+    if (runScheduleWork->domainWork.haveCpuLocalForceWork
+        || (runScheduleWork->simulationWork.havePpDomainDecomposition
+            && !runScheduleWork->simulationWork.useGpuHaloExchange))
     {
-        fr->gpuForceReduction[gmx::AtomLocality::Local]->addDependency(
-                cr->dd->gpuHaloExchange[0][0]->getForcesReadyOnDeviceEvent());
+        gpuForceReduction->addDependency(stateGpu->fReadyOnDevice(AtomLocality::Local));
     }
 
-    if (havePPDomainDecomposition(cr))
+    if (runScheduleWork->simulationWork.useGpuHaloExchange)
     {
-        // (re-)initialize non-local GPU force reduction
-        const bool accumulate = runScheduleWork->domainWork.haveCpuBondedWork
-                                || runScheduleWork->domainWork.haveFreeEnergyWork;
-        const int atomStart = dd_numHomeAtoms(*cr->dd);
-        fr->gpuForceReduction[gmx::AtomLocality::NonLocal]->reinit(stateGpu->getForces(),
-                                                                   nbv->getNumAtoms(AtomLocality::NonLocal),
-                                                                   nbv->getGridIndices(),
-                                                                   atomStart,
-                                                                   accumulate);
+        gpuForceReduction->addDependency(dd->gpuHaloExchange[0][0]->getForcesReadyOnDeviceEvent());
+    }
+}
+
+/*! \brief Setup for the non-local GPU force reduction:
+ * reinitialization plus the registration of forces and dependencies.
+ *
+ * \param [in] runScheduleWork     Schedule workload flag structure
+ * \param [in] nbv                 Non-bonded Verlet object
+ * \param [in] stateGpu            GPU state propagator object
+ * \param [in] gpuForceReduction   GPU force reduction object
+ * \param [in] dd                  Domain decomposition object
+ */
+static void setupNonLocalGpuForceReduction(const gmx::MdrunScheduleWorkload* runScheduleWork,
+                                           const nonbonded_verlet_t*         nbv,
+                                           gmx::StatePropagatorDataGpu*      stateGpu,
+                                           gmx::GpuForceReduction*           gpuForceReduction,
+                                           const gmx_domdec_t*               dd)
+{
+    // (re-)initialize non-local GPU force reduction
+    const bool accumulate = runScheduleWork->domainWork.haveCpuBondedWork
+                            || runScheduleWork->domainWork.haveFreeEnergyWork;
+    const int atomStart = dd_numHomeAtoms(*dd);
+    gpuForceReduction->reinit(stateGpu->getForces(),
+                              nbv->getNumAtoms(AtomLocality::NonLocal),
+                              nbv->getGridIndices(),
+                              atomStart,
+                              accumulate,
+                              stateGpu->fReducedOnDevice(AtomLocality::NonLocal));
 
-        // register forces and add dependencies
-        // in the DD case we use the same stream for H2D and reduction, hence no explicit dependency needed
-        fr->gpuForceReduction[gmx::AtomLocality::NonLocal]->registerNbnxmForce(
-                Nbnxm::gpu_get_f(nbv->gpu_nbv));
+    // register forces and add dependencies
+    gpuForceReduction->registerNbnxmForce(Nbnxm::gpu_get_f(nbv->gpu_nbv));
+
+    if (runScheduleWork->domainWork.haveNonLocalForceContribInCpuBuffer)
+    {
+        gpuForceReduction->addDependency(stateGpu->fReadyOnDevice(AtomLocality::NonLocal));
     }
 }
 
@@ -1238,6 +1269,7 @@ void do_force(FILE*                               fplog,
               rvec                                muTotal,
               double                              t,
               gmx_edsam*                          ed,
+              CpuPpLongRangeNonbondeds*           longRangeNonbondeds,
               int                                 legacyFlags,
               const DDBalanceRegionHandler&       ddBalanceRegionHandler)
 {
@@ -1253,10 +1285,20 @@ void do_force(FILE*                               fplog,
 
     const SimulationWorkload& simulationWork = runScheduleWork->simulationWork;
 
-    runScheduleWork->stepWork = setupStepWorkload(
-            legacyFlags, inputrec.mtsLevels, step, simulationWork, thisRankHasDuty(cr, DUTY_PME));
+    runScheduleWork->stepWork = setupStepWorkload(legacyFlags, inputrec.mtsLevels, step, simulationWork);
     const StepWorkload& stepWork = runScheduleWork->stepWork;
 
+    if (stepWork.useGpuFHalo && !runScheduleWork->domainWork.haveCpuLocalForceWork)
+    {
+        // GPU Force halo exchange will set a subset of local atoms with remote non-local data
+        // First clear local portion of force array, so that untouched atoms are zero.
+        // The dependency for this is that forces from previous timestep have been consumed,
+        // which is satisfied when getCoordinatesReadyOnDeviceEvent has been marked.
+        stateGpu->clearForcesOnGpu(AtomLocality::Local,
+                                   stateGpu->getCoordinatesReadyOnDeviceEvent(
+                                           AtomLocality::Local, simulationWork, stepWork));
+    }
+
     /* At a search step we need to start the first balancing region
      * somewhere early inside the step after communication during domain
      * decomposition (and not during the previous step as usual).
@@ -1279,7 +1321,7 @@ void do_force(FILE*                               fplog,
         }
 
         const bool fillGrid = (stepWork.doNeighborSearch && stepWork.stateChanged);
-        const bool calcCGCM = (fillGrid && !DOMAINDECOMP(cr));
+        const bool calcCGCM = (fillGrid && !haveDDAtomOrdering(*cr));
         if (calcCGCM)
         {
             put_atoms_in_box_omp(fr->pbcType,
@@ -1293,56 +1335,53 @@ void do_force(FILE*                               fplog,
     nbnxn_atomdata_copy_shiftvec(stepWork.haveDynamicBox, fr->shift_vec, nbv->nbat.get());
 
     const bool pmeSendCoordinatesFromGpu =
-            GMX_MPI && simulationWork.useGpuPmePpCommunication && !(stepWork.doNeighborSearch);
+            simulationWork.useGpuPmePpCommunication && !(stepWork.doNeighborSearch);
     const bool reinitGpuPmePpComms =
-            GMX_MPI && simulationWork.useGpuPmePpCommunication && (stepWork.doNeighborSearch);
+            simulationWork.useGpuPmePpCommunication && (stepWork.doNeighborSearch);
 
     auto* localXReadyOnDevice = (stepWork.haveGpuPmeOnThisRank || simulationWork.useGpuBufferOps)
                                         ? stateGpu->getCoordinatesReadyOnDeviceEvent(
                                                 AtomLocality::Local, simulationWork, stepWork)
                                         : nullptr;
 
+    GMX_ASSERT(simulationWork.useGpuHaloExchange
+                       == ((cr->dd != nullptr) && (!cr->dd->gpuHaloExchange[0].empty())),
+               "The GPU halo exchange is active, but it has not been constructed.");
+
+    bool gmx_used_in_debug haveCopiedXFromGpu = false;
     // Copy coordinate from the GPU if update is on the GPU and there
     // are forces to be computed on the CPU, or for the computation of
     // virial, or if host-side data will be transferred from this task
     // to a remote task for halo exchange or PME-PP communication. At
     // search steps the current coordinates are already on the host,
     // hence copy is not needed.
-    const bool haveHostPmePpComms =
-            !thisRankHasDuty(cr, DUTY_PME) && !simulationWork.useGpuPmePpCommunication;
-
-    GMX_ASSERT(simulationWork.useGpuHaloExchange
-                       == ((cr->dd != nullptr) && (!cr->dd->gpuHaloExchange[0].empty())),
-               "The GPU halo exchange is active, but it has not been constructed.");
-    const bool haveHostHaloExchangeComms =
-            havePPDomainDecomposition(cr) && !simulationWork.useGpuHaloExchange;
-
-    bool gmx_used_in_debug haveCopiedXFromGpu = false;
     if (simulationWork.useGpuUpdate && !stepWork.doNeighborSearch
         && (runScheduleWork->domainWork.haveCpuLocalForceWork || stepWork.computeVirial
-            || haveHostPmePpComms || haveHostHaloExchangeComms || simulationWork.computeMuTot))
+            || simulationWork.useCpuPmePpCommunication || simulationWork.useCpuHaloExchange
+            || simulationWork.computeMuTot))
     {
         stateGpu->copyCoordinatesFromGpu(x.unpaddedArrayRef(), AtomLocality::Local);
         haveCopiedXFromGpu = true;
     }
 
+    if (stepWork.doNeighborSearch && ((stepWork.haveGpuPmeOnThisRank || simulationWork.useGpuBufferOps)))
+    {
+        // TODO refactor this to do_md, after partitioning.
+        stateGpu->reinit(mdatoms->homenr,
+                         getLocalAtomCount(cr->dd, *mdatoms, simulationWork.havePpDomainDecomposition));
+        if (stepWork.haveGpuPmeOnThisRank)
+        {
+            // TODO: This should be moved into PME setup function ( pme_gpu_prepare_computation(...) )
+            pme_gpu_set_device_x(fr->pmedata, stateGpu->getCoordinates());
+        }
+    }
+
     // Coordinates on the device are needed if PME or BufferOps are offloaded.
     // The local coordinates can be copied right away.
     // NOTE: Consider moving this copy to right after they are updated and constrained,
     //       if the later is not offloaded.
     if (stepWork.haveGpuPmeOnThisRank || stepWork.useGpuXBufferOps)
     {
-        if (stepWork.doNeighborSearch)
-        {
-            // TODO refactor this to do_md, after partitioning.
-            stateGpu->reinit(mdatoms->homenr,
-                             getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)));
-            if (stepWork.haveGpuPmeOnThisRank)
-            {
-                // TODO: This should be moved into PME setup function ( pme_gpu_prepare_computation(...) )
-                pme_gpu_set_device_x(fr->pmedata, stateGpu->getCoordinates());
-            }
-        }
         // We need to copy coordinates when:
         // 1. Update is not offloaded
         // 2. The buffers were reinitialized on search step
@@ -1350,10 +1389,25 @@ void do_force(FILE*                               fplog,
         {
             GMX_ASSERT(stateGpu != nullptr, "stateGpu should not be null");
             stateGpu->copyCoordinatesToGpu(x.unpaddedArrayRef(), AtomLocality::Local);
+            if (stepWork.doNeighborSearch)
+            {
+                /* On NS steps, we skip X buffer ops. So, unless we use PME or direct GPU
+                 * communications, we don't wait for the coordinates on the device,
+                 * and we must consume the event here.
+                 * Issue #3988. */
+                const bool eventWillBeConsumedByGpuPme = stepWork.haveGpuPmeOnThisRank;
+                const bool eventWillBeConsumedByGpuPmePPComm =
+                        (simulationWork.haveSeparatePmeRank && stepWork.computeSlowForces)
+                        && pmeSendCoordinatesFromGpu;
+                if (!eventWillBeConsumedByGpuPme && !eventWillBeConsumedByGpuPmePPComm)
+                {
+                    stateGpu->consumeCoordinatesCopiedToDeviceEvent(AtomLocality::Local);
+                }
+            }
         }
     }
 
-    if (GMX_MPI && !thisRankHasDuty(cr, DUTY_PME) && stepWork.computeSlowForces)
+    if (simulationWork.haveSeparatePmeRank && stepWork.computeSlowForces)
     {
         /* Send particle coordinates to the pme nodes */
         if (!pmeSendCoordinatesFromGpu && !stepWork.doNeighborSearch && simulationWork.useGpuUpdate)
@@ -1374,6 +1428,7 @@ void do_force(FILE*                               fplog,
                                  simulationWork.useGpuPmePpCommunication,
                                  reinitGpuPmePpComms,
                                  pmeSendCoordinatesFromGpu,
+                                 stepWork.useGpuPmeFReduction,
                                  localXReadyOnDevice,
                                  wcycle);
     }
@@ -1399,7 +1454,7 @@ void do_force(FILE*                               fplog,
         }
 
         wallcycle_start(wcycle, WallCycleCounter::NS);
-        if (!DOMAINDECOMP(cr))
+        if (!haveDDAtomOrdering(*cr))
         {
             const rvec vzero       = { 0.0_real, 0.0_real, 0.0_real };
             const rvec boxDiagonal = { box[XX][XX], box[YY][YY], box[ZZ][ZZ] };
@@ -1463,26 +1518,40 @@ void do_force(FILE*                               fplog,
         // Need to run after the GPU-offload bonded interaction lists
         // are set up to be able to determine whether there is bonded work.
         runScheduleWork->domainWork = setupDomainLifetimeWorkload(
-                inputrec, *fr, pull_work, ed, *mdatoms, havePPDomainDecomposition(cr), simulationWork, stepWork);
+                inputrec, *fr, pull_work, ed, *mdatoms, simulationWork, stepWork);
 
         wallcycle_start_nocount(wcycle, WallCycleCounter::NS);
         wallcycle_sub_start(wcycle, WallCycleSubCounter::NBSSearchLocal);
         /* Note that with a GPU the launch overhead of the list transfer is not timed separately */
         nbv->constructPairlist(InteractionLocality::Local, top->excls, step, nrnb);
 
-        nbv->setupGpuShortRangeWork(fr->listedForcesGpu, InteractionLocality::Local);
+        nbv->setupGpuShortRangeWork(fr->listedForcesGpu.get(), InteractionLocality::Local);
 
         wallcycle_sub_stop(wcycle, WallCycleSubCounter::NBSSearchLocal);
         wallcycle_stop(wcycle, WallCycleCounter::NS);
 
-        if (stepWork.useGpuXBufferOps)
+        if (simulationWork.useGpuBufferOps)
         {
             nbv->atomdata_init_copy_x_to_nbat_x_gpu();
         }
 
         if (simulationWork.useGpuBufferOps)
         {
-            setupGpuForceReductions(runScheduleWork, cr, fr);
+            setupLocalGpuForceReduction(runScheduleWork,
+                                        fr->nbv.get(),
+                                        stateGpu,
+                                        fr->gpuForceReduction[gmx::AtomLocality::Local].get(),
+                                        fr->pmePpCommGpu.get(),
+                                        fr->pmedata,
+                                        cr->dd);
+            if (runScheduleWork->simulationWork.havePpDomainDecomposition)
+            {
+                setupNonLocalGpuForceReduction(runScheduleWork,
+                                               fr->nbv.get(),
+                                               stateGpu,
+                                               fr->gpuForceReduction[gmx::AtomLocality::NonLocal].get(),
+                                               cr->dd);
+            }
         }
     }
     else if (!EI_TPI(inputrec.eI) && stepWork.computeNonbondedForces)
@@ -1512,7 +1581,7 @@ void do_force(FILE*                               fplog,
         wallcycle_start(wcycle, WallCycleCounter::LaunchGpu);
         wallcycle_sub_start(wcycle, WallCycleSubCounter::LaunchGpuNonBonded);
         Nbnxm::gpu_upload_shiftvec(nbv->gpu_nbv, nbv->nbat.get());
-        if (stepWork.doNeighborSearch || !stepWork.useGpuXBufferOps)
+        if (!stepWork.useGpuXBufferOps)
         {
             Nbnxm::gpu_copy_xq_to_gpu(nbv->gpu_nbv, nbv->nbat.get(), AtomLocality::Local);
         }
@@ -1522,7 +1591,7 @@ void do_force(FILE*                               fplog,
 
         // bonded work not split into separate local and non-local, so with DD
         // we can only launch the kernel after non-local coordinates have been received.
-        if (domainWork.haveGpuBondedWork && !havePPDomainDecomposition(cr))
+        if (domainWork.haveGpuBondedWork && !simulationWork.havePpDomainDecomposition)
         {
             fr->listedForcesGpu->setPbcAndlaunchKernel(fr->pbcType, box, fr->bMolPBC, stepWork);
         }
@@ -1549,7 +1618,7 @@ void do_force(FILE*                               fplog,
 
     /* Communicate coordinates and sum dipole if necessary +
        do non-local pair search */
-    if (havePPDomainDecomposition(cr))
+    if (simulationWork.havePpDomainDecomposition)
     {
         if (stepWork.doNeighborSearch)
         {
@@ -1559,7 +1628,7 @@ void do_force(FILE*                               fplog,
             /* Note that with a GPU the launch overhead of the list transfer is not timed separately */
             nbv->constructPairlist(InteractionLocality::NonLocal, top->excls, step, nrnb);
 
-            nbv->setupGpuShortRangeWork(fr->listedForcesGpu, InteractionLocality::NonLocal);
+            nbv->setupGpuShortRangeWork(fr->listedForcesGpu.get(), InteractionLocality::NonLocal);
             wallcycle_sub_stop(wcycle, WallCycleSubCounter::NBSSearchNonLocal);
             wallcycle_stop(wcycle, WallCycleCounter::NS);
             // TODO refactor this GPU halo exchange re-initialisation
@@ -1573,16 +1642,18 @@ void do_force(FILE*                               fplog,
         }
         else
         {
+            GpuEventSynchronizer* gpuCoordinateHaloLaunched = nullptr;
             if (stepWork.useGpuXHalo)
             {
                 // The following must be called after local setCoordinates (which records an event
                 // when the coordinate data has been copied to the device).
-                communicateGpuHaloCoordinates(*cr, box, localXReadyOnDevice);
+                gpuCoordinateHaloLaunched = communicateGpuHaloCoordinates(*cr, box, localXReadyOnDevice);
 
                 if (domainWork.haveCpuBondedWork || domainWork.haveFreeEnergyWork)
                 {
                     // non-local part of coordinate buffer must be copied back to host for CPU work
-                    stateGpu->copyCoordinatesFromGpu(x.unpaddedArrayRef(), AtomLocality::NonLocal);
+                    stateGpu->copyCoordinatesFromGpu(
+                            x.unpaddedArrayRef(), AtomLocality::NonLocal, gpuCoordinateHaloLaunched);
                 }
             }
             else
@@ -1598,14 +1669,15 @@ void do_force(FILE*                               fplog,
 
             if (stepWork.useGpuXBufferOps)
             {
-                if (!stepWork.haveGpuPmeOnThisRank && !stepWork.useGpuXHalo)
+                if (!stepWork.useGpuXHalo)
                 {
                     stateGpu->copyCoordinatesToGpu(x.unpaddedArrayRef(), AtomLocality::NonLocal);
                 }
-                nbv->convertCoordinatesGpu(AtomLocality::NonLocal,
-                                           stateGpu->getCoordinates(),
-                                           stateGpu->getCoordinatesReadyOnDeviceEvent(
-                                                   AtomLocality::NonLocal, simulationWork, stepWork));
+                nbv->convertCoordinatesGpu(
+                        AtomLocality::NonLocal,
+                        stateGpu->getCoordinates(),
+                        stateGpu->getCoordinatesReadyOnDeviceEvent(
+                                AtomLocality::NonLocal, simulationWork, stepWork, gpuCoordinateHaloLaunched));
             }
             else
             {
@@ -1616,7 +1688,7 @@ void do_force(FILE*                               fplog,
         if (simulationWork.useGpuNonbonded)
         {
 
-            if (stepWork.doNeighborSearch || !stepWork.useGpuXBufferOps)
+            if (!stepWork.useGpuXBufferOps)
             {
                 wallcycle_start(wcycle, WallCycleCounter::LaunchGpu);
                 wallcycle_sub_start(wcycle, WallCycleSubCounter::LaunchGpuNonBonded);
@@ -1639,13 +1711,22 @@ void do_force(FILE*                               fplog,
         }
     }
 
+    // With FEP we set up the reduction over threads for local+non-local simultaneously,
+    // so we need to do that here after the local and non-local pairlist construction.
+    if (stepWork.doNeighborSearch && fr->efep != FreeEnergyPerturbationType::No)
+    {
+        wallcycle_sub_start(wcycle, WallCycleSubCounter::NonbondedFep);
+        nbv->setupFepThreadedForceBuffer(fr->natoms_force_constr);
+        wallcycle_sub_stop(wcycle, WallCycleSubCounter::NonbondedFep);
+    }
+
     if (simulationWork.useGpuNonbonded && stepWork.computeNonbondedForces)
     {
         /* launch D2H copy-back F */
         wallcycle_start_nocount(wcycle, WallCycleCounter::LaunchGpu);
         wallcycle_sub_start_nocount(wcycle, WallCycleSubCounter::LaunchGpuNonBonded);
 
-        if (havePPDomainDecomposition(cr))
+        if (simulationWork.havePpDomainDecomposition)
         {
             Nbnxm::gpu_launch_cpyback(nbv->gpu_nbv, nbv->nbat.get(), stepWork, AtomLocality::NonLocal);
         }
@@ -1665,18 +1746,26 @@ void do_force(FILE*                               fplog,
         xWholeMolecules = fr->wholeMoleculeTransform->wholeMoleculeCoordinates(x.unpaddedArrayRef(), box);
     }
 
-    DipoleData dipoleData;
-
-    if (simulationWork.computeMuTot)
+    // For the rest of the CPU tasks that depend on GPU-update produced coordinates,
+    // this wait ensures that the D2H transfer is complete.
+    if (simulationWork.useGpuUpdate && !stepWork.doNeighborSearch)
     {
-        const int start = 0;
-
-        if (simulationWork.useGpuUpdate && !stepWork.doNeighborSearch)
+        const bool needCoordsOnHost  = (runScheduleWork->domainWork.haveCpuLocalForceWork
+                                       || stepWork.computeVirial || simulationWork.computeMuTot);
+        const bool haveAlreadyWaited = simulationWork.useCpuHaloExchange;
+        if (needCoordsOnHost && !haveAlreadyWaited)
         {
             GMX_ASSERT(haveCopiedXFromGpu,
                        "a wait should only be triggered if copy has been scheduled");
             stateGpu->waitCoordinatesReadyOnHost(AtomLocality::Local);
         }
+    }
+
+    DipoleData dipoleData;
+
+    if (simulationWork.computeMuTot)
+    {
+        const int start = 0;
 
         /* Calculate total (local) dipole moment in a temporary common array.
          * This makes it possible to sum them over nodes faster.
@@ -1701,21 +1790,12 @@ void do_force(FILE*                               fplog,
     /* Reset energies */
     reset_enerdata(enerd);
 
-    if (DOMAINDECOMP(cr) && !thisRankHasDuty(cr, DUTY_PME))
+    if (haveDDAtomOrdering(*cr) && simulationWork.haveSeparatePmeRank)
     {
         wallcycle_start(wcycle, WallCycleCounter::PpDuringPme);
         dd_force_flop_start(cr->dd, nrnb);
     }
 
-    // For the rest of the CPU tasks that depend on GPU-update produced coordinates,
-    // this wait ensures that the D2H transfer is complete.
-    if (simulationWork.useGpuUpdate && !stepWork.doNeighborSearch
-        && (runScheduleWork->domainWork.haveCpuLocalForceWork || stepWork.computeVirial))
-    {
-        GMX_ASSERT(haveCopiedXFromGpu, "a wait should only be triggered if copy has been scheduled");
-        stateGpu->waitCoordinatesReadyOnHost(AtomLocality::Local);
-    }
-
     if (inputrec.bRot)
     {
         wallcycle_start(wcycle, WallCycleCounter::Rot);
@@ -1736,7 +1816,7 @@ void do_force(FILE*                               fplog,
      * With multiple time-stepping the use is different for MTS fast (level0 only) and slow steps.
      */
     ForceOutputs forceOutMtsLevel0 = setupForceOutputs(
-            &fr->forceHelperBuffers[0], force, domainWork, stepWork, havePPDomainDecomposition(cr), wcycle);
+            &fr->forceHelperBuffers[0], force, domainWork, stepWork, simulationWork.havePpDomainDecomposition, wcycle);
 
     // Force output for MTS combined forces, only set at level1 MTS steps
     std::optional<ForceOutputs> forceOutMts =
@@ -1745,7 +1825,7 @@ void do_force(FILE*                               fplog,
                                                       forceView->forceMtsCombinedWithPadding(),
                                                       domainWork,
                                                       stepWork,
-                                                      havePPDomainDecomposition(cr),
+                                                      simulationWork.havePpDomainDecomposition,
                                                       wcycle))
                     : std::nullopt;
 
@@ -1782,9 +1862,8 @@ void do_force(FILE*                               fplog,
         /* Calculate the local and non-local free energy interactions here.
          * Happens here on the CPU both with and without GPU.
          */
-        nbv->dispatchFreeEnergyKernel(
-                InteractionLocality::Local,
-                x.unpaddedArrayRef(),
+        nbv->dispatchFreeEnergyKernels(
+                x,
                 &forceOutNonbonded->forceWithShiftForces(),
                 fr->use_simd_kernels,
                 fr->ntype,
@@ -1806,39 +1885,11 @@ void do_force(FILE*                               fplog,
                 enerd,
                 stepWork,
                 nrnb);
-
-        if (havePPDomainDecomposition(cr))
-        {
-            nbv->dispatchFreeEnergyKernel(
-                    InteractionLocality::NonLocal,
-                    x.unpaddedArrayRef(),
-                    &forceOutNonbonded->forceWithShiftForces(),
-                    fr->use_simd_kernels,
-                    fr->ntype,
-                    fr->rlist,
-                    *fr->ic,
-                    fr->shift_vec,
-                    fr->nbfp,
-                    fr->ljpme_c6grid,
-                    mdatoms->chargeA ? gmx::arrayRefFromArray(mdatoms->chargeA, mdatoms->nr)
-                                     : gmx::ArrayRef<real>{},
-                    mdatoms->chargeB ? gmx::arrayRefFromArray(mdatoms->chargeB, mdatoms->nr)
-                                     : gmx::ArrayRef<real>{},
-                    mdatoms->typeA ? gmx::arrayRefFromArray(mdatoms->typeA, mdatoms->nr)
-                                   : gmx::ArrayRef<int>{},
-                    mdatoms->typeB ? gmx::arrayRefFromArray(mdatoms->typeB, mdatoms->nr)
-                                   : gmx::ArrayRef<int>{},
-                    inputrec.fepvals.get(),
-                    lambda,
-                    enerd,
-                    stepWork,
-                    nrnb);
-        }
     }
 
     if (stepWork.computeNonbondedForces && !useOrEmulateGpuNb)
     {
-        if (havePPDomainDecomposition(cr))
+        if (simulationWork.havePpDomainDecomposition)
         {
             do_nb_verlet(fr, ic, enerd, stepWork, InteractionLocality::NonLocal, enbvClearFNo, step, nrnb, wcycle);
         }
@@ -1917,7 +1968,7 @@ void do_force(FILE*                               fplog,
             /* Since all atoms are in the rectangular or triclinic unit-cell,
              * only single box vector shifts (2 in x) are required.
              */
-            set_pbc_dd(&pbc, fr->pbcType, DOMAINDECOMP(cr) ? cr->dd->numCells : nullptr, TRUE, box);
+            set_pbc_dd(&pbc, fr->pbcType, haveDDAtomOrdering(*cr) ? cr->dd->numCells : nullptr, TRUE, box);
         }
 
         for (int mtsIndex = 0; mtsIndex < (simulationWork.useMts && stepWork.computeSlowForces ? 2 : 1);
@@ -1941,27 +1992,23 @@ void do_force(FILE*                               fplog,
                                    nrnb,
                                    lambda,
                                    mdatoms,
-                                   DOMAINDECOMP(cr) ? cr->dd->globalAtomIndices.data() : nullptr,
+                                   haveDDAtomOrdering(*cr) ? cr->dd->globalAtomIndices.data() : nullptr,
                                    stepWork);
         }
     }
 
     if (stepWork.computeSlowForces)
     {
-        calculateLongRangeNonbondeds(fr,
-                                     inputrec,
-                                     cr,
-                                     nrnb,
-                                     wcycle,
-                                     mdatoms,
-                                     x.unpaddedConstArrayRef(),
-                                     &forceOutMtsLevel1->forceWithVirial(),
-                                     enerd,
-                                     box,
-                                     lambda,
-                                     dipoleData.muStateAB,
-                                     stepWork,
-                                     ddBalanceRegionHandler);
+        longRangeNonbondeds->calculate(fr->pmedata,
+                                       cr,
+                                       x.unpaddedConstArrayRef(),
+                                       &forceOutMtsLevel1->forceWithVirial(),
+                                       enerd,
+                                       box,
+                                       lambda,
+                                       dipoleData.muStateAB,
+                                       stepWork,
+                                       ddBalanceRegionHandler);
     }
 
     wallcycle_stop(wcycle, WallCycleCounter::Force);
@@ -2008,7 +2055,7 @@ void do_force(FILE*                               fplog,
                          ed,
                          stepWork.doNeighborSearch);
 
-    if (havePPDomainDecomposition(cr) && stepWork.computeForces && stepWork.useGpuFHalo
+    if (simulationWork.havePpDomainDecomposition && stepWork.computeForces && stepWork.useGpuFHalo
         && domainWork.haveCpuLocalForceWork)
     {
         stateGpu->copyForcesToGpu(forceOutMtsLevel0.forceWithShiftForces().force(), AtomLocality::Local);
@@ -2026,7 +2073,7 @@ void do_force(FILE*                               fplog,
         auto& forceWithShiftForces = forceOutNonbonded->forceWithShiftForces();
 
         /* wait for non-local forces (or calculate in emulation mode) */
-        if (havePPDomainDecomposition(cr))
+        if (simulationWork.havePpDomainDecomposition)
         {
             if (simulationWork.useGpuNonbonded)
             {
@@ -2060,6 +2107,10 @@ void do_force(FILE*                               fplog,
 
                 if (!stepWork.useGpuFHalo)
                 {
+                    /* We don't explicitly wait for the forces to be reduced on device,
+                     * but wait for them to finish copying to CPU instead.
+                     * So, we manually consume the event, see Issue #3988. */
+                    stateGpu->consumeForcesReducedOnDeviceEvent(AtomLocality::NonLocal);
                     // copy from GPU input for dd_move_f()
                     stateGpu->copyForcesFromGpu(forceOutMtsLevel0.forceWithShiftForces().force(),
                                                 AtomLocality::NonLocal);
@@ -2082,13 +2133,13 @@ void do_force(FILE*                               fplog,
      */
     if (stepWork.combineMtsForcesBeforeHaloExchange)
     {
-        combineMtsForces(getLocalAtomCount(cr->dd, *mdatoms, havePPDomainDecomposition(cr)),
+        combineMtsForces(getLocalAtomCount(cr->dd, *mdatoms, simulationWork.havePpDomainDecomposition),
                          force.unpaddedArrayRef(),
                          forceView->forceMtsCombined(),
                          inputrec.mtsLevels[1].stepFactor);
     }
 
-    if (havePPDomainDecomposition(cr))
+    if (simulationWork.havePpDomainDecomposition)
     {
         /* We are done with the CPU compute.
          * We will now communicate the non-local forces.
@@ -2104,13 +2155,11 @@ void do_force(FILE*                               fplog,
             {
                 // If there exist CPU forces, data from halo exchange should accumulate into these
                 bool accumulateForces = domainWork.haveCpuLocalForceWork;
-                if (!accumulateForces)
-                {
-                    // Force halo exchange will set a subset of local atoms with remote non-local data
-                    // First clear local portion of force array, so that untouched atoms are zero
-                    stateGpu->clearForcesOnGpu(AtomLocality::Local);
-                }
-                communicateGpuHaloForces(*cr, accumulateForces);
+                gmx::FixedCapacityVector<GpuEventSynchronizer*, 2> gpuForceHaloDependencies;
+                gpuForceHaloDependencies.push_back(stateGpu->fReadyOnDevice(AtomLocality::Local));
+                gpuForceHaloDependencies.push_back(stateGpu->fReducedOnDevice(AtomLocality::NonLocal));
+
+                communicateGpuHaloForces(*cr, accumulateForces, &gpuForceHaloDependencies);
             }
             else
             {
@@ -2137,8 +2186,9 @@ void do_force(FILE*                               fplog,
     // With both nonbonded and PME offloaded a GPU on the same rank, we use
     // an alternating wait/reduction scheme.
     bool alternateGpuWait =
-            (!c_disableAlternatingWait && stepWork.haveGpuPmeOnThisRank
-             && simulationWork.useGpuNonbonded && !DOMAINDECOMP(cr) && !stepWork.useGpuFBufferOps);
+            (!c_disableAlternatingWait && stepWork.haveGpuPmeOnThisRank && simulationWork.useGpuNonbonded
+             && !simulationWork.havePpDomainDecomposition && !stepWork.useGpuFBufferOps);
+
     if (alternateGpuWait)
     {
         alternatePmeNbGpuWaitReduce(fr->nbv.get(),
@@ -2206,7 +2256,7 @@ void do_force(FILE*                               fplog,
                      enerd,
                      stepWork,
                      InteractionLocality::Local,
-                     DOMAINDECOMP(cr) ? enbvClearFNo : enbvClearFYes,
+                     haveDDAtomOrdering(*cr) ? enbvClearFNo : enbvClearFYes,
                      step,
                      nrnb,
                      wcycle);
@@ -2215,8 +2265,8 @@ void do_force(FILE*                               fplog,
 
     // If on GPU PME-PP comms path, receive forces from PME before GPU buffer ops
     // TODO refactor this and unify with below default-path call to the same function
-    if (PAR(cr) && !thisRankHasDuty(cr, DUTY_PME) && stepWork.computeSlowForces
-        && simulationWork.useGpuPmePpCommunication)
+    if (PAR(cr) && simulationWork.haveSeparatePmeRank && simulationWork.useGpuPmePpCommunication
+        && stepWork.computeSlowForces)
     {
         /* In case of node-splitting, the PP nodes receive the long-range
          * forces, virial and energy from the PME nodes here.
@@ -2250,14 +2300,7 @@ void do_force(FILE*                               fplog,
             //   These should be unified.
             if (domainWork.haveLocalForceContribInCpuBuffer && !stepWork.useGpuFHalo)
             {
-                // Note: AtomLocality::All is used for the non-DD case because, as in this
-                // case copyForcesToGpu() uses a separate stream, it allows overlap of
-                // CPU force H2D with GPU force tasks on all streams including those in the
-                // local stream which would otherwise be implicit dependencies for the
-                // transfer and would not overlap.
-                auto locality = havePPDomainDecomposition(cr) ? AtomLocality::Local : AtomLocality::All;
-
-                stateGpu->copyForcesToGpu(forceWithShift, locality);
+                stateGpu->copyForcesToGpu(forceWithShift, AtomLocality::Local);
             }
 
             if (stepWork.computeNonbondedForces)
@@ -2272,8 +2315,16 @@ void do_force(FILE*                               fplog,
             // NOTE: If there are virtual sites, the forces are modified on host after this D2H copy. Hence,
             //       they should not be copied in do_md(...) for the output.
             if (!simulationWork.useGpuUpdate
-                || (simulationWork.useGpuUpdate && DOMAINDECOMP(cr) && haveHostPmePpComms) || vsite)
+                || (simulationWork.useGpuUpdate && haveDDAtomOrdering(*cr) && simulationWork.useCpuPmePpCommunication)
+                || vsite)
             {
+                if (stepWork.computeNonbondedForces)
+                {
+                    /* We have previously issued force reduction on the GPU, but we will
+                     * not use this event, instead relying on the stream being in-order.
+                     * Issue #3988. */
+                    stateGpu->consumeForcesReducedOnDeviceEvent(AtomLocality::Local);
+                }
                 stateGpu->copyForcesFromGpu(forceWithShift, AtomLocality::Local);
                 stateGpu->waitForcesReadyOnHost(AtomLocality::Local);
             }
@@ -2285,9 +2336,10 @@ void do_force(FILE*                               fplog,
         }
     }
 
-    launchGpuEndOfStepTasks(nbv, fr->listedForcesGpu, fr->pmedata, enerd, *runScheduleWork, step, wcycle);
+    launchGpuEndOfStepTasks(
+            nbv, fr->listedForcesGpu.get(), fr->pmedata, enerd, *runScheduleWork, step, wcycle);
 
-    if (DOMAINDECOMP(cr))
+    if (haveDDAtomOrdering(*cr))
     {
         dd_force_flop_stop(cr->dd, nrnb);
     }
@@ -2307,7 +2359,7 @@ void do_force(FILE*                               fplog,
     }
 
     // TODO refactor this and unify with above GPU PME-PP / GPU update path call to the same function
-    if (PAR(cr) && !thisRankHasDuty(cr, DUTY_PME) && !simulationWork.useGpuPmePpCommunication
+    if (PAR(cr) && simulationWork.haveSeparatePmeRank && simulationWork.useCpuPmePpCommunication
         && stepWork.computeSlowForces)
     {
         /* In case of node-splitting, the PP nodes receive the long-range