Pipeline GPU PME Spline/Spread with PP Comms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_coordinate_receiver_gpu_impl.h
index 604079c0b0fbcead4bebf03ae2a99eea132c5962..d268091771bab038053f1b3b4ad8031da141d880 100644 (file)
@@ -52,25 +52,41 @@ class GpuEventSynchronizer;
 
 namespace gmx
 {
-/*! \internal \brief Class with interfaces and data for CUDA version of PME coordinate receiving functionality */
 
+/*! \brief Object to manage communications with a specific PP rank */
+struct PpCommManager
+{
+    //! Details of PP rank that may be updated after repartitioning
+    const PpRanks& ppRank;
+    //! Stream used communication with for PP rank
+    std::unique_ptr<DeviceStream> stream;
+    //! Synchronization event to receive from PP rank
+    GpuEventSynchronizer* sync = nullptr;
+    //! Range of atoms corresponding to PP rank
+    std::tuple<int, int> atomRange = { 0, 0 };
+};
+
+/*! \internal \brief Class with interfaces and data for CUDA version of PME coordinate receiving functionality */
 class PmeCoordinateReceiverGpu::Impl
 {
 
 public:
     /*! \brief Creates PME GPU coordinate receiver object
-     * \param[in] pmeStream       CUDA stream used for PME computations
      * \param[in] comm            Communicator used for simulation
+     * \param[in] deviceContext   GPU context
      * \param[in] ppRanks         List of PP ranks
      */
-    Impl(const DeviceStream& pmeStream, MPI_Comm comm, gmx::ArrayRef<PpRanks> ppRanks);
+    Impl(MPI_Comm comm, const DeviceContext& deviceContext, gmx::ArrayRef<const PpRanks> ppRanks);
     ~Impl();
 
     /*! \brief
-     * send coordinates buffer address to PP rank
+     * Re-initialize: set atom ranges and, for thread-MPI case,
+     * send coordinates buffer address to PP rank.
+     * This is required after repartitioning since atom ranges and
+     * buffer allocations may have changed.
      * \param[in] d_x   coordinates buffer in GPU memory
      */
-    void sendCoordinateBufferAddressToPpRanks(DeviceBuffer<RVec> d_x);
+    void reinitCoordinateReceiver(DeviceBuffer<RVec> d_x);
 
     /*! \brief
      * Receive coordinate synchronizer pointer from the PP ranks.
@@ -88,24 +104,47 @@ public:
     void launchReceiveCoordinatesFromPpCudaMpi(DeviceBuffer<RVec> recvbuf, int numAtoms, int numBytes, int ppRank);
 
     /*! \brief
-     * For lib MPI, wait for coordinates from PP ranks
-     * For thread MPI, enqueue PP co-ordinate transfer event into PME stream
+     * For lib MPI, wait for coordinates from any PP rank
+     * For thread MPI, enqueue PP co-ordinate transfer event received from PP
+     * rank determined from pipeline stage into given stream
+     * \param[in] pipelineStage  stage of pipeline corresponding to this transfer
+     * \param[in] deviceStream   stream in which to enqueue the wait event.
+     * \returns                  rank of sending PP task
+     */
+    int synchronizeOnCoordinatesFromPpRank(int pipelineStage, const DeviceStream& deviceStream);
+
+    /*! \brief Perform above synchronizeOnCoordinatesFromPpRanks for all PP ranks,
+     * enqueueing all events to a single stream
+     * \param[in] deviceStream   stream in which to enqueue the wait events.
+     */
+    void synchronizeOnCoordinatesFromAllPpRanks(const DeviceStream& deviceStream);
+
+    /*! \brief
+     * Return pointer to stream associated with specific PP rank sender index
+     * \param[in] senderIndex    Index of sender PP rank.
+     */
+    DeviceStream* ppCommStream(int senderIndex);
+
+    /*! \brief
+     * Returns range of atoms involved in communication associated with specific PP rank sender
+     * index \param[in] senderIndex    Index of sender PP rank.
+     */
+    std::tuple<int, int> ppCommAtomRange(int senderIndex);
+
+    /*! \brief
+     * Return number of PP ranks involved in PME-PP communication
      */
-    void synchronizeOnCoordinatesFromPpRanks();
+    int ppCommNumSenderRanks();
 
 private:
-    //! CUDA stream for PME operations
-    const DeviceStream& pmeStream_;
     //! communicator for simulation
     MPI_Comm comm_;
-    //! list of PP ranks
-    gmx::ArrayRef<PpRanks> ppRanks_;
-    //! vector of MPI requests
-    std::vector<MPI_Request> request_;
-    //! vector of synchronization events to receive from PP tasks
-    std::vector<GpuEventSynchronizer*> ppSync_;
-    //! counter of messages to receive
-    int recvCount_ = 0;
+    //! MPI requests, one per PP rank
+    std::vector<MPI_Request> requests_;
+    //! GPU context handle (not used in CUDA)
+    const DeviceContext& deviceContext_;
+    //! Communication manager objects corresponding to multiple sending PP ranks
+    std::vector<PpCommManager> ppCommManagers_;
 };
 
 } // namespace gmx