GPU Force Halo Exchange
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cuh
index bd125654eec47bd3440497a7cc39a579363dd66e..10d9118927bc45f0e0d867162bfea788ba7b227b 100644 (file)
@@ -68,11 +68,13 @@ class GpuHaloExchange::Impl
          *
          * \param [inout] dd                       domdec structure
          * \param [in]    mpi_comm_mysim           communicator used for simulation
+         * \param [in]    localStream              local NB CUDA stream
          * \param [in]    nonLocalStream           non-local NB CUDA stream
          * \param [in]    coordinatesOnDeviceEvent event recorded when coordinates have been copied to device
          */
         Impl(gmx_domdec_t *dd,
              MPI_Comm mpi_comm_mysim,
+             void *localStream,
              void *nonLocalStream,
              void *coordinatesOnDeviceEvent);
         ~Impl();
@@ -80,8 +82,10 @@ class GpuHaloExchange::Impl
         /*! \brief
          * (Re-) Initialization for GPU halo exchange
          * \param [in] d_coordinatesBuffer  pointer to coordinates buffer in GPU memory
+         * \param [in] d_forcesBuffer   pointer to forces buffer in GPU memory
          */
-        void reinitHalo(float3 *d_coordinatesBuffer);
+        void reinitHalo(float3 *d_coordinatesBuffer,
+                        float3 *d_forcesBuffer);
 
 
         /*! \brief
@@ -90,6 +94,11 @@ class GpuHaloExchange::Impl
          */
         void communicateHaloCoordinates(const matrix box);
 
+        /*! \brief  GPU halo exchange of force buffer
+         * \param[in] accumulateForces  True if forces should accumulate, otherwise they are set
+         */
+        void communicateHaloForces(bool accumulateForces);
+
     private:
 
         /*! \brief Data transfer wrapper for GPU halo exchange
@@ -106,11 +115,11 @@ class GpuHaloExchange::Impl
          * \param [inout] remotePtr  remote address to recv data
          * \param [in] recvRank      rank to recv data from
          */
-        void communicateHaloDataWithCudaDirect(void *sendPtr,
-                                               int   sendSize,
-                                               int   sendRank,
-                                               void* remotePtr,
-                                               int   recvRank);
+        void communicateHaloDataWithCudaDirect(void        *sendPtr,
+                                               int          sendSize,
+                                               int          sendRank,
+                                               void       * remotePtr,
+                                               int          recvRank);
 
         //! Domain decomposition object
         gmx_domdec_t               *dd_                       = nullptr;
@@ -152,8 +161,8 @@ class GpuHaloExchange::Impl
         int                         fSendSize_                = 0;
         //! recv copy size to this rank for F
         int                         fRecvSize_                = 0;
-        //! offset of local halo region
-        int                         localOffset_              = 0;
+        //! number of home atoms - offset of local halo region
+        int                         numHomeAtoms_             = 0;
         //! remote GPU coordinates buffer pointer for pushing data
         void                       *remoteXPtr_               = 0;
         //! remote GPU force buffer pointer for pushing data
@@ -166,12 +175,16 @@ class GpuHaloExchange::Impl
         GpuEventSynchronizer       *haloDataTransferLaunched_ = nullptr;
         //! MPI communicator used for simulation
         MPI_Comm                    mpi_comm_mysim_;
+        //! CUDA stream for local non-bonded calculations
+        cudaStream_t                localStream_              = nullptr;
         //! CUDA stream for non-local non-bonded calculations
         cudaStream_t                nonLocalStream_           = nullptr;
         //! Event triggered when coordinates have been copied to device
         GpuEventSynchronizer       *coordinatesOnDeviceEvent_ = nullptr;
         //! full coordinates buffer in GPU memory
         float3                     *d_x_                      = nullptr;
+        //! full forces buffer in GPU memory
+        float3                     *d_f_                      = nullptr;
 
 };