Pass the GPU streams to StatePropagatorDataGpu constructor
authorArtem Zhmurov <zhmurov@gmail.com>
Fri, 27 Sep 2019 20:29:30 +0000 (22:29 +0200)
committerArtem Zhmurov <zhmurov@gmail.com>
Mon, 7 Oct 2019 16:51:29 +0000 (18:51 +0200)
Now the StatePropagatorDataGpu has a local copy of all GPU streams and
manages the update stream. This will allow to select the specific stream
for a specific copy event in the follow-ups. The update stream is now
created in the constructor of the StatePropagatorDataGPU object, which
is a temporary solution until there is a separate device stream manager
(#3115).

Notes:

 - The current implementation where StatePropagatorDataGpu is also used
   on PME-only ranks, where many of the streams do not exist, without
   any restriction on the methods which would require these streams is a
   weakness of the design that will be dealt with in follow-up
 - The OpenCL builds unconditionally use PME stream/context, since for
   these this object is only used when the initial coordinates are copied.
 - The update stream is created in the constructor, whereas the rest of
   the streams is passed as arguments. This asymmentry will be removed
   with introduction of the centralized management of context/streams.

Refs. #2816.

Change-Id: Ia9b1cabd1d3d4942dba8465c716bf644037581e7

15 files changed:
src/gromacs/ewald/pme_only.cpp
src/gromacs/ewald/tests/pmegathertest.cpp
src/gromacs/ewald/tests/pmesplinespreadtest.cpp
src/gromacs/ewald/tests/pmetestcommon.cpp
src/gromacs/ewald/tests/pmetestcommon.h
src/gromacs/mdlib/update_constrain_cuda.h
src/gromacs/mdlib/update_constrain_cuda_impl.cpp
src/gromacs/mdlib/update_constrain_cuda_impl.cu
src/gromacs/mdlib/update_constrain_cuda_impl.h
src/gromacs/mdrun/md.cpp
src/gromacs/mdrun/runner.cpp
src/gromacs/mdtypes/state_propagator_data_gpu.h
src/gromacs/mdtypes/state_propagator_data_gpu_impl.cpp
src/gromacs/mdtypes/state_propagator_data_gpu_impl.h
src/gromacs/mdtypes/state_propagator_data_gpu_impl_gpu.cpp

index 01518afbb709fd719717e3c37be3549ba59a477f..6e202ad32f4daf27561d5e46713d6eca8c6041a5 100644 (file)
@@ -556,8 +556,15 @@ int gmx_pmeonly(struct gmx_pme_t *pme,
         changePinningPolicy(&pme_pp->x, pme_get_pinning_policy());
     }
 
-    // Unconditionally initialize the StatePropagatorDataGpu object to get more verbose message if it is used from CPU builds
-    auto stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(commandStream, deviceContext, GpuApiCallBehavior::Sync, paddingSize);
+    std::unique_ptr<gmx::StatePropagatorDataGpu> stateGpu;
+    if (useGpuForPme)
+    {
+        // TODO: The local and non-local nonbonded streams are passed as nullptrs, since they will be not used for the GPU buffer
+        //       management in PME only ranks. Make the constructor safer.
+        stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(commandStream, nullptr, nullptr,
+                                                                 deviceContext, GpuApiCallBehavior::Sync, paddingSize);
+    }
+
 
     clear_nrnb(mynrnb);
 
index d2afc1e324f2022165d1ec2e3c66ce86491edc23..077970b4a942295ebbf882e91bc311e21254c30c 100644 (file)
@@ -407,9 +407,10 @@ class PmeGatherTest : public ::testing::TestWithParam<GatherInputParameters>
                                           (inputForceTreatment == PmeForceOutputHandling::ReduceWithInput) ? "with reduction" : "without reduction"
                                           ));
 
-                PmeSafePointer         pmeSafe  = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
-                StatePropagatorDataGpu stateGpu = makeStatePropagatorDataGpu(*pmeSafe.get());
-                pmeInitAtoms(pmeSafe.get(), &stateGpu, codePath, inputAtomData.coordinates, inputAtomData.charges);
+                PmeSafePointer pmeSafe = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
+                std::unique_ptr<StatePropagatorDataGpu> stateGpu = (codePath == CodePath::GPU) ? makeStatePropagatorDataGpu(*pmeSafe.get()) : nullptr;
+
+                pmeInitAtoms(pmeSafe.get(), stateGpu.get(), codePath, inputAtomData.coordinates, inputAtomData.charges);
 
                 /* Setting some more inputs */
                 pmeSetRealGrid(pmeSafe.get(), codePath, nonZeroGridValues);
