Rework GPU halo and state propagator streams and dependencies to get better overlap
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cu
index 5796b96980873f3f4036388d021aab7d9e6b2ab3..b0dab24e8a978a689a5d44dbb9f3e7d213675512 100644 (file)
@@ -200,7 +200,7 @@ void GpuHaloExchange::Impl::reinitHalo(float3* d_coordinatesBuffer, float3* d_fo
         std::copy(ind.index.begin(), ind.index.end(), h_indexMap_.begin());
 
         copyToDeviceBuffer(
-                &d_indexMap_, h_indexMap_.data(), 0, newSize, nonLocalStream_, GpuApiCallBehavior::Async, nullptr);
+                &d_indexMap_, h_indexMap_.data(), 0, newSize, *haloStream_, GpuApiCallBehavior::Async, nullptr);
     }
 
 #if GMX_MPI
@@ -270,19 +270,16 @@ void GpuHaloExchange::Impl::enqueueWaitRemoteCoordinatesReadyEvent(GpuEventSynch
                  0,
                  mpi_comm_mysim_,
                  MPI_STATUS_IGNORE);
-    remoteCoordinatesReadyOnDeviceEvent->enqueueWaitEvent(nonLocalStream_);
+    remoteCoordinatesReadyOnDeviceEvent->enqueueWaitEvent(*haloStream_);
 }
 
-void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box,
-                                                       GpuEventSynchronizer* coordinatesReadyOnDeviceEvent)
+GpuEventSynchronizer* GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix box,
+                                                                        GpuEventSynchronizer* dependencyEvent)
 {
-
     wallcycle_start(wcycle_, WallCycleCounter::LaunchGpu);
-    if (pulse_ == 0)
-    {
-        // ensure stream waits until coordinate data is available on device
-        coordinatesReadyOnDeviceEvent->enqueueWaitEvent(nonLocalStream_);
-    }
+
+    // ensure stream waits until dependency has been satisfied
+    dependencyEvent->enqueueWaitEvent(*haloStream_);
 
     wallcycle_sub_start(wcycle_, WallCycleSubCounter::LaunchGpuMoveX);
 
@@ -318,8 +315,7 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
         const auto kernelArgs = prepareGpuKernelArguments(
                 kernelFn, config, &sendBuf, &d_x, &indexMap, &size, &coordinateShift);
 
-        launchGpuKernel(
-                kernelFn, config, nonLocalStream_, nullptr, "Domdec GPU Apply X Halo Exchange", kernelArgs);
+        launchGpuKernel(kernelFn, config, *haloStream_, nullptr, "Domdec GPU Apply X Halo Exchange", kernelArgs);
     }
 
     wallcycle_sub_stop(wcycle_, WallCycleSubCounter::LaunchGpuMoveX);
@@ -331,28 +327,41 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
 
     // wait for remote co-ordinates is implicit with process-MPI as non-local stream is synchronized before MPI calls
     // and MPI_Waitall call makes sure both neighboring ranks' non-local stream is synchronized before data transfer is initiated
-    if (GMX_THREAD_MPI && pulse_ == 0)
+    if (GMX_THREAD_MPI && dimIndex_ == 0 && pulse_ == 0)
     {
-        enqueueWaitRemoteCoordinatesReadyEvent(coordinatesReadyOnDeviceEvent);
+        enqueueWaitRemoteCoordinatesReadyEvent(dependencyEvent);
     }
 
     float3* recvPtr = GMX_THREAD_MPI ? remoteXPtr_ : &d_x_[atomOffset_];
     communicateHaloData(d_sendBuf_, xSendSize_, sendRankX_, recvPtr, xRecvSize_, recvRankX_);
 
+    coordinateHaloLaunched_.markEvent(*haloStream_);
+
     wallcycle_stop(wcycle_, WallCycleCounter::MoveX);
+
+    return &coordinateHaloLaunched_;
 }
 
 // The following method should be called after non-local buffer operations,
