Remove thread-MPI limitation for GPU direct PME-PP communication
[alexxy/gromacs.git] / src / gromacs / ewald / pme_only.cpp
index 5130034a9fb3c06b2fc1a1ba565c91ac2e743674..64f685ab44a61090250d34e98340f0fed5428669 100644 (file)
 #include "pme_output.h"
 #include "pme_pp_communication.h"
 
-/*! \brief environment variable to enable GPU P2P communication */
-static const bool c_enableGpuPmePpComms =
-        GMX_GPU_CUDA && GMX_THREAD_MPI && (getenv("GMX_GPU_PME_PP_COMMS") != nullptr);
-
 /*! \brief Master PP-PME communication data structure */
 struct gmx_pme_pp
 {
@@ -466,8 +462,16 @@ static int gmx_pme_recv_coeffs_coords(struct gmx_pme_t*            pme,
                 {
                     if (pme_pp->useGpuDirectComm)
                     {
-                        pme_pp->pmeCoordinateReceiverGpu->receiveCoordinatesSynchronizerFromPpCudaDirect(
-                                sender.rankId);
+                        if (GMX_THREAD_MPI)
+                        {
+                            pme_pp->pmeCoordinateReceiverGpu->receiveCoordinatesSynchronizerFromPpCudaDirect(
+                                    sender.rankId);
+                        }
+                        else
+                        {
+                            pme_pp->pmeCoordinateReceiverGpu->launchReceiveCoordinatesFromPpCudaMpi(
+                                    stateGpu->getCoordinates(), nat, sender.numAtoms * sizeof(rvec), sender.rankId);
+                        }
                     }
                     else
                     {
@@ -493,7 +497,7 @@ static int gmx_pme_recv_coeffs_coords(struct gmx_pme_t*            pme,
 
             if (pme_pp->useGpuDirectComm)
             {
-                pme_pp->pmeCoordinateReceiverGpu->enqueueWaitReceiveCoordinatesFromPpCudaDirect();
+                pme_pp->pmeCoordinateReceiverGpu->synchronizeOnCoordinatesFromPpRanks();
             }
 
             status = pmerecvqxX;
@@ -531,7 +535,8 @@ static int gmx_pme_recv_coeffs_coords(struct gmx_pme_t*            pme,
 }
 
 /*! \brief Send the PME mesh force, virial and energy to the PP-only ranks. */
-static void gmx_pme_send_force_vir_ener(gmx_pme_pp*      pme_pp,
+static void gmx_pme_send_force_vir_ener(const gmx_pme_t& pme,
+                                        gmx_pme_pp*      pme_pp,
                                         const PmeOutput& output,
                                         real             dvdlambda_q,
                                         real             dvdlambda_lj,
@@ -547,18 +552,32 @@ static void gmx_pme_send_force_vir_ener(gmx_pme_pp*      pme_pp,
     ind_end  = 0;
     for (const auto& receiver : pme_pp->ppRanks)
     {
-        ind_start     = ind_end;
-        ind_end       = ind_start + receiver.numAtoms;
-        void* sendbuf = const_cast<void*>(static_cast<const void*>(output.forces_[ind_start]));
+        ind_start = ind_end;
+        ind_end   = ind_start + receiver.numAtoms;
         if (pme_pp->useGpuDirectComm)
         {
             GMX_ASSERT((pme_pp->pmeForceSenderGpu != nullptr),
                        "The use of GPU direct communication for PME-PP is enabled, "
                        "but the PME GPU force reciever object does not exist");
-            pme_pp->pmeForceSenderGpu->sendFSynchronizerToPpCudaDirect(receiver.rankId);
+
+            if (GMX_THREAD_MPI)
+            {
+                pme_pp->pmeForceSenderGpu->sendFSynchronizerToPpCudaDirect(receiver.rankId);
+            }
+            else
+            {
+                pme_pp->pmeForceSenderGpu->sendFToPpCudaMpi(pme_gpu_get_device_f(&pme),
+                                                            ind_start,
+                                                            receiver.numAtoms * sizeof(rvec),
+                                                            receiver.rankId,
+                                                            &pme_pp->req[messages]);
+
+                messages++;
+            }
         }
         else
         {
+            void* sendbuf = const_cast<void*>(static_cast<const void*>(output.forces_[ind_start]));
             // Send using MPI
             MPI_Isend(sendbuf,
                       receiver.numAtoms * sizeof(rvec),
@@ -593,6 +612,7 @@ static void gmx_pme_send_force_vir_ener(gmx_pme_pp*      pme_pp,
     MPI_Waitall(messages, pme_pp->req.data(), pme_pp->stat.data());
 #else
     GMX_RELEASE_ASSERT(false, "Invalid call to gmx_pme_send_force_vir_ener");
+    GMX_UNUSED_VALUE(pme);
     GMX_UNUSED_VALUE(pme_pp);
     GMX_UNUSED_VALUE(output);
     GMX_UNUSED_VALUE(dvdlambda_q);
@@ -608,6 +628,7 @@ int gmx_pmeonly(struct gmx_pme_t*               pme,
                 gmx_walltime_accounting_t       walltime_accounting,
                 t_inputrec*                     ir,
                 PmeRunMode                      runMode,
+                bool                            useGpuPmePpCommunication,
                 const gmx::DeviceStreamManager* deviceStreamManager)
 {
     int     ret;
@@ -640,7 +661,7 @@ int gmx_pmeonly(struct gmx_pme_t*               pme,
                            "Device stream can not be nullptr when using GPU in PME-only rank");
         changePinningPolicy(&pme_pp->chargeA, pme_get_pinning_policy());
         changePinningPolicy(&pme_pp->x, pme_get_pinning_policy());
-        if (c_enableGpuPmePpComms)
+        if (useGpuPmePpCommunication)
         {
             pme_pp->pmeCoordinateReceiverGpu = std::make_unique<gmx::PmeCoordinateReceiverGpu>(
                     deviceStreamManager->stream(gmx::DeviceStreamType::Pme),
@@ -780,7 +801,7 @@ int gmx_pmeonly(struct gmx_pme_t*               pme,
         }
 
         cycles = wallcycle_stop(wcycle, WallCycleCounter::PmeMesh);
-        gmx_pme_send_force_vir_ener(pme_pp.get(), output, dvdlambda_q, dvdlambda_lj, cycles);
+        gmx_pme_send_force_vir_ener(*pme, pme_pp.get(), output, dvdlambda_q, dvdlambda_lj, cycles);
 
         count++;
     } /***** end of quasi-loop, we stop with the break above */