index 136c08b69733a400c73d56c4831f89fb0e357d1f..4a30d4d1c52a9c6c21c6701b91514192f5cd2b56 100644 (file)
@@ -122,7 +122,6 @@ class PmeSplineAndSpreadTest : public ::testing::TestWithParam<SplineAndSpreadIn
 
             for (const auto &context : getPmeTestEnv()->getHardwareContexts())
             {
-                std::shared_ptr<StatePropagatorDataGpu> stateGpu;
                 CodePath   codePath       = context->getCodePath();
                 const bool supportedInput = pmeSupportsInputForMode(*getPmeTestEnv()->hwinfo(), &inputRec, codePath);
                 if (!supportedInput)
@@ -146,9 +145,10 @@ class PmeSplineAndSpreadTest : public ::testing::TestWithParam<SplineAndSpreadIn
 
                     /* Running the test */
 
-                    PmeSafePointer         pmeSafe  = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
-                    StatePropagatorDataGpu stateGpu = makeStatePropagatorDataGpu(*pmeSafe.get());
-                    pmeInitAtoms(pmeSafe.get(), &stateGpu, codePath, coordinates, charges);
+                    PmeSafePointer pmeSafe = pmeInitWrapper(&inputRec, codePath, context->getDeviceInfo(), context->getPmeGpuProgram(), box);
+                    std::unique_ptr<StatePropagatorDataGpu> stateGpu = (codePath == CodePath::GPU) ? makeStatePropagatorDataGpu(*pmeSafe.get()) : nullptr;
+
+                    pmeInitAtoms(pmeSafe.get(), stateGpu.get(), codePath, coordinates, charges);
 
                     const bool     computeSplines = (option.first == PmeSplineAndSpreadOptions::SplineOnly) || (option.first == PmeSplineAndSpreadOptions::SplineAndSpreadUnified);
                     const bool     spreadCharges  = (option.first == PmeSplineAndSpreadOptions::SpreadOnly) || (option.first == PmeSplineAndSpreadOptions::SplineAndSpreadUnified);
index c3e538eb00372ce1a7724e985749ec28c5854aea..c82c0c685d93281de9b78ea65765bc210b70dfe0 100644 (file)
@@ -168,14 +168,14 @@ PmeSafePointer pmeInitEmpty(const t_inputrec         *inputRec,
 }
 
 //! Make a GPU state-propagator manager
-StatePropagatorDataGpu
+std::unique_ptr<StatePropagatorDataGpu>
 makeStatePropagatorDataGpu(const gmx_pme_t &pme)
 {
     // TODO: Pin the host buffer and use async memory copies
-    return StatePropagatorDataGpu(pme_gpu_get_device_stream(&pme),
-                                  pme_gpu_get_device_context(&pme),
-                                  GpuApiCallBehavior::Sync,
-                                  pme_gpu_get_padding_size(&pme));
+    return std::make_unique<StatePropagatorDataGpu>(pme_gpu_get_device_stream(&pme), nullptr, nullptr,
+                                                    pme_gpu_get_device_context(&pme),
+                                                    GpuApiCallBehavior::Sync,
+                                                    pme_gpu_get_padding_size(&pme));
 }
 
 //! PME initialization with atom data
index dd27865ab7034b957190fef5afe0ce91aae13cf3..0f00f57e3855da4438818efd49c9f72ff65cd3f4 100644 (file)
@@ -134,7 +134,7 @@ PmeSafePointer pmeInitEmpty(const t_inputrec *inputRec,
                             const Matrix3x3 &box = {{1.0F, 0.0F, 0.0F, 0.0F, 1.0F, 0.0F, 0.0F, 0.0F, 1.0F}},
                             real ewaldCoeff_q = 0.0F, real ewaldCoeff_lj = 0.0F);
 //! Make a GPU state-propagator manager
-StatePropagatorDataGpu
+std::unique_ptr<StatePropagatorDataGpu>
 makeStatePropagatorDataGpu(const gmx_pme_t &pme);
 //! PME initialization with atom data and system box
 void pmeInitAtoms(gmx_pme_t               *pme,
index 25c9e7af38595fc69ddea6014e82b8ad22adb7da..f2035c687f4adf71ea8ed5bd8de42bf473e0dc39 100644 (file)
@@ -128,6 +128,9 @@ class UpdateConstrainCuda
          */
         void setPbc(const t_pbc *pbc);
 
+        /*! \brief Synchronize the device stream.
+         */
+        void synchronizeStream();
 
     private:
         class Impl;
index b5823ed97fbbe55612fb4fc6a5dd237719192068..f8f9c45a522ad71a98371f65cd2d6c00aa4950e5 100644 (file)
@@ -93,6 +93,11 @@ void UpdateConstrainCuda::setPbc(gmx_unused const t_pbc *pbc)
     GMX_ASSERT(false, "A CPU stub for UpdateConstrain was called insted of the correct implementation.");
 }
 
+void UpdateConstrainCuda::synchronizeStream()
+{
+    GMX_ASSERT(false, "A CPU stub for UpdateConstrain was called insted of the correct implementation.");
+}
+
 }      // namespace gmx
 
 #endif /* GMX_GPU != GMX_GPU_CUDA */
