Remove thread-MPI limitation for GPU PP Halo exchange
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cu
index ebc98e784feca58b79c89a32edd86b416d27ccd8..65af08d35d73aee7b8a594a0f832962e9bdd44e1 100644 (file)
@@ -206,47 +206,77 @@ void GpuHaloExchange::Impl::reinitHalo(float3* d_coordinatesBuffer, float3* d_fo
         copyToDeviceBuffer(
                 &d_indexMap_, h_indexMap_.data(), 0, newSize, nonLocalStream_, GpuApiCallBehavior::Async, nullptr);
     }
-    // This rank will push data to its neighbor, so needs to know
-    // the remote receive address and similarly send its receive
-    // address to other neighbour. We can do this here in reinit fn
-    // since the pointers will not change until the next NS step.
 
-    // Coordinates buffer:
-    void* recvPtr = static_cast<void*>(&d_x_[atomOffset_]);
 #if GMX_MPI
-    MPI_Sendrecv(&recvPtr,
-                 sizeof(void*),
+    // Exchange of remote addresses from neighboring ranks is needed only with CUDA-direct as cudamemcpy needs both src/dst pointer
+    // MPI calls such as MPI_send doesn't worry about receiving address, that is taken care by MPI_recv call in neighboring rank
+    if (GMX_THREAD_MPI)
+    {
+        // This rank will push data to its neighbor, so needs to know
+        // the remote receive address and similarly send its receive
+        // address to other neighbour. We can do this here in reinit fn
+        // since the pointers will not change until the next NS step.
+
+        // Coordinates buffer:
+        float3* recvPtr = &d_x_[atomOffset_];
+        MPI_Sendrecv(&recvPtr,
+                     sizeof(void*),
+                     MPI_BYTE,
+                     recvRankX_,
+                     0,
+                     &remoteXPtr_,
+                     sizeof(void*),
+                     MPI_BYTE,
+                     sendRankX_,
+                     0,
+                     mpi_comm_mysim_,
+                     MPI_STATUS_IGNORE);
+
+        // Force buffer:
+        recvPtr = d_recvBuf_;
+        MPI_Sendrecv(&recvPtr,
+                     sizeof(void*),
+                     MPI_BYTE,
+                     recvRankF_,
+                     0,
+                     &remoteFPtr_,
+                     sizeof(void*),
+                     MPI_BYTE,
+                     sendRankF_,
+                     0,
+                     mpi_comm_mysim_,
+                     MPI_STATUS_IGNORE);
+    }
+#endif
+
+    wallcycle_sub_stop(wcycle_, ewcsDD_GPU);
+    wallcycle_stop(wcycle_, ewcDOMDEC);
+
+    return;
+}
+
+void GpuHaloExchange::Impl::enqueueWaitRemoteCoordinatesReadyEvent(GpuEventSynchronizer* coordinatesReadyOnDeviceEvent)
+{
+    GMX_ASSERT(coordinatesReadyOnDeviceEvent != nullptr,
+               "Co-ordinate Halo exchange requires valid co-ordinate ready event");
+
+    // Wait for event from receiving task that remote coordinates are ready, and enqueue that event to stream used
+    // for subsequent data push. This avoids a race condition with the remote data being written in the previous timestep.
+    // Similarly send event to task that will push data to this task.
+    GpuEventSynchronizer* remoteCoordinatesReadyOnDeviceEvent;
+    MPI_Sendrecv(&coordinatesReadyOnDeviceEvent,
+                 sizeof(GpuEventSynchronizer*),
                  MPI_BYTE,
                  recvRankX_,
                  0,
-                 &remoteXPtr_,
-                 sizeof(void*),
+                 &remoteCoordinatesReadyOnDeviceEvent,
+                 sizeof(GpuEventSynchronizer*),
                  MPI_BYTE,
                  sendRankX_,
                  0,
                  mpi_comm_mysim_,
                  MPI_STATUS_IGNORE);
-
-    // Force buffer:
-    recvPtr = static_cast<void*>(d_recvBuf_);
-    MPI_Sendrecv(&recvPtr,
-                 sizeof(void*),
-                 MPI_BYTE,
-                 recvRankF_,
-                 0,
-                 &remoteFPtr_,
-                 sizeof(void*),
-                 MPI_BYTE,
-                 sendRankF_,
-                 0,
-                 mpi_comm_mysim_,
-                 MPI_STATUS_IGNORE);
-#endif
-
-    wallcycle_sub_stop(wcycle_, ewcsDD_GPU);
-    wallcycle_stop(wcycle_, ewcDOMDEC);
-
-    return;
+    remoteCoordinatesReadyOnDeviceEvent->enqueueWaitEvent(nonLocalStream_);
 }
 
 void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box,
