Pipeline GPU PME Spline/Spread with PP Comms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_only.cpp
index 70af46f6382c2d8ccbb236e096ab25bcce26df5e..56002966be057185ae0bf3924073cd6962b43759 100644 (file)
@@ -3,7 +3,8 @@
  *
  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
  * Copyright (c) 2001-2004, The GROMACS development team.
- * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
+ * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
+ * Copyright (c) 2018,2019,2020,2021, by the GROMACS development team, led by
  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
  * and including many others, as listed in the AUTHORS file in the
  * top-level source directory and at http://www.gromacs.org.
@@ -59,6 +60,8 @@
 
 #include "gmxpre.h"
 
+#include "pme_only.h"
+
 #include "config.h"
 
 #include <cassert>
 
 #include "gromacs/domdec/domdec.h"
 #include "gromacs/ewald/pme.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
+#include "gromacs/ewald/pme_force_sender_gpu.h"
 #include "gromacs/fft/parallel_3dfft.h"
 #include "gromacs/fileio/pdbio.h"
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/gmxlib/nrnb.h"
+#include "gromacs/gpu_utils/device_stream_manager.h"
 #include "gromacs/gpu_utils/hostallocator.h"
 #include "gromacs/math/gmxcomplex.h"
 #include "gromacs/math/units.h"
@@ -84,6 +90,8 @@
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/forceoutput.h"
 #include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/mdtypes/simulation_workload.h"
+#include "gromacs/mdtypes/state_propagator_data_gpu.h"
 #include "gromacs/timing/cyclecounter.h"
 #include "gromacs/timing/wallcycle.h"
 #include "gromacs/utility/fatalerror.h"
 
 #include "pme_gpu_internal.h"
 #include "pme_internal.h"
+#include "pme_output.h"
 #include "pme_pp_communication.h"
 
-//! Contains information about the PP ranks that partner this PME rank.
-struct PpRanks
-{
-    //! The MPI rank ID of this partner PP rank.
-    int rankId;
-    //! The number of atoms to communicate with this partner PP rank.
-    int numAtoms;
-};
-
 /*! \brief Master PP-PME communication data structure */