index b22f365358da3a593ae10f53b233dd89e7fb4ea1..e4667720947eb63fc93ce77a5f3442495a435558 100644 (file)
@@ -162,6 +162,11 @@ void UpdateConstrainCuda::Impl::setPbc(const t_pbc *pbc)
     settleCuda_->setPbc(pbc);
 }
 
+void UpdateConstrainCuda::Impl::synchronizeStream()
+{
+    gpuStreamSynchronize(commandStream_);
+}
+
 UpdateConstrainCuda::UpdateConstrainCuda(const t_inputrec  &ir,
                                          const gmx_mtop_t  &mtop,
                                          const void        *commandStream)
@@ -201,4 +206,9 @@ void UpdateConstrainCuda::setPbc(const t_pbc *pbc)
     impl_->setPbc(pbc);
 }
 
+void UpdateConstrainCuda::synchronizeStream()
+{
+    impl_->synchronizeStream();
+}
+
 } //namespace gmx
index 652dd84eb6e44fc93eb57c57a0a2f79095b62a9f..85c6fb0dc4c29e6059dc4c2782b44d10ca5d71d6 100644 (file)
@@ -132,6 +132,10 @@ class UpdateConstrainCuda::Impl
          */
         void setPbc(const t_pbc *pbc);
 
+        /*! \brief Synchronize the device stream.
+         */
+        void synchronizeStream();
+
     private:
 
         //! CUDA stream
index 99f2a4f0de72ac37dba5f2bb6b090e387e47b9cc..2cebfd79206bb9aacad6223e2cc5211fe91aca19 100644 (file)
@@ -352,7 +352,7 @@ void gmx::LegacySimulator::do_md()
             GMX_LOG(mdlog.info).asParagraph().
                 appendText("Updating coordinates on the GPU.");
         }
-        integrator = std::make_unique<UpdateConstrainCuda>(*ir, *top_global, nullptr);
+        integrator = std::make_unique<UpdateConstrainCuda>(*ir, *top_global, fr->stateGpu->getUpdateStream());
     }
 
     if (useGpuForPme || (useGpuForNonbonded && useGpuForBufferOps) || useGpuForUpdate)
@@ -1241,7 +1241,9 @@ void gmx::LegacySimulator::do_md()
                                   doPressureCouple, ir->nstpcouple*ir->delta_t, M);
             stateGpu->copyCoordinatesFromGpu(ArrayRef<RVec>(state->x), StatePropagatorDataGpu::AtomLocality::All);
             stateGpu->copyVelocitiesFromGpu(state->v, StatePropagatorDataGpu::AtomLocality::All);
-            stateGpu->synchronizeStream();
+            // Synchronize the update stream.
+            // TODO: Replace with event-based synchronization.
+            integrator->synchronizeStream();
         }
         else
         {
index f7f24405b553349924a728a8d2232533b272355d..92057c8320fe263aaad67100d7d79f35a817ab3e 100644 (file)
@@ -1501,26 +1501,32 @@ int Mdrunner::mdrunner()
                                                          fcd->orires.nr != 0,
                                                          fcd->disres.nsystems != 0);
 