@@ -305,7 +335,15 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
     // ToDo: We need further refinement here as communicateHaloData includes launch time for cudamemcpyasync
     wallcycle_start(wcycle_, ewcMOVEX);
 
-    communicateHaloData(d_x_, HaloQuantity::HaloCoordinates, coordinatesReadyOnDeviceEvent);
+    // wait for remote co-ordinates is implicit with process-MPI as non-local stream is synchronized before MPI calls
+    // and MPI_Waitall call makes sure both neighboring ranks' non-local stream is synchronized before data transfer is initiated
+    if (GMX_THREAD_MPI && pulse_ == 0)
+    {
+        enqueueWaitRemoteCoordinatesReadyEvent(coordinatesReadyOnDeviceEvent);
+    }
+
+    float3* recvPtr = GMX_THREAD_MPI ? remoteXPtr_ : &d_x_[atomOffset_];
+    communicateHaloData(d_sendBuf_, xSendSize_, sendRankX_, recvPtr, xRecvSize_, recvRankX_);
 
     wallcycle_stop(wcycle_, ewcMOVEX);
 
@@ -320,8 +358,10 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     // ToDo: We need further refinement here as communicateHaloData includes launch time for cudamemcpyasync
     wallcycle_start(wcycle_, ewcMOVEF);
 
+    float3* recvPtr = GMX_THREAD_MPI ? remoteFPtr_ : d_recvBuf_;
+
     // Communicate halo data (in non-local stream)
-    communicateHaloData(d_f_, HaloQuantity::HaloForces, nullptr);
+    communicateHaloData(&(d_f_[atomOffset_]), fSendSize_, sendRankF_, recvPtr, fRecvSize_, recvRankF_);
 
     wallcycle_stop(wcycle_, ewcMOVEF);
 
@@ -386,65 +426,62 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     wallcycle_stop(wcycle_, ewcLAUNCH_GPU);
 }
 
-
-void GpuHaloExchange::Impl::communicateHaloData(float3*               d_ptr,
-                                                HaloQuantity          haloQuantity,
-                                                GpuEventSynchronizer* coordinatesReadyOnDeviceEvent)
+void GpuHaloExchange::Impl::communicateHaloData(float3* sendPtr,
+                                                int     sendSize,
+                                                int     sendRank,
+                                                float3* recvPtr,
+                                                int     recvSize,
+                                                int     recvRank)
 {
-
-    void* sendPtr;
-    int   sendSize;
-    void* remotePtr;
-    int   sendRank;
-    int   recvRank;
-
-    if (haloQuantity == HaloQuantity::HaloCoordinates)
+    if (GMX_THREAD_MPI)
     {
-        sendPtr   = static_cast<void*>(d_sendBuf_);
-        sendSize  = xSendSize_;
-        remotePtr = remoteXPtr_;
-        sendRank  = sendRankX_;
-        recvRank  = recvRankX_;
-
-#if GMX_MPI
-        // Wait for event from receiving task that remote coordinates are ready, and enqueue that event to stream used
-        // for subsequent data push. This avoids a race condition with the remote data being written in the previous timestep.
-        // Similarly send event to task that will push data to this task.
-        GpuEventSynchronizer* remoteCoordinatesReadyOnDeviceEvent;
-        MPI_Sendrecv(&coordinatesReadyOnDeviceEvent,
-                     sizeof(GpuEventSynchronizer*),
-                     MPI_BYTE,
-                     recvRank,
-                     0,
-                     &remoteCoordinatesReadyOnDeviceEvent,
-                     sizeof(GpuEventSynchronizer*),
-                     MPI_BYTE,
-                     sendRank,
-                     0,
-                     mpi_comm_mysim_,
-                     MPI_STATUS_IGNORE);
-        remoteCoordinatesReadyOnDeviceEvent->enqueueWaitEvent(nonLocalStream_);
-#else
-        GMX_UNUSED_VALUE(coordinatesReadyOnDeviceEvent);
-#endif
+        // no need to explicitly sync with GMX_THREAD_MPI as all operations are
+        // anyway launched in correct stream
+        communicateHaloDataWithCudaDirect(sendPtr, sendSize, sendRank, recvPtr, recvRank);
     }
     else
     {
-        sendPtr   = static_cast<void*>(&(d_ptr[atomOffset_]));
-        sendSize  = fSendSize_;
-        remotePtr = remoteFPtr_;
-        sendRank  = sendRankF_;
-        recvRank  = recvRankF_;
+        communicateHaloDataWithCudaMPI(sendPtr, sendSize, sendRank, recvPtr, recvSize, recvRank);
     }
+}
 