-struct gmx_pme_pp {
+struct gmx_pme_pp
+{
     MPI_Comm             mpi_comm_mysim; /**< MPI communicator for this simulation */
     std::vector<PpRanks> ppRanks;        /**< The PP partner ranks                 */
     int                  peerRankId;     /**< The peer PP rank id                  */
@@ -119,17 +120,27 @@ struct gmx_pme_pp {
     std::vector<real>           sigmaA;
     std::vector<real>           sigmaB;
     //@}
-    gmx::HostVector<gmx::RVec>  x; /**< Vector of atom coordinates to transfer to PME ranks */
-    std::vector<gmx::RVec>      f; /**< Vector of atom forces received from PME ranks */
+    gmx::HostVector<gmx::RVec> x; /**< Vector of atom coordinates to transfer to PME ranks */
+    std::vector<gmx::RVec>     f; /**< Vector of atom forces received from PME ranks */
     //@{
     /**< Vectors of MPI objects used in non-blocking communication between multiple PP ranks per PME rank */
     std::vector<MPI_Request> req;
     std::vector<MPI_Status>  stat;
     //@}
+
+    /*! \brief object for receiving coordinates using communications operating on GPU memory space */
+    std::unique_ptr<gmx::PmeCoordinateReceiverGpu> pmeCoordinateReceiverGpu;
+    /*! \brief object for sending PME force using communications operating on GPU memory space */
+    std::unique_ptr<gmx::PmeForceSenderGpu> pmeForceSenderGpu;
+
+    /*! \brief whether GPU direct communications are active for PME-PP transfers */
+    bool useGpuDirectComm = false;
+    /*! \brief whether GPU direct communications should send forces directly to remote GPU memory */
+    bool sendForcesDirectToPpGpu = false;
 };
 
 /*! \brief Initialize the PME-only side of the PME <-> PP communication */
-static std::unique_ptr<gmx_pme_pp> gmx_pme_pp_init(const t_commrec *cr)
+static std::unique_ptr<gmx_pme_pp> gmx_pme_pp_init(const t_commreccr)
 {
     auto pme_pp = std::make_unique<gmx_pme_pp>();
 
@@ -140,14 +151,14 @@ static std::unique_ptr<gmx_pme_pp> gmx_pme_pp_init(const t_commrec *cr)
     MPI_Comm_rank(cr->mpi_comm_mygroup, &rank);
     auto ppRanks = get_pme_ddranks(cr, rank);
     pme_pp->ppRanks.reserve(ppRanks.size());
-    for (const auto &ppRankId : ppRanks)
+    for (const autoppRankId : ppRanks)
     {
-        pme_pp->ppRanks.push_back({ppRankId, 0});
+        pme_pp->ppRanks.push_back({ ppRankId, 0 });
     }
     // The peer PP rank is the last one.
     pme_pp->peerRankId = pme_pp->ppRanks.back().rankId;
-    pme_pp->req.resize(eCommType_NR*pme_pp->ppRanks.size());
-    pme_pp->stat.resize(eCommType_NR*pme_pp->ppRanks.size());
+    pme_pp->req.resize(eCommType_NR * pme_pp->ppRanks.size());
+    pme_pp->stat.resize(eCommType_NR * pme_pp->ppRanks.size());
 #else
     GMX_UNUSED_VALUE(cr);
 #endif
@@ -155,17 +166,17 @@ static std::unique_ptr<gmx_pme_pp> gmx_pme_pp_init(const t_commrec *cr)
     return pme_pp;
 }
 
-static void reset_pmeonly_counters(gmx_wallcycle_t           wcycle,
+static void reset_pmeonly_counters(gmx_wallcycle           wcycle,
                                    gmx_walltime_accounting_t walltime_accounting,
-                                   t_nrnb                   *nrnb,
+                                   t_nrnb*                   nrnb,
                                    int64_t                   step,
                                    bool                      useGpuForPme)
 {
     /* Reset all the counters related to performance over the run */
-    wallcycle_stop(wcycle, ewcRUN);
+    wallcycle_stop(wcycle, WallCycleCounter::Run);
     wallcycle_reset_all(wcycle);
     *nrnb = { 0 };
-    wallcycle_start(wcycle, ewcRUN);
+    wallcycle_start(wcycle, WallCycleCounter::Run);
     walltime_accounting_reset_time(walltime_accounting, step);
 
     if (useGpuForPme)
@@ -174,18 +185,18 @@ static void reset_pmeonly_counters(gmx_wallcycle_t           wcycle,
     }
 }
 
-static gmx_pme_t *gmx_pmeonly_switch(std::vector<gmx_pme_t *> *pmedata,
-                                     const ivec grid_size,
-                                     real ewaldcoeff_q, real ewaldcoeff_lj,
-                                     const t_commrec *cr, const t_inputrec *ir)
+static gmx_pme_t* gmx_pmeonly_switch(std::vector<gmx_pme_t*>* pmedata,
+                                     const ivec               grid_size,
+                                     real                     ewaldcoeff_q,
+                                     real                     ewaldcoeff_lj,
+                                     const t_commrec*         cr,
+                                     const t_inputrec*        ir)
 {
     GMX_ASSERT(pmedata, "Bad PME tuning list pointer");
-    for (auto &pme : *pmedata)
+    for (autopme : *pmedata)
     {
         GMX_ASSERT(pme, "Bad PME tuning list element pointer");
-        if (pme->nkx == grid_size[XX] &&
-            pme->nky == grid_size[YY] &&
-            pme->nkz == grid_size[ZZ])
+        if (gmx_pme_grid_matches(*pme, grid_size))
         {
             /* Here we have found an existing PME data structure that suits us.
              * However, in the GPU case, we have to reinitialize it - there's only one GPU structure.
@@ -198,8 +209,8 @@ static gmx_pme_t *gmx_pmeonly_switch(std::vector<gmx_pme_t *> *pmedata,
         }
     }
 
-    const auto &pme          = pmedata->back();
-    gmx_pme_t  *newStructure = nullptr;
+    const autopme          = pmedata->back();
+    gmx_pme_t*  newStructure = nullptr;
     // Copy last structure with new grid params
     gmx_pme_reinit(&newStructure, cr, pme, ir, grid_size, ewaldcoeff_q, ewaldcoeff_lj);
     pmedata->push_back(newStructure);
@@ -208,46 +219,55 @@ static gmx_pme_t *gmx_pmeonly_switch(std::vector<gmx_pme_t *> *pmedata,
 
 /*! \brief Called by PME-only ranks to receive coefficients and coordinates
  *
- * \param[in,out] pme_pp    PME-PP communication structure.
- * \param[out] natoms       Number of received atoms.
- * \param[out] box        System box, if received.
- * \param[out] maxshift_x        Maximum shift in X direction, if received.
- * \param[out] maxshift_y        Maximum shift in Y direction, if received.
- * \param[out] lambda_q         Free-energy lambda for electrostatics, if received.
- * \param[out] lambda_lj         Free-energy lambda for Lennard-Jones, if received.
- * \param[out] bEnerVir          Set to true if this is an energy/virial calculation step, otherwise set to false.
- * \param[out] step              MD integration step number.
- * \param[out] grid_size         PME grid size, if received.
- * \param[out] ewaldcoeff_q         Ewald cut-off parameter for electrostatics, if received.
- * \param[out] ewaldcoeff_lj         Ewald cut-off parameter for Lennard-Jones, if received.
- * \param[out] atomSetChanged    Set to true only if the local domain atom data (charges/coefficients)
- *                               has been received (after DD) and should be reinitialized. Otherwise not changed.
+ * Note that with GPU direct communication the transfer is only initiated, it is the responsibility
+ * of the caller to synchronize prior to launching spread.
  *
- * \retval pmerecvqxX             All parameters were set, chargeA and chargeB can be NULL.
- * \retval pmerecvqxFINISH        No parameters were set.
- * \retval pmerecvqxSWITCHGRID    Only grid_size and *ewaldcoeff were set.
- * \retval pmerecvqxRESETCOUNTERS *step was set.
+ * \param[in] pme                     PME data structure.
+ * \param[in,out] pme_pp              PME-PP communication structure.
+ * \param[out] natoms                 Number of received atoms.
+ * \param[out] box                    System box, if received.
+ * \param[out] maxshift_x             Maximum shift in X direction, if received.
+ * \param[out] maxshift_y             Maximum shift in Y direction, if received.
+ * \param[out] lambda_q               Free-energy lambda for electrostatics, if received.
+ * \param[out] lambda_lj              Free-energy lambda for Lennard-Jones, if received.
+ * \param[out] computeEnergyAndVirial Set to true if this is an energy/virial calculation
+ *                                    step, otherwise set to false.
+ * \param[out] step                   MD integration step number.
+ * \param[out] grid_size              PME grid size, if received.
+ * \param[out] ewaldcoeff_q           Ewald cut-off parameter for electrostatics, if received.
+ * \param[out] ewaldcoeff_lj          Ewald cut-off parameter for Lennard-Jones, if received.
+ * \param[in]  useGpuForPme           Flag on whether PME is on GPU.
+ * \param[in]  stateGpu               GPU state propagator object.
+ * \param[in]  runMode                PME run mode.
+ *
+ * \retval pmerecvqxX                 All parameters were set, chargeA and chargeB can be NULL.
+ * \retval pmerecvqxFINISH            No parameters were set.
+ * \retval pmerecvqxSWITCHGRID        Only grid_size and *ewaldcoeff were set.
+ * \retval pmerecvqxRESETCOUNTERS     *step was set.
  */
-static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
-                                      int               *natoms,
-                                      matrix             box,
-                                      int               *maxshift_x,
-                                      int               *maxshift_y,
-                                      real              *lambda_q,
-                                      real              *lambda_lj,
-                                      gmx_bool          *bEnerVir,
-                                      int64_t           *step,
-                                      ivec              *grid_size,
-                                      real              *ewaldcoeff_q,
-                                      real              *ewaldcoeff_lj,
-                                      bool              *atomSetChanged)
+static int gmx_pme_recv_coeffs_coords(struct gmx_pme_t*            pme,
+                                      gmx_pme_pp*                  pme_pp,
+                                      int*                         natoms,
+                                      matrix                       box,
+                                      int*                         maxshift_x,
+                                      int*                         maxshift_y,
+                                      real*                        lambda_q,
+                                      real*                        lambda_lj,
+                                      gmx_bool*                    computeEnergyAndVirial,
+                                      int64_t*                     step,
+                                      ivec*                        grid_size,
+                                      real*                        ewaldcoeff_q,
+                                      real*                        ewaldcoeff_lj,
+                                      bool                         useGpuForPme,
+                                      gmx::StatePropagatorDataGpu* stateGpu,
+                                      PmeRunMode gmx_unused        runMode)
 {
     int status = -1;
     int nat    = 0;
 
 #if GMX_MPI
-    unsigned int flags    = 0;
-    int          messages = 0;
+    int  messages       = 0;
+    bool atomSetChanged = false;
 
     do
     {
@@ -255,25 +275,27 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
         cnb.flags = 0;
 
         /* Receive the send count, box and time step from the peer PP node */
-        MPI_Recv(&cnb, sizeof(cnb), MPI_BYTE,
-                 pme_pp->peerRankId, eCommType_CNB,
-                 pme_pp->mpi_comm_mysim, MPI_STATUS_IGNORE);
-
-        /* We accumulate all received flags */
-        flags |= cnb.flags;
+        MPI_Recv(&cnb, sizeof(cnb), MPI_BYTE, pme_pp->peerRankId, eCommType_CNB, pme_pp->mpi_comm_mysim, MPI_STATUS_IGNORE);
 
-        *step  = cnb.step;
+        *step = cnb.step;
 
         if (debug)
         {
-            fprintf(debug, "PME only rank receiving:%s%s%s%s%s\n",
-                    (cnb.flags & PP_PME_CHARGE)        ? " charges" : "",
-                    (cnb.flags & PP_PME_COORD )        ? " coordinates" : "",
-                    (cnb.flags & PP_PME_FINISH)        ? " finish" : "",
-                    (cnb.flags & PP_PME_SWITCHGRID)    ? " switch grid" : "",
+            fprintf(debug,
+                    "PME only rank receiving:%s%s%s%s%s\n",
+                    (cnb.flags & PP_PME_CHARGE) ? " charges" : "",
+                    (cnb.flags & PP_PME_COORD) ? " coordinates" : "",
+                    (cnb.flags & PP_PME_FINISH) ? " finish" : "",
+                    (cnb.flags & PP_PME_SWITCHGRID) ? " switch grid" : "",
                     (cnb.flags & PP_PME_RESETCOUNTERS) ? " reset counters" : "");
         }
 
+        pme_pp->useGpuDirectComm = ((cnb.flags & PP_PME_GPUCOMMS) != 0);
+        GMX_ASSERT(!pme_pp->useGpuDirectComm || (pme_pp->pmeForceSenderGpu != nullptr),
+                   "The use of GPU direct communication for PME-PP is enabled, "
+                   "but the PME GPU force reciever object does not exist");
+        pme_pp->sendForcesDirectToPpGpu = ((cnb.flags & PP_PME_RECVFTOGPU) != 0);
+
         if (cnb.flags & PP_PME_FINISH)
         {
             status = pmerecvqxFINISH;
@@ -286,7 +308,7 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
             *ewaldcoeff_q  = cnb.ewaldcoeff_q;
             *ewaldcoeff_lj = cnb.ewaldcoeff_lj;
 
-            status         = pmerecvqxSWITCHGRID;
+            status = pmerecvqxSWITCHGRID;
         }
 
         if (cnb.flags & PP_PME_RESETCOUNTERS)
@@ -297,10 +319,10 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
 
         if (cnb.flags & (PP_PME_CHARGE | PP_PME_SQRTC6 | PP_PME_SIGMA))
         {
-            *atomSetChanged = true;
+            atomSetChanged = true;
 
             /* Receive the send counts from the other PP nodes */
-            for (auto &sender : pme_pp->ppRanks)
+            for (autosender : pme_pp->ppRanks)
             {
                 if (sender.rankId == pme_pp->peerRankId)
                 {
@@ -308,17 +330,20 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
                 }
                 else
                 {
-                    MPI_Irecv(&sender.numAtoms, sizeof(sender.numAtoms),
+                    MPI_Irecv(&sender.numAtoms,
+                              sizeof(sender.numAtoms),
                               MPI_BYTE,
-                              sender.rankId, eCommType_CNB,
-                              pme_pp->mpi_comm_mysim, &pme_pp->req[messages++]);
+                              sender.rankId,
+                              eCommType_CNB,
+                              pme_pp->mpi_comm_mysim,
+                              &pme_pp->req[messages++]);
                 }
             }
             MPI_Waitall(messages, pme_pp->req.data(), pme_pp->stat.data());
             messages = 0;
 
             nat = 0;
-            for (const auto &sender : pme_pp->ppRanks)
+            for (const autosender : pme_pp->ppRanks)
             {
                 nat += sender.numAtoms;
             }
@@ -357,40 +382,43 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
             /* Receive the charges in place */
             for (int q = 0; q < eCommType_NR; q++)
             {
-                real *bufferPtr;
+                realbufferPtr;
 
-                if (!(cnb.flags & (PP_PME_CHARGE<<q)))
+                if (!(cnb.flags & (PP_PME_CHARGE << q)))
                 {
                     continue;
                 }
                 switch (q)
                 {
-                    case eCommType_ChargeA: bufferPtr = pme_pp->chargeA.data();  break;
-                    case eCommType_ChargeB: bufferPtr = pme_pp->chargeB.data();  break;
+                    case eCommType_ChargeA: bufferPtr = pme_pp->chargeA.data(); break;
+                    case eCommType_ChargeB: bufferPtr = pme_pp->chargeB.data(); break;
                     case eCommType_SQRTC6A: bufferPtr = pme_pp->sqrt_c6A.data(); break;
                     case eCommType_SQRTC6B: bufferPtr = pme_pp->sqrt_c6B.data(); break;
-                    case eCommType_SigmaA:  bufferPtr = pme_pp->sigmaA.data();   break;
-                    case eCommType_SigmaB:  bufferPtr = pme_pp->sigmaB.data();   break;
+                    case eCommType_SigmaA: bufferPtr = pme_pp->sigmaA.data(); break;
+                    case eCommType_SigmaB: bufferPtr = pme_pp->sigmaB.data(); break;
                     default: gmx_incons("Wrong eCommType");
                 }
                 nat = 0;
-                for (const auto &sender : pme_pp->ppRanks)
+                for (const autosender : pme_pp->ppRanks)
                 {
                     if (sender.numAtoms > 0)
                     {
-                        MPI_Irecv(bufferPtr+nat,
-                                  sender.numAtoms*sizeof(real),
+                        MPI_Irecv(bufferPtr + nat,
+                                  sender.numAtoms * sizeof(real),
                                   MPI_BYTE,
-                                  sender.rankId, q,
+                                  sender.rankId,
+                                  q,
                                   pme_pp->mpi_comm_mysim,
                                   &pme_pp->req[messages++]);
                         nat += sender.numAtoms;
                         if (debug)
                         {
-                            fprintf(debug, "Received from PP rank %d: %d %s\n",
-                                    sender.rankId, sender.numAtoms,
-                                    (q == eCommType_ChargeA ||
-                                     q == eCommType_ChargeB) ? "charges" : "params");
+                            fprintf(debug,
+                                    "Received from PP rank %d: %d %s\n",
+                                    sender.rankId,
+                                    sender.numAtoms,
+                                    (q == eCommType_ChargeA || q == eCommType_ChargeB) ? "charges"
+                                                                                       : "params");
                         }
                     }
                 }
@@ -399,31 +427,72 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
 
         if (cnb.flags & PP_PME_COORD)
         {
+            if (atomSetChanged)
+            {
+                gmx_pme_reinit_atoms(pme, nat, pme_pp->chargeA, pme_pp->chargeB);
+                if (useGpuForPme)
+                {
+                    stateGpu->reinit(nat, nat);
+                    pme_gpu_set_device_x(pme, stateGpu->getCoordinates());
+                }
+                if (pme_pp->useGpuDirectComm)
+                {
+                    GMX_ASSERT(runMode == PmeRunMode::GPU,
+                               "GPU Direct PME-PP communication has been enabled, "
+                               "but PME run mode is not PmeRunMode::GPU\n");
+
+                    // This rank will have its data accessed directly by PP rank, so needs to send the remote addresses and re-set atom ranges associated with transfers.
+                    pme_pp->pmeCoordinateReceiverGpu->reinitCoordinateReceiver(stateGpu->getCoordinates());
+                    pme_pp->pmeForceSenderGpu->setForceSendBuffer(pme_gpu_get_device_f(pme));
+                }
+            }
+
+
             /* The box, FE flag and lambda are sent along with the coordinates
              *  */
             copy_mat(cnb.box, box);
-            *lambda_q       = cnb.lambda_q;
-            *lambda_lj      = cnb.lambda_lj;
-            *bEnerVir       = ((cnb.flags & PP_PME_ENER_VIR) != 0U);
-            *step           = cnb.step;
+            *lambda_q               = cnb.lambda_q;
+            *lambda_lj              = cnb.lambda_lj;
+            *computeEnergyAndVirial = ((cnb.flags & PP_PME_ENER_VIR) != 0U);
+            *step                   = cnb.step;
 
             /* Receive the coordinates in place */
             nat = 0;
-            for (const auto &sender : pme_pp->ppRanks)
+            for (const autosender : pme_pp->ppRanks)
             {
                 if (sender.numAtoms > 0)
                 {
-                    MPI_Irecv(pme_pp->x[nat],
-                              sender.numAtoms*sizeof(rvec),
-                              MPI_BYTE,
-                              sender.rankId, eCommType_COORD,
-                              pme_pp->mpi_comm_mysim, &pme_pp->req[messages++]);
+                    if (pme_pp->useGpuDirectComm)
+                    {
+                        if (GMX_THREAD_MPI)
+                        {
+                            pme_pp->pmeCoordinateReceiverGpu->receiveCoordinatesSynchronizerFromPpCudaDirect(
+                                    sender.rankId);
+                        }
+                        else
+                        {
+                            pme_pp->pmeCoordinateReceiverGpu->launchReceiveCoordinatesFromPpCudaMpi(
+                                    stateGpu->getCoordinates(), nat, sender.numAtoms * sizeof(rvec), sender.rankId);
+                        }
+                    }
+                    else
+                    {
+                        MPI_Irecv(pme_pp->x[nat],
+                                  sender.numAtoms * sizeof(rvec),
+                                  MPI_BYTE,
+                                  sender.rankId,
+                                  eCommType_COORD,
+                                  pme_pp->mpi_comm_mysim,
+                                  &pme_pp->req[messages++]);
+                    }
                     nat += sender.numAtoms;
                     if (debug)
                     {
-                        fprintf(debug, "Received from PP rank %d: %d "
+                        fprintf(debug,
+                                "Received from PP rank %d: %d "
                                 "coordinates\n",
-                                sender.rankId, sender.numAtoms);
+                                sender.rankId,
+                                sender.numAtoms);
                     }
                 }
             }
@@ -434,57 +503,96 @@ static int gmx_pme_recv_coeffs_coords(gmx_pme_pp        *pme_pp,
         /* Wait for the coordinates and/or charges to arrive */
         MPI_Waitall(messages, pme_pp->req.data(), pme_pp->stat.data());
         messages = 0;
-    }
-    while (status == -1);
+    } while (status == -1);
 #else
+    GMX_UNUSED_VALUE(pme);
     GMX_UNUSED_VALUE(pme_pp);
     GMX_UNUSED_VALUE(box);
     GMX_UNUSED_VALUE(maxshift_x);
     GMX_UNUSED_VALUE(maxshift_y);
     GMX_UNUSED_VALUE(lambda_q);
     GMX_UNUSED_VALUE(lambda_lj);
-    GMX_UNUSED_VALUE(bEnerVir);
+    GMX_UNUSED_VALUE(computeEnergyAndVirial);
     GMX_UNUSED_VALUE(step);
     GMX_UNUSED_VALUE(grid_size);
     GMX_UNUSED_VALUE(ewaldcoeff_q);
     GMX_UNUSED_VALUE(ewaldcoeff_lj);
-    GMX_UNUSED_VALUE(atomSetChanged);
+    GMX_UNUSED_VALUE(useGpuForPme);
+    GMX_UNUSED_VALUE(stateGpu);
 
     status = pmerecvqxX;
 #endif
 
     if (status == pmerecvqxX)
     {
-        *natoms   = nat;
+        *natoms = nat;
     }
 
     return status;
 }
 
 /*! \brief Send the PME mesh force, virial and energy to the PP-only ranks. */
-static void gmx_pme_send_force_vir_ener(gmx_pme_pp *pme_pp,
-                                        const PmeOutput &output,
-                                        real dvdlambda_q, real dvdlambda_lj,
-                                        float cycles)
+static void gmx_pme_send_force_vir_ener(const gmx_pme_t& pme,
+                                        gmx_pme_pp*      pme_pp,
+                                        const PmeOutput& output,
+                                        real             dvdlambda_q,
+                                        real             dvdlambda_lj,
+                                        float            cycles)
 {
 #if GMX_MPI
     gmx_pme_comm_vir_ene_t cve;
     int                    messages, ind_start, ind_end;
     cve.cycles = cycles;
 
-    /* Now the evaluated forces have to be transferred to the PP nodes */
+    if (pme_pp->useGpuDirectComm)
+    {
+        GMX_ASSERT((pme_pp->pmeForceSenderGpu != nullptr),
+                   "The use of GPU direct communication for PME-PP is enabled, "
+                   "but the PME GPU force reciever object does not exist");
+    }
+
     messages = 0;
     ind_end  = 0;
-    for (const auto &receiver : pme_pp->ppRanks)
+
+    /* Now the evaluated forces have to be transferred to the PP ranks */
+    if (pme_pp->useGpuDirectComm && GMX_THREAD_MPI)
     {
-        ind_start = ind_end;
-        ind_end   = ind_start + receiver.numAtoms;
-        if (MPI_Isend(const_cast<void *>(static_cast<const void *>(output.forces_[ind_start])),
-                      (ind_end-ind_start)*sizeof(rvec), MPI_BYTE,
-                      receiver.rankId, 0,
-                      pme_pp->mpi_comm_mysim, &pme_pp->req[messages++]) != 0)
+        int numPpRanks = static_cast<int>(pme_pp->ppRanks.size());
+#    pragma omp parallel for num_threads(std::min(numPpRanks, pme.nthread)) schedule(static)
+        for (int i = 0; i < numPpRanks; i++)
         {
-            gmx_comm("MPI_Isend failed in do_pmeonly");
+            auto& receiver = pme_pp->ppRanks[i];
+            pme_pp->pmeForceSenderGpu->sendFToPpCudaDirect(
+                    receiver.rankId, receiver.numAtoms, pme_pp->sendForcesDirectToPpGpu);
+        }
+    }
+    else
+    {
+        for (const auto& receiver : pme_pp->ppRanks)
+        {
+            ind_start = ind_end;
+            ind_end   = ind_start + receiver.numAtoms;
+            if (pme_pp->useGpuDirectComm)
+            {
+                pme_pp->pmeForceSenderGpu->sendFToPpCudaMpi(pme_gpu_get_device_f(&pme),
+                                                            ind_start,
+                                                            receiver.numAtoms * sizeof(rvec),
+                                                            receiver.rankId,
+                                                            &pme_pp->req[messages]);
+            }
+            else
+            {
+                void* sendbuf = const_cast<void*>(static_cast<const void*>(output.forces_[ind_start]));
+                // Send using MPI
+                MPI_Isend(sendbuf,
+                          receiver.numAtoms * sizeof(rvec),
+                          MPI_BYTE,
+                          receiver.rankId,
+                          0,
+                          pme_pp->mpi_comm_mysim,
+                          &pme_pp->req[messages]);
+            }
+            messages++;
         }
     }
 
@@ -502,17 +610,15 @@ static void gmx_pme_send_force_vir_ener(gmx_pme_pp *pme_pp,
 
     if (debug)
     {
-        fprintf(debug, "PME rank sending to PP rank %d: virial and energy\n",
-                pme_pp->peerRankId);
+        fprintf(debug, "PME rank sending to PP rank %d: virial and energy\n", pme_pp->peerRankId);
     }
-    MPI_Isend(&cve, sizeof(cve), MPI_BYTE,
-              pme_pp->peerRankId, 1,
-              pme_pp->mpi_comm_mysim, &pme_pp->req[messages++]);
+    MPI_Isend(&cve, sizeof(cve), MPI_BYTE, pme_pp->peerRankId, 1, pme_pp->mpi_comm_mysim, &pme_pp->req[messages++]);
 
     /* Wait for the forces to arrive */
     MPI_Waitall(messages, pme_pp->req.data(), pme_pp->stat.data());
 #else
-    gmx_call("MPI not enabled");
+    GMX_RELEASE_ASSERT(false, "Invalid call to gmx_pme_send_force_vir_ener");
+    GMX_UNUSED_VALUE(pme);
     GMX_UNUSED_VALUE(pme_pp);
     GMX_UNUSED_VALUE(output);
     GMX_UNUSED_VALUE(dvdlambda_q);
@@ -521,35 +627,64 @@ static void gmx_pme_send_force_vir_ener(gmx_pme_pp *pme_pp,
 #endif
 }
 
-int gmx_pmeonly(struct gmx_pme_t *pme,
-                const t_commrec *cr, t_nrnb *mynrnb,
-                gmx_wallcycle  *wcycle,
-                gmx_walltime_accounting_t walltime_accounting,
-                t_inputrec *ir, PmeRunMode runMode)
+int gmx_pmeonly(struct gmx_pme_t*               pme,
+                const t_commrec*                cr,
+                t_nrnb*                         mynrnb,
+                gmx_wallcycle*                  wcycle,
+                gmx_walltime_accounting_t       walltime_accounting,
+                t_inputrec*                     ir,
+                PmeRunMode                      runMode,
+                bool                            useGpuPmePpCommunication,
+                const gmx::DeviceStreamManager* deviceStreamManager)
 {
-    int                ret;
-    int                natoms = 0;
-    matrix             box;
-    real               lambda_q   = 0;
-    real               lambda_lj  = 0;
-    int                maxshift_x = 0, maxshift_y = 0;
-    real               dvdlambda_q, dvdlambda_lj;
-    float              cycles;
-    int                count;
-    gmx_bool           bEnerVir = FALSE;
-    int64_t            step;
+    int     ret;
+    int     natoms = 0;
+    matrix  box;
+    real    lambda_q   = 0;
+    real    lambda_lj  = 0;
+    int     maxshift_x = 0, maxshift_y = 0;
+    real    dvdlambda_q, dvdlambda_lj;
+    float   cycles;
+    int     count;
+    bool    computeEnergyAndVirial = false;
+    int64_t step;
 
     /* This data will only use with PME tuning, i.e. switching PME grids */
-    std::vector<gmx_pme_t *> pmedata;
+    std::vector<gmx_pme_t*> pmedata;
     pmedata.push_back(pme);
 
-    auto       pme_pp       = gmx_pme_pp_init(cr);
-    //TODO the variable below should be queried from the task assignment info
+    auto pme_pp = gmx_pme_pp_init(cr);
+
+    std::unique_ptr<gmx::StatePropagatorDataGpu> stateGpu;
+    // TODO the variable below should be queried from the task assignment info
     const bool useGpuForPme = (runMode == PmeRunMode::GPU) || (runMode == PmeRunMode::Mixed);
     if (useGpuForPme)
     {
+        GMX_RELEASE_ASSERT(
+                deviceStreamManager != nullptr,
+                "Device stream manager can not be nullptr when using GPU in PME-only rank.");
+        GMX_RELEASE_ASSERT(deviceStreamManager->streamIsValid(gmx::DeviceStreamType::Pme),
+                           "Device stream can not be nullptr when using GPU in PME-only rank");
         changePinningPolicy(&pme_pp->chargeA, pme_get_pinning_policy());
         changePinningPolicy(&pme_pp->x, pme_get_pinning_policy());
+        if (useGpuPmePpCommunication)
+        {
+            pme_pp->pmeCoordinateReceiverGpu = std::make_unique<gmx::PmeCoordinateReceiverGpu>(
+                    pme_pp->mpi_comm_mysim, deviceStreamManager->context(), pme_pp->ppRanks);
+            pme_pp->pmeForceSenderGpu =
+                    std::make_unique<gmx::PmeForceSenderGpu>(pme_gpu_get_f_ready_synchronizer(pme),
+                                                             pme_pp->mpi_comm_mysim,
+                                                             deviceStreamManager->context(),
+                                                             pme_pp->ppRanks);
+        }
+        // TODO: Special PME-only constructor is used here. There is no mechanism to prevent from using the other constructor here.
+        //       This should be made safer.
+        stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(
+                &deviceStreamManager->stream(gmx::DeviceStreamType::Pme),
+                deviceStreamManager->context(),
+                GpuApiCallBehavior::Async,
+                pme_gpu_get_block_size(pme),
+                wcycle);
     }
 
     clear_nrnb(mynrnb);
@@ -562,19 +697,23 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
         {
             /* Domain decomposition */
             ivec newGridSize;
-            bool atomSetChanged = false;
-            real ewaldcoeff_q   = 0, ewaldcoeff_lj = 0;
-            ret = gmx_pme_recv_coeffs_coords(pme_pp.get(),
+            real ewaldcoeff_q = 0, ewaldcoeff_lj = 0;
+            ret = gmx_pme_recv_coeffs_coords(pme,
+                                             pme_pp.get(),
                                              &natoms,
                                              box,
-                                             &maxshift_x, &maxshift_y,
-                                             &lambda_q, &lambda_lj,
-                                             &bEnerVir,
+                                             &maxshift_x,
+                                             &maxshift_y,
+                                             &lambda_q,
+                                             &lambda_lj,
+                                             &computeEnergyAndVirial,
                                              &step,
                                              &newGridSize,
                                              &ewaldcoeff_q,
                                              &ewaldcoeff_lj,
-                                             &atomSetChanged);
+                                             useGpuForPme,
+                                             stateGpu.get(),
+                                             runMode);
 
             if (ret == pmerecvqxSWITCHGRID)
             {
@@ -582,18 +721,12 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
                 pme = gmx_pmeonly_switch(&pmedata, newGridSize, ewaldcoeff_q, ewaldcoeff_lj, cr, ir);
             }
 
-            if (atomSetChanged)
-            {
-                gmx_pme_reinit_atoms(pme, natoms, pme_pp->chargeA.data());
-            }
-
             if (ret == pmerecvqxRESETCOUNTERS)
             {
                 /* Reset the cycle and flop counters */
                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, step, useGpuForPme);
             }
-        }
-        while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
+        } while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
 
         if (ret == pmerecvqxFINISH)
         {
@@ -603,11 +736,11 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
 
         if (count == 0)
         {
-            wallcycle_start(wcycle, ewcRUN);
+            wallcycle_start(wcycle, WallCycleCounter::Run);
             walltime_accounting_start_time(walltime_accounting);
         }
 
-        wallcycle_start(wcycle, ewcPMEMESH);
+        wallcycle_start(wcycle, WallCycleCounter::PmeMesh);
 
         dvdlambda_q  = 0;
         dvdlambda_lj = 0;
@@ -616,40 +749,72 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
         // of pme_pp (maybe box, energy and virial, too; and likewise
         // from mdatoms for the other call to gmx_pme_do), so we have
         // fewer lines of code and less parameter passing.
-        const int pmeFlags = GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0);
-        PmeOutput output   = {{}, 0, {{0}}, 0, {{0}}};
+        gmx::StepWorkload stepWork;
+        stepWork.computeVirial = computeEnergyAndVirial;
+        stepWork.computeEnergy = computeEnergyAndVirial;
+        stepWork.computeForces = true;
+        PmeOutput output       = { {}, false, 0, { { 0 } }, 0, 0, { { 0 } }, 0 };
         if (useGpuForPme)
         {
-            const bool boxChanged = false;
-            //TODO this should be set properly by gmx_pme_recv_coeffs_coords,
+            stepWork.haveDynamicBox      = false;
+            stepWork.useGpuPmeFReduction = pme_pp->useGpuDirectComm;
+            // TODO this should be set properly by gmx_pme_recv_coeffs_coords,
             // or maybe use inputrecDynamicBox(ir), at the very least - change this when this codepath is tested!
-            pme_gpu_prepare_computation(pme, boxChanged, box, wcycle, pmeFlags);
-            pme_gpu_launch_spread(pme, as_rvec_array(pme_pp->x.data()), wcycle);
-            pme_gpu_launch_complex_transforms(pme, wcycle);
-            pme_gpu_launch_gather(pme, wcycle, PmeForceOutputHandling::Set, false);
-            output = pme_gpu_wait_finish_task(pme, pmeFlags, wcycle);
+            pme_gpu_prepare_computation(pme, box, wcycle, stepWork);
+            if (!pme_pp->useGpuDirectComm)
+            {
+                stateGpu->copyCoordinatesToGpu(gmx::ArrayRef<gmx::RVec>(pme_pp->x),
+                                               gmx::AtomLocality::Local);
+            }
+            // On the separate PME rank we do not need a synchronizer as we schedule everything in a single stream
+            // TODO: with pme on GPU the receive should make a list of synchronizers and pass it here #3157
+            auto xReadyOnDevice = nullptr;
+
+            pme_gpu_launch_spread(pme,
+                                  xReadyOnDevice,
+                                  wcycle,
+                                  lambda_q,
+                                  pme_pp->useGpuDirectComm,
+                                  pme_pp->pmeCoordinateReceiverGpu.get());
+            pme_gpu_launch_complex_transforms(pme, wcycle, stepWork);
+            pme_gpu_launch_gather(pme, wcycle, lambda_q);
+            output = pme_gpu_wait_finish_task(pme, computeEnergyAndVirial, lambda_q, wcycle);
             pme_gpu_reinit_computation(pme, wcycle);
         }
         else
         {
-            GMX_ASSERT(pme_pp->x.size() == static_cast<size_t>(natoms), "The coordinate buffer should have size natoms");
-
-            gmx_pme_do(pme, pme_pp->x, pme_pp->f,
-                       pme_pp->chargeA.data(), pme_pp->chargeB.data(),
-                       pme_pp->sqrt_c6A.data(), pme_pp->sqrt_c6B.data(),
-                       pme_pp->sigmaA.data(), pme_pp->sigmaB.data(), box,
-                       cr, maxshift_x, maxshift_y, mynrnb, wcycle,
-                       output.coulombVirial_, output.lennardJonesVirial_,
-                       &output.coulombEnergy_, &output.lennardJonesEnergy_,
-                       lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
-                       pmeFlags);
+            GMX_ASSERT(pme_pp->x.size() == static_cast<size_t>(natoms),
+                       "The coordinate buffer should have size natoms");
+
+            gmx_pme_do(pme,
+                       pme_pp->x,
+                       pme_pp->f,
+                       pme_pp->chargeA,
+                       pme_pp->chargeB,
+                       pme_pp->sqrt_c6A,
+                       pme_pp->sqrt_c6B,
+                       pme_pp->sigmaA,
+                       pme_pp->sigmaB,
+                       box,
+                       cr,
+                       maxshift_x,
+                       maxshift_y,
+                       mynrnb,
+                       wcycle,
+                       output.coulombVirial_,
+                       output.lennardJonesVirial_,
+                       &output.coulombEnergy_,
+                       &output.lennardJonesEnergy_,
+                       lambda_q,
+                       lambda_lj,
+                       &dvdlambda_q,
+                       &dvdlambda_lj,
+                       stepWork);
             output.forces_ = pme_pp->f;
         }
 
-        cycles = wallcycle_stop(wcycle, ewcPMEMESH);
-
-        gmx_pme_send_force_vir_ener(pme_pp.get(), output,
-                                    dvdlambda_q, dvdlambda_lj, cycles);
+        cycles = wallcycle_stop(wcycle, WallCycleCounter::PmeMesh);
+        gmx_pme_send_force_vir_ener(*pme, pme_pp.get(), output, dvdlambda_q, dvdlambda_lj, cycles);
 
         count++;
     } /***** end of quasi-loop, we stop with the break above */