-// and before the local buffer operations. It operates in the non-local stream.
-void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
+// and before the local buffer operations.
+void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces,
+                                                  FixedCapacityVector<GpuEventSynchronizer*, 2>* dependencyEvents)
 {
+
     // Consider time spent in communicateHaloData as Comm.F counter
     // ToDo: We need further refinement here as communicateHaloData includes launch time for cudamemcpyasync
     wallcycle_start(wcycle_, WallCycleCounter::MoveF);
 
+    while (dependencyEvents->size() > 0)
+    {
+        auto dependency = dependencyEvents->back();
+        dependency->enqueueWaitEvent(*haloStream_);
+        dependencyEvents->pop_back();
+    }
+
     float3* recvPtr = GMX_THREAD_MPI ? remoteFPtr_ : d_recvBuf_;
 
-    // Communicate halo data (in non-local stream)
+    // Communicate halo data
     communicateHaloData(&(d_f_[atomOffset_]), fSendSize_, sendRankF_, recvPtr, fRecvSize_, recvRankF_);
 
     wallcycle_stop(wcycle_, WallCycleCounter::MoveF);
@@ -361,19 +370,6 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     wallcycle_sub_start(wcycle_, WallCycleSubCounter::LaunchGpuMoveF);
 
     float3* d_f = d_f_;
-    // If this is the last pulse and index (noting the force halo
-    // exchanges across multiple pulses and indices are called in
-    // reverse order) then perform the following preparation
-    // activities
-    if ((pulse_ == (dd_->comm->cd[dimIndex_].numPulses() - 1)) && (dimIndex_ == (dd_->ndim - 1)))
-    {
-        // ensure non-local stream waits for local stream, due to dependence on
-        // the previous H2D copy of CPU forces (if accumulateForces is true)
-        // or local force clearing.
-        GpuEventSynchronizer eventLocal;
-        eventLocal.markEvent(localStream_);
-        eventLocal.enqueueWaitEvent(nonLocalStream_);
-    }
 
     // Unpack halo buffer into force array
 
@@ -405,14 +401,10 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
         const auto kernelArgs =
                 prepareGpuKernelArguments(kernelFn, config, &d_f, &recvBuf, &indexMap, &size);
 
-        launchGpuKernel(
-                kernelFn, config, nonLocalStream_, nullptr, "Domdec GPU Apply F Halo Exchange", kernelArgs);
+        launchGpuKernel(kernelFn, config, *haloStream_, nullptr, "Domdec GPU Apply F Halo Exchange", kernelArgs);
     }
 
-    if (pulse_ == 0)
-    {
-        fReadyOnDevice_.markEvent(nonLocalStream_);
-    }
+    fReadyOnDevice_.markEvent(*haloStream_);
 
     wallcycle_sub_stop(wcycle_, WallCycleSubCounter::LaunchGpuMoveF);
     wallcycle_stop(wcycle_, WallCycleCounter::LaunchGpu);
@@ -447,12 +439,12 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaMPI(float3* sendPtr,
     // no need to wait for haloDataReadyOnDevice event if this rank is not sending any data
     if (sendSize > 0)
     {
-        // wait for non local stream to complete all outstanding
+        // wait for halo stream to complete all outstanding
         // activities, to ensure that buffer is up-to-date in GPU memory
         // before transferring to remote rank
 
         // ToDo: Replace stream synchronize with event synchronize
-        nonLocalStream_.synchronize();
+        haloStream_->synchronize();
     }
 
     // perform halo exchange directly in device buffers
@@ -491,7 +483,7 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(float3* sendPtr,
                                sendPtr,
                                sendSize * DIM * sizeof(float),
                                cudaMemcpyDeviceToDevice,
-                               nonLocalStream_.stream());
+                               haloStream_->stream());
 
         CU_RET_ERR(stat, "cudaMemcpyAsync on GPU Domdec CUDA direct data transfer failed");
     }
