Pass the GPU streams to StatePropagatorDataGpu constructor
[alexxy/gromacs.git] / src / gromacs / mdtypes / state_propagator_data_gpu_impl.h
index f32d6df9debf731fbf64c6615b285b4451b32677..16aead662a809c097ae04014741780de5d406435 100644 (file)
@@ -80,12 +80,16 @@ class StatePropagatorDataGpu::Impl
          * \todo A DeviceContext object is visible in CPU parts of the code so we
          *       can stop passing a void*.
          *
-         *  \param[in] commandStream  GPU stream, nullptr allowed.
-         *  \param[in] deviceContext  GPU context, nullptr allowed.
-         *  \param[in] transferKind   H2D/D2H transfer call behavior (synchronous or not).
-         *  \param[in] paddingSize    Padding size for coordinates buffer.
+         *  \param[in] pmeStream       Device PME stream, nullptr allowed.
+         *  \param[in] localStream     Device NBNXM local stream, nullptr allowed.
+         *  \param[in] nonLocalStream  Device NBNXM non-local stream, nullptr allowed.
+         *  \param[in] deviceContext   Device context, nullptr allowed.
+         *  \param[in] transferKind    H2D/D2H transfer call behavior (synchronous or not).
+         *  \param[in] paddingSize     Padding size for coordinates buffer.
          */
-        Impl(const void        *commandStream,
+        Impl(const void        *pmeStream,
+             const void        *localStream,
+             const void        *nonLocalStream,
              const void        *deviceContext,
              GpuApiCallBehavior transferKind,
              int                paddingSize);
@@ -98,6 +102,9 @@ class StatePropagatorDataGpu::Impl
          * The coordinates buffer is reallocated with the padding added at the end. The
          * size of padding is set by the constructor.
          *
+         * \note The PME requires clearing of the padding, which is done in the pmeStream_.
+         *       Hence the pmeStream_ should be created in the gpuContext_.
+         *
          *  \param[in] numAtomsLocal  Number of atoms in local domain.
          *  \param[in] numAtomsAll    Total number of atoms to handle.
          */
@@ -184,9 +191,13 @@ class StatePropagatorDataGpu::Impl
         void copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  h_f,
                                AtomLocality              atomLocality);
 
-        /*! \brief Synchronize the underlying GPU stream
+        /*! \brief Getter for the update stream.
+         *
+         *  \todo This is temporary here, until the management of this stream is taken over.
+         *
+         *  \returns The device command stream to use in update-constraints.
          */
-        void synchronizeStream();
+        void* getUpdateStream();
 
         /*! \brief Getter for the number of local atoms.
          *
@@ -202,10 +213,14 @@ class StatePropagatorDataGpu::Impl
 
     private:
 
-        /*! \brief GPU stream.
-         * \todo The stream should be set to non-nullptr once the synchronization points are restored
-         */
-        CommandStream        commandStream_              = nullptr;
+        //! GPU PME stream.
+        CommandStream        pmeStream_                  = nullptr;
+        //! GPU NBNXM local stream.
+        CommandStream        localStream_                = nullptr;
+        //! GPU NBNXM non-local stream
+        CommandStream        nonLocalStream_             = nullptr;
+        //! GPU Update-constreaints stream.
+        CommandStream        updateStream_               = nullptr;
         /*! \brief GPU context (for OpenCL builds)
          * \todo Make a Context class usable in CPU code
          */
@@ -245,27 +260,31 @@ class StatePropagatorDataGpu::Impl
          *
          * \todo Template on locality.
          *
-         * \param[in,out]  d_data        Device-side buffer.
-         * \param[in,out]  h_data        Host-side buffer.
-         * \param[in]      dataSize      Device-side data allocation size.
-         * \param[in]      atomLocality  If all, local or non-local ranges should be copied.
+         *  \param[out] d_data         Device-side buffer.
+         *  \param[in]  h_data         Host-side buffer.
+         *  \param[in]  dataSize       Device-side data allocation size.
+         *  \param[in]  atomLocality   If all, local or non-local ranges should be copied.
+         *  \param[in]  commandStream  GPU stream to execute copy in.
          */
-        void copyToDevice(DeviceBuffer<float>            d_data,
-                          gmx::ArrayRef<const gmx::RVec> h_data,
-                          int                            dataSize,
-                          AtomLocality                   atomLocality);
+        void copyToDevice(DeviceBuffer<float>                   d_data,
+                          const gmx::ArrayRef<const gmx::RVec>  h_data,
+                          int                                   dataSize,
+                          AtomLocality                          atomLocality,
+                          CommandStream                         commandStream);
 
         /*! \brief Performs the copy of data from device to host buffer.
          *
-         * \param[in,out]  h_data        Host-side buffer.
-         * \param[in,out]  d_data        Device-side buffer.
-         * \param[in]      dataSize      Device-side data allocation size.
-         * \param[in]      atomLocality  If all, local or non-local ranges should be copied.
+         *  \param[out] h_data         Host-side buffer.
+         *  \param[in]  d_data         Device-side buffer.
+         *  \param[in]  dataSize       Device-side data allocation size.
+         *  \param[in]  atomLocality   If all, local or non-local ranges should be copied.
+         *  \param[in]  commandStream  GPU stream to execute copy in.
          */
         void copyFromDevice(gmx::ArrayRef<gmx::RVec>  h_data,
                             DeviceBuffer<float>       d_data,
                             int                       dataSize,
-                            AtomLocality              atomLocality);
+                            AtomLocality              atomLocality,
+                            CommandStream             commandStream);
 };
 
 }      // namespace gmx