Remove MPI comm from GPU PME-PP force transfer initiation
[alexxy/gromacs.git] / src / gromacs / ewald / pme_pp_comm_gpu_impl.h
index 3d3039db6d04f1b3c267314bba0ad260338b0d4a..f62faea93374de33c2501d83d868d2fe8bd06f88 100644 (file)
@@ -58,12 +58,17 @@ class PmePpCommGpu::Impl
 public:
     /*! \brief Creates PME-PP GPU communication object.
      *
-     * \param[in] comm            Communicator used for simulation
-     * \param[in] pmeRank         Rank of PME task
-     * \param[in] deviceContext   GPU context.
-     * \param[in] deviceStream    GPU stream.
+     * \param[in] comm              Communicator used for simulation
+     * \param[in] pmeRank           Rank of PME task
+     * \param[in] pmeCpuForceBuffer Buffer for PME force in CPU memory
+     * \param[in] deviceContext     GPU context.
+     * \param[in] deviceStream      GPU stream.
      */
-    Impl(MPI_Comm comm, int pmeRank, const DeviceContext& deviceContext, const DeviceStream& deviceStream);
+    Impl(MPI_Comm                comm,
+         int                     pmeRank,
+         std::vector<gmx::RVec>& pmeCpuForceBuffer,
+         const DeviceContext&    deviceContext,
+         const DeviceStream&     deviceStream);
     ~Impl();
 
     /*! \brief Perform steps required when buffer size changes
@@ -115,10 +120,9 @@ private:
     /*! \brief Pull force buffer directly from GPU memory on PME
      * rank to either GPU or CPU memory on PP task using CUDA
      * Memory copy. This method is used with Thread-MPI.
-     * \param[out] recvPtr CPU buffer to receive PME force data
      * \param[in] receivePmeForceToGpu Whether receive is to GPU, otherwise CPU
      */
-    void receiveForceFromPmeCudaDirect(float3* recvPtr, bool receivePmeForceToGpu);
+    void receiveForceFromPmeCudaDirect(bool receivePmeForceToGpu);
 
     /*! \brief Pull force buffer directly from GPU memory on PME
      * rank to either GPU or CPU memory on PP task using CUDA-aware
@@ -160,6 +164,8 @@ private:
     MPI_Comm comm_;
     //! Rank of PME task
     int pmeRank_ = -1;
+    //! Buffer for PME force on CPU
+    std::vector<gmx::RVec>& pmeCpuForceBuffer_;
     //! Buffer for staging PME force on GPU
     DeviceBuffer<gmx::RVec> d_pmeForces_;
     //! number of atoms in PME force staging array