@@ -506,7 +498,7 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(float3* sendPtr,
     GMX_ASSERT(haloDataTransferLaunched_ != nullptr,
                "Halo exchange requires valid event to synchronize data transfer initiated in "
                "remote rank");
-    haloDataTransferLaunched_->markEvent(nonLocalStream_);
+    haloDataTransferLaunched_->markEvent(*haloStream_);
 
     MPI_Sendrecv(&haloDataTransferLaunched_,
                  sizeof(GpuEventSynchronizer*), //NOLINT(bugprone-sizeof-expression)
@@ -521,7 +513,7 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(float3* sendPtr,
                  mpi_comm_mysim_,
                  MPI_STATUS_IGNORE);
 
-    haloDataTransferRemote->enqueueWaitEvent(nonLocalStream_);
+    haloDataTransferRemote->enqueueWaitEvent(*haloStream_);
 #else
     GMX_UNUSED_VALUE(sendRank);
     GMX_UNUSED_VALUE(recvRank);
@@ -538,8 +530,6 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
                             int                  dimIndex,
                             MPI_Comm             mpi_comm_mysim,
                             const DeviceContext& deviceContext,
-                            const DeviceStream&  localStream,
-                            const DeviceStream&  nonLocalStream,
                             int                  pulse,
                             gmx_wallcycle*       wcycle) :
     dd_(dd),
@@ -551,8 +541,7 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
     haloDataTransferLaunched_(GMX_THREAD_MPI ? new GpuEventSynchronizer() : nullptr),
     mpi_comm_mysim_(mpi_comm_mysim),
     deviceContext_(deviceContext),
-    localStream_(localStream),
-    nonLocalStream_(nonLocalStream),
+    haloStream_(new DeviceStream(deviceContext, DeviceStreamPriority::High, false)),
     dimIndex_(dimIndex),
     pulse_(pulse),
     wcycle_(wcycle)
@@ -580,11 +569,9 @@ GpuHaloExchange::GpuHaloExchange(gmx_domdec_t*        dd,
                                  int                  dimIndex,
                                  MPI_Comm             mpi_comm_mysim,
                                  const DeviceContext& deviceContext,
-                                 const DeviceStream&  localStream,
-                                 const DeviceStream&  nonLocalStream,
                                  int                  pulse,
                                  gmx_wallcycle*       wcycle) :
-    impl_(new Impl(dd, dimIndex, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse, wcycle))
+    impl_(new Impl(dd, dimIndex, mpi_comm_mysim, deviceContext, pulse, wcycle))
 {
 }
 
@@ -603,15 +590,16 @@ void GpuHaloExchange::reinitHalo(DeviceBuffer<RVec> d_coordinatesBuffer, DeviceB
     impl_->reinitHalo(asFloat3(d_coordinatesBuffer), asFloat3(d_forcesBuffer));
 }
 
-void GpuHaloExchange::communicateHaloCoordinates(const matrix          box,
-                                                 GpuEventSynchronizer* coordinatesReadyOnDeviceEvent)
+GpuEventSynchronizer* GpuHaloExchange::communicateHaloCoordinates(const matrix          box,
+                                                                  GpuEventSynchronizer* dependencyEvent)
 {
-    impl_->communicateHaloCoordinates(box, coordinatesReadyOnDeviceEvent);
+    return impl_->communicateHaloCoordinates(box, dependencyEvent);
 }
 
-void GpuHaloExchange::communicateHaloForces(bool accumulateForces)
+void GpuHaloExchange::communicateHaloForces(bool accumulateForces,
+                                            FixedCapacityVector<GpuEventSynchronizer*, 2>* dependencyEvents)
 {
-    impl_->communicateHaloForces(accumulateForces);
+    impl_->communicateHaloForces(accumulateForces, dependencyEvents);
 }
 
 GpuEventSynchronizer* GpuHaloExchange::getForcesReadyOnDeviceEvent()