-        const void *commandStream = ((GMX_GPU == GMX_GPU_OPENCL) && thisRankHasPmeGpuTask) ? pme_gpu_get_device_stream(fr->pmedata) : nullptr;
-        const void *deviceContext = ((GMX_GPU == GMX_GPU_OPENCL) && thisRankHasPmeGpuTask) ? pme_gpu_get_device_context(fr->pmedata) : nullptr;
-        const int   paddingSize   = pme_gpu_get_padding_size(fr->pmedata);
-
-        const bool  inputIsCompatibleWithModularSimulator = ModularSimulator::isInputCompatible(
+        const bool inputIsCompatibleWithModularSimulator = ModularSimulator::isInputCompatible(
                     false,
                     inputrec, doRerun, vsite.get(), ms, replExParams,
                     fcd, static_cast<int>(filenames.size()), filenames.data(),
                     &observablesHistory, membed);
 
-        const bool          useModularSimulator = inputIsCompatibleWithModularSimulator && !(getenv("GMX_DISABLE_MODULAR_SIMULATOR") != nullptr);
-        GpuApiCallBehavior  transferKind        = (inputrec->eI == eiMD && !doRerun && !useModularSimulator) ? GpuApiCallBehavior::Async : GpuApiCallBehavior::Sync;
+        const bool useModularSimulator = inputIsCompatibleWithModularSimulator && !(getenv("GMX_DISABLE_MODULAR_SIMULATOR") != nullptr);
 
-        // We initialize GPU state even for the CPU runs so we will have a more verbose
-        // error if someone will try accessing it from the CPU codepath
-        gmx::StatePropagatorDataGpu stateGpu(commandStream,
-                                             deviceContext,
-                                             transferKind,
-                                             paddingSize);
-        fr->stateGpu = &stateGpu;
+        std::unique_ptr<gmx::StatePropagatorDataGpu> stateGpu;
+        if (gpusWereDetected && ((useGpuForPme && thisRankHasDuty(cr, DUTY_PME)) || useGpuForUpdate))
+        {
+            const void         *pmeStream      = pme_gpu_get_device_stream(fr->pmedata);
+            const void         *localStream    = fr->nbv->gpu_nbv != nullptr ? Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, Nbnxm::InteractionLocality::Local) : nullptr;
+            const void         *nonLocalStream = fr->nbv->gpu_nbv != nullptr ? Nbnxm::gpu_get_command_stream(fr->nbv->gpu_nbv, Nbnxm::InteractionLocality::NonLocal) : nullptr;
+            const void         *deviceContext  = pme_gpu_get_device_context(fr->pmedata);
+            const int           paddingSize    = pme_gpu_get_padding_size(fr->pmedata);
+            GpuApiCallBehavior  transferKind   = (inputrec->eI == eiMD && !doRerun && !useModularSimulator) ? GpuApiCallBehavior::Async : GpuApiCallBehavior::Sync;
+
+            stateGpu = std::make_unique<gmx::StatePropagatorDataGpu>(pmeStream,
+                                                                     localStream,
+                                                                     nonLocalStream,
+                                                                     deviceContext,
+                                                                     transferKind,
+                                                                     paddingSize);
+            fr->stateGpu = stateGpu.get();
+        }
 
         // TODO This is not the right place to manage the lifetime of
         // this data structure, but currently it's the easiest way to
index 212fde53dd32bcb3cd52eb64b162a2d4bbbcd119..ff43c4a807c6f03e920c6d1105b4d32f02682c1b 100644 (file)
@@ -96,12 +96,16 @@ class StatePropagatorDataGpu
          * \todo A DeviceContext object is visible in CPU parts of the code so we
          *       can stop passing a void*.
          *
-         *  \param[in] commandStream  GPU stream, nullptr allowed.
-         *  \param[in] deviceContext  GPU context, nullptr allowed.
-         *  \param[in] transferKind   H2D/D2H transfer call behavior (synchronous or not).
-         *  \param[in] paddingSize    Padding size for coordinates buffer.
-         */
-        StatePropagatorDataGpu(const void        *commandStream,
+         *  \param[in] pmeStream       Device PME stream, nullptr allowed.
+         *  \param[in] localStream     Device NBNXM local stream, nullptr allowed.
+         *  \param[in] nonLocalStream  Device NBNXM non-local stream, nullptr allowed.
+         *  \param[in] deviceContext   Device context, nullptr allowed.
+         *  \param[in] transferKind    H2D/D2H transfer call behavior (synchronous or not).
+         *  \param[in] paddingSize     Padding size for coordinates buffer.
+         */
+        StatePropagatorDataGpu(const void        *pmeStream,
+                               const void        *localStream,
+                               const void        *nonLocalStream,
                                const void        *deviceContext,
                                GpuApiCallBehavior transferKind,
                                int                paddingSize);
@@ -202,9 +206,14 @@ class StatePropagatorDataGpu
          */
         void copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  h_f,
                                AtomLocality              atomLocality);