-    communicateHaloDataWithCudaDirect(sendPtr, sendSize, sendRank, remotePtr, recvRank);
+void GpuHaloExchange::Impl::communicateHaloDataWithCudaMPI(float3* sendPtr,
+                                                           int     sendSize,
+                                                           int     sendRank,
+                                                           float3* recvPtr,
+                                                           int     recvSize,
+                                                           int     recvRank)
+{
+    // no need to wait for haloDataReadyOnDevice event if this rank is not sending any data
+    if (sendSize > 0)
+    {
+        // wait for non local stream to complete all outstanding
+        // activities, to ensure that buffer is up-to-date in GPU memory
+        // before transferring to remote rank
+
+        // ToDo: Replace stream synchronize with event synchronize
+        nonLocalStream_.synchronize();
+    }
+
+    // perform halo exchange directly in device buffers
+#if GMX_MPI
+    MPI_Request request;
+
+    // recv remote data into halo region
+    MPI_Irecv(recvPtr, recvSize * DIM, MPI_FLOAT, recvRank, 0, mpi_comm_mysim_, &request);
+
+    // send data to remote halo region
+    MPI_Send(sendPtr, sendSize * DIM, MPI_FLOAT, sendRank, 0, mpi_comm_mysim_);
+
+    MPI_Wait(&request, MPI_STATUS_IGNORE);
+#endif
 }
 
-void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(void* sendPtr,
-                                                              int   sendSize,
-                                                              int   sendRank,
-                                                              void* remotePtr,
-                                                              int   recvRank)
+void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(float3* sendPtr,
+                                                              int     sendSize,
+                                                              int     sendRank,
+                                                              float3* remotePtr,
+                                                              int     recvRank)
 {
 
     cudaError_t stat;
@@ -474,6 +511,9 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(void* sendPtr,
     // to its stream.
     GpuEventSynchronizer* haloDataTransferRemote;
 
+    GMX_ASSERT(haloDataTransferLaunched_ != nullptr,
+               "Halo exchange requires valid event to synchronize data transfer initiated in "
+               "remote rank");
     haloDataTransferLaunched_->markEvent(nonLocalStream_);
 
     MPI_Sendrecv(&haloDataTransferLaunched_,
@@ -516,7 +556,7 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
     sendRankF_(dd->neighbor[dimIndex][0]),
     recvRankF_(dd->neighbor[dimIndex][1]),
     usePBC_(dd->ci[dd->dim[dimIndex]] == 0),
-    haloDataTransferLaunched_(new GpuEventSynchronizer()),
+    haloDataTransferLaunched_(GMX_THREAD_MPI ? new GpuEventSynchronizer() : nullptr),
     mpi_comm_mysim_(mpi_comm_mysim),
     deviceContext_(deviceContext),
     localStream_(localStream),
@@ -525,10 +565,6 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
     pulse_(pulse),
     wcycle_(wcycle)
 {
-
-    GMX_RELEASE_ASSERT(GMX_THREAD_MPI,
-                       "GPU Halo exchange is currently only supported with thread-MPI enabled");
-
     if (usePBC_ && dd->unitCellInfo.haveScrewPBC)
     {
         gmx_fatal(FARGS, "Error: screw is not yet supported in GPU halo exchange\n");