Take over management of OpenCL context from PME and NBNXM
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cu
index 92a1d9f3d5d2d2235099eb0ae230e01331da19e0..4a44beb3e69945b3fc2517d38edbec9b656e716e 100644 (file)
@@ -54,6 +54,7 @@
 #include "gromacs/domdec/domdec_struct.h"
 #include "gromacs/domdec/gpuhaloexchange.h"
 #include "gromacs/gpu_utils/cudautils.cuh"
+#include "gromacs/gpu_utils/device_context.h"
 #include "gromacs/gpu_utils/devicebuffer.h"
 #include "gromacs/gpu_utils/gpueventsynchronizer.cuh"
 #include "gromacs/gpu_utils/typecasts.cuh"
@@ -415,11 +416,12 @@ GpuEventSynchronizer* GpuHaloExchange::Impl::getForcesReadyOnDeviceEvent()
 }
 
 /*! \brief Create Domdec GPU object */
-GpuHaloExchange::Impl::Impl(gmx_domdec_t* dd,
-                            MPI_Comm      mpi_comm_mysim,
-                            void*         localStream,
-                            void*         nonLocalStream,
-                            int           pulse) :
+GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
+                            MPI_Comm             mpi_comm_mysim,
+                            const DeviceContext& deviceContext,
+                            void*                localStream,
+                            void*                nonLocalStream,
+                            int                  pulse) :
     dd_(dd),
     sendRankX_(dd->neighbor[0][1]),
     recvRankX_(dd->neighbor[0][0]),
@@ -428,6 +430,7 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t* dd,
     usePBC_(dd->ci[dd->dim[0]] == 0),
     haloDataTransferLaunched_(new GpuEventSynchronizer()),
     mpi_comm_mysim_(mpi_comm_mysim),
+    deviceContext_(deviceContext),
     localStream_(*static_cast<cudaStream_t*>(localStream)),
     nonLocalStream_(*static_cast<cudaStream_t*>(nonLocalStream)),
     pulse_(pulse)
@@ -460,12 +463,13 @@ GpuHaloExchange::Impl::~Impl()
     delete haloDataTransferLaunched_;
 }
 
-GpuHaloExchange::GpuHaloExchange(gmx_domdec_t* dd,
-                                 MPI_Comm      mpi_comm_mysim,
-                                 void*         localStream,
-                                 void*         nonLocalStream,
-                                 int           pulse) :
-    impl_(new Impl(dd, mpi_comm_mysim, localStream, nonLocalStream, pulse))
+GpuHaloExchange::GpuHaloExchange(gmx_domdec_t*        dd,
+                                 MPI_Comm             mpi_comm_mysim,
+                                 const DeviceContext& deviceContext,
+                                 void*                localStream,
+                                 void*                nonLocalStream,
+                                 int                  pulse) :
+    impl_(new Impl(dd, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse))
 {
 }