-        /*! \brief Synchronize the underlying GPU stream
+
+        /*! \brief Getter for the update stream.
+         *
+         *  \todo This is temporary here, until the management of this stream is taken over.
+         *
+         *  \returns The device command stream to use in update-constraints.
          */
-        void synchronizeStream();
+        void* getUpdateStream();
 
         /*! \brief Getter for the number of local atoms.
          *
index d66cfc552fefecaa184a2312004bef76b563786a..6a7dd647307fd47f0302fa9fb6da58f502364ba7 100644 (file)
@@ -54,10 +54,12 @@ class StatePropagatorDataGpu::Impl
 {
 };
 
-StatePropagatorDataGpu::StatePropagatorDataGpu(const void *       /* commandStream */,
-                                               const void *       /* deviceContext */,
-                                               GpuApiCallBehavior /* transferKind  */,
-                                               int                /* paddingSize   */)
+StatePropagatorDataGpu::StatePropagatorDataGpu(const void *       /* pmeStream       */,
+                                               const void *       /* localStream     */,
+                                               const void *       /* nonLocalStream  */,
+                                               const void *       /* deviceContext   */,
+                                               GpuApiCallBehavior /* transferKind    */,
+                                               int                /* paddingSize     */)
     : impl_(nullptr)
 {
 }
@@ -136,9 +138,10 @@ void StatePropagatorDataGpu::copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  /* h_f
     GMX_ASSERT(false, "A CPU stub method from GPU state propagator data was called insted of one from GPU implementation.");
 }
 
-void StatePropagatorDataGpu::synchronizeStream()
+void* StatePropagatorDataGpu::getUpdateStream()
 {
     GMX_ASSERT(false, "A CPU stub method from GPU state propagator data was called insted of one from GPU implementation.");
+    return nullptr;
 }
 
 int StatePropagatorDataGpu::numAtomsLocal()
index f32d6df9debf731fbf64c6615b285b4451b32677..16aead662a809c097ae04014741780de5d406435 100644 (file)
@@ -80,12 +80,16 @@ class StatePropagatorDataGpu::Impl
          * \todo A DeviceContext object is visible in CPU parts of the code so we
          *       can stop passing a void*.
          *
-         *  \param[in] commandStream  GPU stream, nullptr allowed.
-         *  \param[in] deviceContext  GPU context, nullptr allowed.
-         *  \param[in] transferKind   H2D/D2H transfer call behavior (synchronous or not).
-         *  \param[in] paddingSize    Padding size for coordinates buffer.
+         *  \param[in] pmeStream       Device PME stream, nullptr allowed.
+         *  \param[in] localStream     Device NBNXM local stream, nullptr allowed.
+         *  \param[in] nonLocalStream  Device NBNXM non-local stream, nullptr allowed.
+         *  \param[in] deviceContext   Device context, nullptr allowed.
+         *  \param[in] transferKind    H2D/D2H transfer call behavior (synchronous or not).
+         *  \param[in] paddingSize     Padding size for coordinates buffer.
          */
-        Impl(const void        *commandStream,
+        Impl(const void        *pmeStream,
+             const void        *localStream,
+             const void        *nonLocalStream,
              const void        *deviceContext,
              GpuApiCallBehavior transferKind,
              int                paddingSize);
@@ -98,6 +102,9 @@ class StatePropagatorDataGpu::Impl
          * The coordinates buffer is reallocated with the padding added at the end. The
          * size of padding is set by the constructor.
          *
+         * \note The PME requires clearing of the padding, which is done in the pmeStream_.
+         *       Hence the pmeStream_ should be created in the gpuContext_.
+         *
          *  \param[in] numAtomsLocal  Number of atoms in local domain.
          *  \param[in] numAtomsAll    Total number of atoms to handle.
          */
@@ -184,9 +191,13 @@ class StatePropagatorDataGpu::Impl
         void copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  h_f,
                                AtomLocality              atomLocality);
 
-        /*! \brief Synchronize the underlying GPU stream
+        /*! \brief Getter for the update stream.
+         *
+         *  \todo This is temporary here, until the management of this stream is taken over.
+         *
+         *  \returns The device command stream to use in update-constraints.
          */
