Remove thread-MPI limitation for GPU direct PME-PP communication
[alexxy/gromacs.git] / src / gromacs / ewald / pme_coordinate_receiver_gpu_impl.cu
index 10f48f1d979085e75572f4bf91ed8ed228acb971..7fa2122dfb8a1cc5a65f03084e9356de114fc7a9 100644 (file)
@@ -43,6 +43,7 @@
  */
 #include "gmxpre.h"
 
+#include "gromacs/ewald/pme_pp_communication.h"
 #include "pme_coordinate_receiver_gpu_impl.h"
 
 #include "config.h"
@@ -62,9 +63,6 @@ PmeCoordinateReceiverGpu::Impl::Impl(const DeviceStream&    pmeStream,
     comm_(comm),
     ppRanks_(ppRanks)
 {
-    GMX_RELEASE_ASSERT(
-            GMX_THREAD_MPI,
-            "PME-PP GPU Communication is currently only supported with thread-MPI enabled");
     request_.resize(ppRanks.size());
     ppSync_.resize(ppRanks.size());
 }
@@ -73,28 +71,34 @@ PmeCoordinateReceiverGpu::Impl::~Impl() = default;
 
 void PmeCoordinateReceiverGpu::Impl::sendCoordinateBufferAddressToPpRanks(DeviceBuffer<RVec> d_x)
 {
-
-    int ind_start = 0;
-    int ind_end   = 0;
-    for (const auto& receiver : ppRanks_)
+    // Need to send address to PP rank only for thread-MPI as PP rank pushes data using cudamemcpy
+    if (GMX_THREAD_MPI)
     {
-        ind_start = ind_end;
-        ind_end   = ind_start + receiver.numAtoms;
-
-        // Data will be transferred directly from GPU.
-        void* sendBuf = reinterpret_cast<void*>(&d_x[ind_start]);
+        int ind_start = 0;
+        int ind_end   = 0;
+        for (const auto& receiver : ppRanks_)
+        {
+            ind_start = ind_end;
+            ind_end   = ind_start + receiver.numAtoms;
 
+            // Data will be transferred directly from GPU.
+            void* sendBuf = reinterpret_cast<void*>(&d_x[ind_start]);
 #if GMX_MPI
-        MPI_Send(&sendBuf, sizeof(void**), MPI_BYTE, receiver.rankId, 0, comm_);
+            MPI_Send(&sendBuf, sizeof(void**), MPI_BYTE, receiver.rankId, 0, comm_);
 #else
-        GMX_UNUSED_VALUE(sendBuf);
+            GMX_UNUSED_VALUE(sendBuf);
 #endif
+        }
     }
 }
 
 /*! \brief Receive coordinate synchronizer pointer from the PP ranks. */
 void PmeCoordinateReceiverGpu::Impl::receiveCoordinatesSynchronizerFromPpCudaDirect(int ppRank)
 {
+    GMX_ASSERT(GMX_THREAD_MPI,
+               "receiveCoordinatesSynchronizerFromPpCudaDirect is expected to be called only for "
+               "Thread-MPI");
+
     // Data will be pushed directly from PP task
 
 #if GMX_MPI
@@ -106,18 +110,44 @@ void PmeCoordinateReceiverGpu::Impl::receiveCoordinatesSynchronizerFromPpCudaDir
 #endif
 }
 
-void PmeCoordinateReceiverGpu::Impl::enqueueWaitReceiveCoordinatesFromPpCudaDirect()
+/*! \brief Receive coordinate data using CUDA-aware MPI */
+void PmeCoordinateReceiverGpu::Impl::launchReceiveCoordinatesFromPpCudaMpi(DeviceBuffer<RVec> recvbuf,
+                                                                           int numAtoms,
+                                                                           int numBytes,
+                                                                           int ppRank)
+{
+    GMX_ASSERT(GMX_LIB_MPI,
+               "launchReceiveCoordinatesFromPpCudaMpi is expected to be called only for Lib-MPI");
+
+#if GMX_MPI
+    MPI_Irecv(&recvbuf[numAtoms], numBytes, MPI_BYTE, ppRank, eCommType_COORD_GPU, comm_, &request_[recvCount_++]);
+#else
+    GMX_UNUSED_VALUE(recvbuf);
+    GMX_UNUSED_VALUE(numAtoms);
+    GMX_UNUSED_VALUE(numBytes);
+    GMX_UNUSED_VALUE(ppRank);
+#endif
+}
+
+void PmeCoordinateReceiverGpu::Impl::synchronizeOnCoordinatesFromPpRanks()
 {
     if (recvCount_ > 0)
     {
-        // ensure PME calculation doesn't commence until coordinate data has been transferred
+        // ensure PME calculation doesn't commence until coordinate data/remote events
+        // has been transferred
 #if GMX_MPI
         MPI_Waitall(recvCount_, request_.data(), MPI_STATUS_IGNORE);
 #endif
-        for (int i = 0; i < recvCount_; i++)
+
+        // Make PME stream wait on PP to PME data trasnfer events
+        if (GMX_THREAD_MPI)
         {
-            ppSync_[i]->enqueueWaitEvent(pmeStream_);
+            for (int i = 0; i < recvCount_; i++)
+            {
+                ppSync_[i]->enqueueWaitEvent(pmeStream_);
+            }
         }
+
         // reset receive counter
         recvCount_ = 0;
     }
@@ -142,9 +172,17 @@ void PmeCoordinateReceiverGpu::receiveCoordinatesSynchronizerFromPpCudaDirect(in
     impl_->receiveCoordinatesSynchronizerFromPpCudaDirect(ppRank);
 }
 
-void PmeCoordinateReceiverGpu::enqueueWaitReceiveCoordinatesFromPpCudaDirect()
+void PmeCoordinateReceiverGpu::launchReceiveCoordinatesFromPpCudaMpi(DeviceBuffer<RVec> recvbuf,
+                                                                     int                numAtoms,
+                                                                     int                numBytes,
+                                                                     int                ppRank)
+{
+    impl_->launchReceiveCoordinatesFromPpCudaMpi(recvbuf, numAtoms, numBytes, ppRank);
+}
+
+void PmeCoordinateReceiverGpu::synchronizeOnCoordinatesFromPpRanks()
 {
-    impl_->enqueueWaitReceiveCoordinatesFromPpCudaDirect();
+    impl_->synchronizeOnCoordinatesFromPpRanks();
 }
 
 } // namespace gmx