Pipeline GPU PME Spline/Spread with PP Comms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_only.cpp
index 3151ec1ca719e906edeae6acc140d6a03f73b51a..56002966be057185ae0bf3924073cd6962b43759 100644 (file)
@@ -218,6 +218,9 @@ static gmx_pme_t* gmx_pmeonly_switch(std::vector<gmx_pme_t*>* pmedata,
 }
 
 /*! \brief Called by PME-only ranks to receive coefficients and coordinates
+ *
+ * Note that with GPU direct communication the transfer is only initiated, it is the responsibility
+ * of the caller to synchronize prior to launching spread.
  *
  * \param[in] pme                     PME data structure.
  * \param[in,out] pme_pp              PME-PP communication structure.
@@ -438,9 +441,8 @@ static int gmx_pme_recv_coeffs_coords(struct gmx_pme_t*            pme,
                                "GPU Direct PME-PP communication has been enabled, "
                                "but PME run mode is not PmeRunMode::GPU\n");
 
-                    // This rank will have its data accessed directly by PP rank, so needs to send the remote addresses.
-                    pme_pp->pmeCoordinateReceiverGpu->sendCoordinateBufferAddressToPpRanks(
-                            stateGpu->getCoordinates());
+                    // This rank will have its data accessed directly by PP rank, so needs to send the remote addresses and re-set atom ranges associated with transfers.
+                    pme_pp->pmeCoordinateReceiverGpu->reinitCoordinateReceiver(stateGpu->getCoordinates());
                     pme_pp->pmeForceSenderGpu->setForceSendBuffer(pme_gpu_get_device_f(pme));
                 }
             }
@@ -495,11 +497,6 @@ static int gmx_pme_recv_coeffs_coords(struct gmx_pme_t*            pme,
                 }
             }
 
-            if (pme_pp->useGpuDirectComm)
-            {
-                pme_pp->pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromPpRanks();
-            }
-
             status = pmerecvqxX;
         }
 
@@ -673,9 +670,7 @@ int gmx_pmeonly(struct gmx_pme_t*               pme,
         if (useGpuPmePpCommunication)
         {
             pme_pp->pmeCoordinateReceiverGpu = std::make_unique<gmx::PmeCoordinateReceiverGpu>(
-                    deviceStreamManager->stream(gmx::DeviceStreamType::Pme),
-                    pme_pp->mpi_comm_mysim,
-                    pme_pp->ppRanks);
+                    pme_pp->mpi_comm_mysim, deviceStreamManager->context(), pme_pp->ppRanks);
             pme_pp->pmeForceSenderGpu =
                     std::make_unique<gmx::PmeForceSenderGpu>(pme_gpu_get_f_ready_synchronizer(pme),
                                                              pme_pp->mpi_comm_mysim,
@@ -775,7 +770,12 @@ int gmx_pmeonly(struct gmx_pme_t*               pme,
             // TODO: with pme on GPU the receive should make a list of synchronizers and pass it here #3157
             auto xReadyOnDevice = nullptr;
 
-            pme_gpu_launch_spread(pme, xReadyOnDevice, wcycle, lambda_q);
+            pme_gpu_launch_spread(pme,
+                                  xReadyOnDevice,
+                                  wcycle,
+                                  lambda_q,
+                                  pme_pp->useGpuDirectComm,
+                                  pme_pp->pmeCoordinateReceiverGpu.get());
             pme_gpu_launch_complex_transforms(pme, wcycle, stepWork);
             pme_gpu_launch_gather(pme, wcycle, lambda_q);
             output = pme_gpu_wait_finish_task(pme, computeEnergyAndVirial, lambda_q, wcycle);