-        void synchronizeStream();
+        void* getUpdateStream();
 
         /*! \brief Getter for the number of local atoms.
          *
@@ -202,10 +213,14 @@ class StatePropagatorDataGpu::Impl
 
     private:
 
-        /*! \brief GPU stream.
-         * \todo The stream should be set to non-nullptr once the synchronization points are restored
-         */
-        CommandStream        commandStream_              = nullptr;
+        //! GPU PME stream.
+        CommandStream        pmeStream_                  = nullptr;
+        //! GPU NBNXM local stream.
+        CommandStream        localStream_                = nullptr;
+        //! GPU NBNXM non-local stream
+        CommandStream        nonLocalStream_             = nullptr;
+        //! GPU Update-constreaints stream.
+        CommandStream        updateStream_               = nullptr;
         /*! \brief GPU context (for OpenCL builds)
          * \todo Make a Context class usable in CPU code
          */
@@ -245,27 +260,31 @@ class StatePropagatorDataGpu::Impl
          *
          * \todo Template on locality.
          *
-         * \param[in,out]  d_data        Device-side buffer.
-         * \param[in,out]  h_data        Host-side buffer.
-         * \param[in]      dataSize      Device-side data allocation size.
-         * \param[in]      atomLocality  If all, local or non-local ranges should be copied.
+         *  \param[out] d_data         Device-side buffer.
+         *  \param[in]  h_data         Host-side buffer.
+         *  \param[in]  dataSize       Device-side data allocation size.
+         *  \param[in]  atomLocality   If all, local or non-local ranges should be copied.
+         *  \param[in]  commandStream  GPU stream to execute copy in.
          */
-        void copyToDevice(DeviceBuffer<float>            d_data,
-                          gmx::ArrayRef<const gmx::RVec> h_data,
-                          int                            dataSize,
-                          AtomLocality                   atomLocality);
+        void copyToDevice(DeviceBuffer<float>                   d_data,
+                          const gmx::ArrayRef<const gmx::RVec>  h_data,
+                          int                                   dataSize,
+                          AtomLocality                          atomLocality,
+                          CommandStream                         commandStream);
 
         /*! \brief Performs the copy of data from device to host buffer.
          *
-         * \param[in,out]  h_data        Host-side buffer.
-         * \param[in,out]  d_data        Device-side buffer.
-         * \param[in]      dataSize      Device-side data allocation size.
-         * \param[in]      atomLocality  If all, local or non-local ranges should be copied.
+         *  \param[out] h_data         Host-side buffer.
+         *  \param[in]  d_data         Device-side buffer.
+         *  \param[in]  dataSize       Device-side data allocation size.
+         *  \param[in]  atomLocality   If all, local or non-local ranges should be copied.
+         *  \param[in]  commandStream  GPU stream to execute copy in.
          */
         void copyFromDevice(gmx::ArrayRef<gmx::RVec>  h_data,
                             DeviceBuffer<float>       d_data,
                             int                       dataSize,
-                            AtomLocality              atomLocality);
+                            AtomLocality              atomLocality,
+                            CommandStream             commandStream);
 };
 
 }      // namespace gmx
