Rework GPU halo and state propagator streams and dependencies to get better overlap
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cuh
index 6190a56b9584d0b575a78cb5a14931b265680108..c29834616a5ab4153aff0cf501467311ea8c7ee4 100644 (file)
@@ -75,8 +75,6 @@ public:
      * \param [in]    dimIndex                 the dimension index for this instance
      * \param [in]    mpi_comm_mysim           communicator used for simulation
      * \param [in]    deviceContext            GPU device context
-     * \param [in]    localStream              local NB CUDA stream
-     * \param [in]    nonLocalStream           non-local NB CUDA stream
      * \param [in]    pulse                    the communication pulse for this instance
      * \param [in]    wcycle                   The wallclock counter
      */
@@ -84,8 +82,6 @@ public:
          int                  dimIndex,
          MPI_Comm             mpi_comm_mysim,
          const DeviceContext& deviceContext,
-         const DeviceStream&  localStream,
-         const DeviceStream&  nonLocalStream,
          int                  pulse,
          gmx_wallcycle*       wcycle);
     ~Impl();
@@ -101,14 +97,17 @@ public:
     /*! \brief
      * GPU halo exchange of coordinates buffer
      * \param [in] box  Coordinate box (from which shifts will be constructed)
-     * \param [in] coordinatesReadyOnDeviceEvent event recorded when coordinates have been copied to device
+     * \param [in] dependencyEvent   Dependency event for this operation
+     * \returns                      Event recorded when this operation has been launched
      */
-    void communicateHaloCoordinates(const matrix box, GpuEventSynchronizer* coordinatesReadyOnDeviceEvent);
+    GpuEventSynchronizer* communicateHaloCoordinates(const matrix box, GpuEventSynchronizer* dependencyEvent);
 
     /*! \brief  GPU halo exchange of force buffer
-     * \param[in] accumulateForces  True if forces should accumulate, otherwise they are set
+     * \param [in] accumulateForces  True if forces should accumulate, otherwise they are set
+     * \param [in] dependencyEvents  Dependency events for this operation
      */
-    void communicateHaloForces(bool accumulateForces);
+    void communicateHaloForces(bool                                           accumulateForces,
+                               FixedCapacityVector<GpuEventSynchronizer*, 2>* dependencyEvents);
 
     /*! \brief Get the event synchronizer for the forces ready on device.
      *  \returns  The event to synchronize the stream that consumes forces on device.
@@ -150,8 +149,8 @@ private:
                                         int     recvSize,
                                         int     recvRank);
 
-    /*! \brief Exchange coordinate-ready event with neighbor ranks and enqueue wait in non-local
-     * stream \param [in] eventSync    event recorded when coordinates/forces are ready to device
+    /*! \brief Exchange coordinate-ready event with neighbor ranks and enqueue wait in halo stream
+     * \param [in] eventSync    event recorded when coordinates/forces are ready to device
      */
     void enqueueWaitRemoteCoordinatesReadyEvent(GpuEventSynchronizer* coordinatesReadyOnDeviceEvent);
 
@@ -211,10 +210,8 @@ private:
     MPI_Comm mpi_comm_mysim_;
     //! GPU context object
     const DeviceContext& deviceContext_;
-    //! CUDA stream for local non-bonded calculations
-    const DeviceStream& localStream_;
-    //! CUDA stream for non-local non-bonded calculations
-    const DeviceStream& nonLocalStream_;
+    //! CUDA stream for this halo exchange
+    DeviceStream* haloStream_;
     //! full coordinates buffer in GPU memory
     float3* d_x_ = nullptr;
     //! full forces buffer in GPU memory
@@ -229,6 +226,8 @@ private:
     gmx_wallcycle* wcycle_ = nullptr;
     //! The atom offset for receive (x) or send (f) for dimension index and pulse corresponding to this halo exchange instance
     int atomOffset_ = 0;
+    //! Event triggered when coordinate halo has been launched
+    GpuEventSynchronizer coordinateHaloLaunched_;
 };
 
 } // namespace gmx