Make DeviceStream into a class
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cu
index 4a44beb3e69945b3fc2517d38edbec9b656e716e..0829800111069b96822579e3b486b4750f1af0ff 100644 (file)
@@ -135,7 +135,6 @@ void GpuHaloExchange::Impl::reinitHalo(float3* d_coordinatesBuffer, float3* d_fo
     d_x_ = d_coordinatesBuffer;
     d_f_ = d_forcesBuffer;
 
-    cudaStream_t                 stream  = nonLocalStream_;
     const gmx_domdec_comm_t&     comm    = *dd_->comm;
     const gmx_domdec_comm_dim_t& cd      = comm.cd[0];
     const gmx_domdec_ind_t&      ind     = cd.ind[pulse_];
@@ -167,7 +166,7 @@ void GpuHaloExchange::Impl::reinitHalo(float3* d_coordinatesBuffer, float3* d_fo
     GMX_ASSERT(ind.index.size() == h_indexMap_.size(), "Size mismatch");
     std::copy(ind.index.begin(), ind.index.end(), h_indexMap_.begin());
 
-    copyToDeviceBuffer(&d_indexMap_, h_indexMap_.data(), 0, newSize, stream,
+    copyToDeviceBuffer(&d_indexMap_, h_indexMap_.data(), 0, newSize, nonLocalStream_,
                        GpuApiCallBehavior::Async, nullptr);
 
     // This rank will push data to its neighbor, so needs to know
@@ -215,7 +214,7 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
     config.gridSize[1]      = 1;
     config.gridSize[2]      = 1;
     config.sharedMemorySize = 0;
-    config.stream           = nonLocalStream_;
+    config.stream           = nonLocalStream_.stream();
 
     const float3* sendBuf  = d_sendBuf_;
     const float3* d_x      = d_x_;
@@ -264,7 +263,7 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
         if (!accumulateForces)
         {
             // Clear local portion of force array (in local stream)
-            cudaMemsetAsync(d_f, 0, numHomeAtoms_ * sizeof(rvec), localStream_);
+            cudaMemsetAsync(d_f, 0, numHomeAtoms_ * sizeof(rvec), localStream_.stream());
         }
 
         // ensure non-local stream waits for local stream, due to dependence on
@@ -286,7 +285,7 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     config.gridSize[1]      = 1;
     config.gridSize[2]      = 1;
     config.sharedMemorySize = 0;
-    config.stream           = nonLocalStream_;
+    config.stream           = nonLocalStream_.stream();
 
     const float3* recvBuf  = d_recvBuf_;
     const int*    indexMap = d_indexMap_;
@@ -373,8 +372,7 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(void* sendPtr,
                                                               int   recvRank)
 {
 
-    cudaError_t  stat;
-    cudaStream_t stream = nonLocalStream_;
+    cudaError_t stat;
 
     // We asynchronously push data to remote rank. The remote
     // destination pointer has already been set in the init fn.  We
@@ -386,7 +384,7 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(void* sendPtr,
     if (sendSize > 0)
     {
         stat = cudaMemcpyAsync(remotePtr, sendPtr, sendSize * DIM * sizeof(float),
-                               cudaMemcpyDeviceToDevice, stream);
+                               cudaMemcpyDeviceToDevice, nonLocalStream_.stream());
         CU_RET_ERR(stat, "cudaMemcpyAsync on GPU Domdec CUDA direct data transfer failed");
     }
 
@@ -397,13 +395,13 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(void* sendPtr,
     // to its stream.
     GpuEventSynchronizer* haloDataTransferRemote;
 
-    haloDataTransferLaunched_->markEvent(stream);
+    haloDataTransferLaunched_->markEvent(nonLocalStream_);
 
     MPI_Sendrecv(&haloDataTransferLaunched_, sizeof(GpuEventSynchronizer*), MPI_BYTE, sendRank, 0,
                  &haloDataTransferRemote, sizeof(GpuEventSynchronizer*), MPI_BYTE, recvRank, 0,
                  mpi_comm_mysim_, MPI_STATUS_IGNORE);
 
-    haloDataTransferRemote->enqueueWaitEvent(stream);
+    haloDataTransferRemote->enqueueWaitEvent(nonLocalStream_);
 #else
     GMX_UNUSED_VALUE(sendRank);
     GMX_UNUSED_VALUE(recvRank);
@@ -419,8 +417,8 @@ GpuEventSynchronizer* GpuHaloExchange::Impl::getForcesReadyOnDeviceEvent()
 GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
                             MPI_Comm             mpi_comm_mysim,
                             const DeviceContext& deviceContext,
-                            void*                localStream,
-                            void*                nonLocalStream,
+                            const DeviceStream&  localStream,
+                            const DeviceStream&  nonLocalStream,
                             int                  pulse) :
     dd_(dd),
     sendRankX_(dd->neighbor[0][1]),
@@ -431,8 +429,8 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
     haloDataTransferLaunched_(new GpuEventSynchronizer()),
     mpi_comm_mysim_(mpi_comm_mysim),
     deviceContext_(deviceContext),
-    localStream_(*static_cast<cudaStream_t*>(localStream)),
-    nonLocalStream_(*static_cast<cudaStream_t*>(nonLocalStream)),
+    localStream_(localStream),
+    nonLocalStream_(nonLocalStream),
     pulse_(pulse)
 {
 
@@ -466,8 +464,8 @@ GpuHaloExchange::Impl::~Impl()
 GpuHaloExchange::GpuHaloExchange(gmx_domdec_t*        dd,
                                  MPI_Comm             mpi_comm_mysim,
                                  const DeviceContext& deviceContext,
-                                 void*                localStream,
-                                 void*                nonLocalStream,
+                                 const DeviceStream&  localStream,
+                                 const DeviceStream&  nonLocalStream,
                                  int                  pulse) :
     impl_(new Impl(dd, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse))
 {