index 89446f247759253f88a37854c1e3a810930bb23b..3cc0bb7653bc4286265bb178ec67785732f5f135 100644 (file)
 namespace gmx
 {
 
-StatePropagatorDataGpu::Impl::Impl(gmx_unused const void *commandStream,
-                                   gmx_unused const void *deviceContext,
+StatePropagatorDataGpu::Impl::Impl(const void            *pmeStream,
+                                   const void            *localStream,
+                                   const void            *nonLocalStream,
+                                   const void            *deviceContext,
                                    GpuApiCallBehavior     transferKind,
                                    int                    paddingSize) :
     transferKind_(transferKind),
@@ -72,18 +74,29 @@ StatePropagatorDataGpu::Impl::Impl(gmx_unused const void *commandStream,
 
     GMX_RELEASE_ASSERT(getenv("GMX_USE_GPU_BUFFER_OPS") == nullptr, "GPU buffer ops are not supported in this build.");
 
-    // Set the stream-context pair for the OpenCL builds,
-    // use the nullptr stream for CUDA builds
-#if GMX_GPU == GMX_GPU_OPENCL
-    if (commandStream != nullptr)
+    if (pmeStream != nullptr)
+    {
+        pmeStream_ = *static_cast<const CommandStream*>(pmeStream);
+    }
+    if (localStream != nullptr)
+    {
+        localStream_ = *static_cast<const CommandStream*>(localStream);
+    }
+    if (nonLocalStream != nullptr)
     {
-        commandStream_ = *static_cast<const CommandStream*>(commandStream);
+        nonLocalStream_ = *static_cast<const CommandStream*>(nonLocalStream);
     }
+// The OpenCL build will never use the updateStream
+// TODO: The update stream should be created only when it is needed.
+#if GMX_GPU == GMX_GPU_CUDA
+    cudaError_t stat;
+    stat = cudaStreamCreate(&updateStream_);
+    CU_RET_ERR(stat, "CUDA stream creation failed in StatePropagatorDataGpu");
+#endif
     if (deviceContext != nullptr)
     {
         deviceContext_ = *static_cast<const DeviceContext*>(deviceContext);
     }
-#endif
 
 }
 
@@ -114,7 +127,10 @@ void StatePropagatorDataGpu::Impl::reinit(int numAtomsLocal, int numAtomsAll)
     const size_t paddingAllocationSize = numAtomsPadded - numAtomsAll_;
     if (paddingAllocationSize > 0)
     {
-        clearDeviceBufferAsync(&d_x_, DIM*numAtomsAll_, DIM*paddingAllocationSize, commandStream_);
+        // The PME stream is used here because:
+        // 1. The padding clearing is only needed by PME.
+        // 2. It is the stream that is created in the PME OpenCL context.
+        clearDeviceBufferAsync(&d_x_, DIM*numAtomsAll_, DIM*paddingAllocationSize, pmeStream_);
     }
 
     reallocateDeviceBuffer(&d_v_, DIM*numAtomsAll_, &d_vSize_, &d_vCapacity_, deviceContext_);
@@ -151,11 +167,16 @@ std::tuple<int, int> StatePropagatorDataGpu::Impl::getAtomRangesFromAtomLocality
 void StatePropagatorDataGpu::Impl::copyToDevice(DeviceBuffer<float>                   d_data,
                                                 const gmx::ArrayRef<const gmx::RVec>  h_data,
                                                 int                                   dataSize,
-                                                AtomLocality                          atomLocality)
+                                                AtomLocality                          atomLocality,
+                                                CommandStream                         commandStream)
 {
 
 #if GMX_GPU == GMX_GPU_OPENCL
     GMX_ASSERT(deviceContext_ != nullptr, "GPU context should be set in OpenCL builds.");
+    // The PME stream is used for OpenCL builds, because it is the context that it associated with the
+    // PME task which requires the coordinates managed here in OpenCL.
+    // TODO: This will have to be changed when the OpenCL implementation will be extended.
+    commandStream = pmeStream_;
 #endif
 
     GMX_UNUSED_VALUE(dataSize);
@@ -173,21 +194,25 @@ void StatePropagatorDataGpu::Impl::copyToDevice(DeviceBuffer<float>
         GMX_ASSERT(elementsStartAt + numElementsToCopy <= dataSize, "The device allocation is smaller than requested copy range.");
         GMX_ASSERT(atomsStartAt + numAtomsToCopy <= h_data.ssize(), "The host buffer is smaller than the requested copy range.");
 
-        // TODO: Use the proper stream
         copyToDeviceBuffer(&d_data, reinterpret_cast<const float *>(&h_data.data()[atomsStartAt]),
                            elementsStartAt, numElementsToCopy,
-                           commandStream_, transferKind_, nullptr);
+                           commandStream, transferKind_, nullptr);
     }
 }
 
 void StatePropagatorDataGpu::Impl::copyFromDevice(gmx::ArrayRef<gmx::RVec>  h_data,
                                                   DeviceBuffer<float>       d_data,
                                                   int                       dataSize,
-                                                  AtomLocality              atomLocality)
+                                                  AtomLocality              atomLocality,
+                                                  CommandStream             commandStream)
 {
 
 #if GMX_GPU == GMX_GPU_OPENCL
     GMX_ASSERT(deviceContext_ != nullptr, "GPU context should be set in OpenCL builds.");
+    // The PME stream is used for OpenCL builds, because it is the context that it associated with the
+    // PME task which requires the coordinates managed here in OpenCL.
+    // TODO: This will have to be changed when the OpenCL implementation will be extended.
+    commandStream = pmeStream_;
 #endif
 
     GMX_UNUSED_VALUE(dataSize);
@@ -205,11 +230,9 @@ void StatePropagatorDataGpu::Impl::copyFromDevice(gmx::ArrayRef<gmx::RVec>  h_da
         GMX_ASSERT(elementsStartAt + numElementsToCopy <= dataSize, "The device allocation is smaller than requested copy range.");
         GMX_ASSERT(atomsStartAt + numAtomsToCopy <= h_data.ssize(), "The host buffer is smaller than the requested copy range.");
 
-        // TODO: Use the proper stream
         copyFromDeviceBuffer(reinterpret_cast<float*>(&h_data.data()[atomsStartAt]), &d_data,
                              elementsStartAt, numElementsToCopy,
-                             commandStream_, transferKind_, nullptr);
-
+                             commandStream, transferKind_, nullptr);
     }
 }
 
@@ -221,13 +244,15 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getCoordinates()
 void StatePropagatorDataGpu::Impl::copyCoordinatesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_x,
                                                         AtomLocality                          atomLocality)
 {
-    copyToDevice(d_x_, h_x, d_xSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyToDevice(d_x_, h_x, d_xSize_, atomLocality, nullptr);
 }
 
 void StatePropagatorDataGpu::Impl::copyCoordinatesFromGpu(gmx::ArrayRef<gmx::RVec>  h_x,
                                                           AtomLocality              atomLocality)
 {
-    copyFromDevice(h_x, d_x_, d_xSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyFromDevice(h_x, d_x_, d_xSize_, atomLocality, nullptr);
 }
 
 
@@ -239,13 +264,15 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getVelocities()
 void StatePropagatorDataGpu::Impl::copyVelocitiesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_v,
                                                        AtomLocality                          atomLocality)
 {
-    copyToDevice(d_v_, h_v, d_vSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyToDevice(d_v_, h_v, d_vSize_, atomLocality, nullptr);
 }
 
 void StatePropagatorDataGpu::Impl::copyVelocitiesFromGpu(gmx::ArrayRef<gmx::RVec>  h_v,
                                                          AtomLocality              atomLocality)
 {
-    copyFromDevice(h_v, d_v_, d_vSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyFromDevice(h_v, d_v_, d_vSize_, atomLocality, nullptr);
 }
 
 
@@ -257,18 +284,20 @@ DeviceBuffer<float> StatePropagatorDataGpu::Impl::getForces()
 void StatePropagatorDataGpu::Impl::copyForcesToGpu(const gmx::ArrayRef<const gmx::RVec>  h_f,
                                                    AtomLocality                          atomLocality)
 {
-    copyToDevice(d_f_, h_f, d_fSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyToDevice(d_f_, h_f, d_fSize_, atomLocality, nullptr);
 }
 
 void StatePropagatorDataGpu::Impl::copyForcesFromGpu(gmx::ArrayRef<gmx::RVec>  h_f,
                                                      AtomLocality              atomLocality)
 {
-    copyFromDevice(h_f, d_f_, d_fSize_, atomLocality);
+    // TODO: Use the correct stream
+    copyFromDevice(h_f, d_f_, d_fSize_, atomLocality, nullptr);
 }
 
-void StatePropagatorDataGpu::Impl::synchronizeStream()
+void* StatePropagatorDataGpu::Impl::getUpdateStream()
 {
-    gpuStreamSynchronize(commandStream_);
+    return updateStream_;
 }
 
 int StatePropagatorDataGpu::Impl::numAtomsLocal()
@@ -283,11 +312,15 @@ int StatePropagatorDataGpu::Impl::numAtomsAll()
 
 
 
-StatePropagatorDataGpu::StatePropagatorDataGpu(const void        *commandStream,
+StatePropagatorDataGpu::StatePropagatorDataGpu(const void        *pmeStream,
+                                               const void        *localStream,
+                                               const void        *nonLocalStream,
                                                const void        *deviceContext,
                                                GpuApiCallBehavior transferKind,
                                                int                paddingSize)
-    : impl_(new Impl(commandStream,
+    : impl_(new Impl(pmeStream,
+                     localStream,
+                     nonLocalStream,
                      deviceContext,
                      transferKind,
                      paddingSize))
@@ -365,9 +398,9 @@ void StatePropagatorDataGpu::copyForcesFromGpu(gmx::ArrayRef<RVec>  h_f,
     return impl_->copyForcesFromGpu(h_f, atomLocality);
 }
 
-void StatePropagatorDataGpu::synchronizeStream()
+void* StatePropagatorDataGpu::getUpdateStream()
 {
-    return impl_->synchronizeStream();
+    return impl_->getUpdateStream();
 }
 
 int StatePropagatorDataGpu::numAtomsLocal()