Remove single-dimension limitation from GPU halo exchange
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cu
index cbcff3defbc7274da10cbbbd2119fdb4cf129476..d519445f74d48d0dd2fbea7b1f623e3b7c8c5371 100644 (file)
@@ -136,10 +136,30 @@ void GpuHaloExchange::Impl::reinitHalo(float3* d_coordinatesBuffer, float3* d_fo
     d_x_ = d_coordinatesBuffer;
     d_f_ = d_forcesBuffer;
 
-    const gmx_domdec_comm_t&     comm    = *dd_->comm;
-    const gmx_domdec_comm_dim_t& cd      = comm.cd[0];
-    const gmx_domdec_ind_t&      ind     = cd.ind[pulse_];
-    int                          newSize = ind.nsend[nzone_ + 1];
+    const gmx_domdec_comm_t&     comm = *dd_->comm;
+    const gmx_domdec_comm_dim_t& cd   = comm.cd[dimIndex_];
+    const gmx_domdec_ind_t&      ind  = cd.ind[pulse_];
+
+    numHomeAtoms_ = comm.atomRanges.numHomeAtoms(); // offset for data recieved by this rank
+
+    // Determine receive offset for the dimension index and pulse of this halo exchange object
+    int numZoneTemp   = 1;
+    int numZone       = 0;
+    int numAtomsTotal = numHomeAtoms_;
+    for (int i = 0; i <= dimIndex_; i++)
+    {
+        int pulseMax = (i == dimIndex_) ? pulse_ : (comm.cd[i].numPulses() - 1);
+        for (int p = 0; p <= pulseMax; p++)
+        {
+            atomOffset_                     = numAtomsTotal;
+            const gmx_domdec_ind_t& indTemp = comm.cd[i].ind[p];
+            numAtomsTotal += indTemp.nrecv[numZoneTemp + 1];
+        }
+        numZone = numZoneTemp;
+        numZoneTemp += numZoneTemp;
+    }
+
+    int newSize = ind.nsend[numZone + 1];
 
     GMX_ASSERT(cd.receiveInPlace, "Out-of-place receive is not yet supported in GPU halo exchange");
 
@@ -162,28 +182,24 @@ void GpuHaloExchange::Impl::reinitHalo(float3* d_coordinatesBuffer, float3* d_fo
     fSendSize_ = xRecvSize_;
     fRecvSize_ = xSendSize_;
 
-    numHomeAtoms_ = comm.atomRanges.numHomeAtoms(); // offset for data recieved by this rank
-
-    GMX_ASSERT(ind.index.size() == h_indexMap_.size(), "Size mismatch");
-    std::copy(ind.index.begin(), ind.index.end(), h_indexMap_.begin());
-
-    copyToDeviceBuffer(&d_indexMap_, h_indexMap_.data(), 0, newSize, nonLocalStream_,
-                       GpuApiCallBehavior::Async, nullptr);
+    if (newSize > 0)
+    {
+        GMX_ASSERT(ind.index.size() == h_indexMap_.size(),
+                   "Size mismatch between domain decomposition communication index array and GPU "
+                   "halo exchange index mapping array");
+        std::copy(ind.index.begin(), ind.index.end(), h_indexMap_.begin());
 
+        copyToDeviceBuffer(&d_indexMap_, h_indexMap_.data(), 0, newSize, nonLocalStream_,
+                           GpuApiCallBehavior::Async, nullptr);
+    }
     // This rank will push data to its neighbor, so needs to know
     // the remote receive address and similarly send its receive
     // address to other neighbour. We can do this here in reinit fn
     // since the pointers will not change until the next NS step.
 
     // Coordinates buffer:
+    void* recvPtr = static_cast<void*>(&d_x_[atomOffset_]);
 #if GMX_MPI
-    int pulseOffset = 0;
-    for (int p = pulse_ - 1; p >= 0; p--)
-    {
-        pulseOffset += cd.ind[p].nrecv[nzone_ + 1];
-    }
-    //    void* recvPtr = static_cast<void*>(&d_coordinatesBuffer[numHomeAtoms_ + pulseOffset]);
-    void* recvPtr = static_cast<void*>(&d_x_[numHomeAtoms_ + pulseOffset]);
     MPI_Sendrecv(&recvPtr, sizeof(void*), MPI_BYTE, recvRankX_, 0, &remoteXPtr_, sizeof(void*),
                  MPI_BYTE, sendRankX_, 0, mpi_comm_mysim_, MPI_STATUS_IGNORE);
 
@@ -228,12 +244,9 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
     // performing pressure coupling. So, for simplicity, the box
     // is used every step to pass the shift vector as an argument of
     // the packing kernel.
-    //
-    // Because only one-dimensional DD is supported, the coordinate
-    // shift only needs to handle that dimension.
-    const int    dimensionIndex = dd_->dim[0];
-    const float3 coordinateShift{ box[dimensionIndex][XX], box[dimensionIndex][YY],
-                                  box[dimensionIndex][ZZ] };
+    const int    boxDimensionIndex = dd_->dim[dimIndex_];
+    const float3 coordinateShift{ box[boxDimensionIndex][XX], box[boxDimensionIndex][YY],
+                                  box[boxDimensionIndex][ZZ] };
 
     // Avoid launching kernel when there is no work to do
     if (size > 0)
@@ -278,8 +291,11 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     wallcycle_sub_start(wcycle_, ewcsLAUNCH_GPU_MOVEF);
 
     float3* d_f = d_f_;
-
-    if (pulse_ == (dd_->comm->cd[0].numPulses() - 1))
+    // If this is the last pulse and index (noting the force halo
+    // exchanges across multiple pulses and indices are called in
+    // reverse order) then perform the following preparation
+    // activities
+    if ((pulse_ == (dd_->comm->cd[dimIndex_].numPulses() - 1)) && (dimIndex_ == (dd_->ndim - 1)))
     {
         if (!accumulateForces)
         {
@@ -311,11 +327,11 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     const int*    indexMap = d_indexMap_;
     const int     size     = fRecvSize_;
 
-    if (pulse_ > 0)
+    if (pulse_ > 0 || dd_->ndim > 1)
     {
         // We need to accumulate rather than set, since it is possible
-        // that, in this pulse, a value could be written to a location
-        // corresponding to the halo region of a following pulse.
+        // that, in this pulse/dim, a value could be written to a location
+        // corresponding to the halo region of a following pulse/dim.
         accumulateForces = true;
     }
 
@@ -374,12 +390,7 @@ void GpuHaloExchange::Impl::communicateHaloData(float3*               d_ptr,
     }
     else
     {
-        int recvOffset = dd_->comm->atomRanges.end(DDAtomRanges::Type::Zones);
-        for (int p = pulse_; p < dd_->comm->cd[0].numPulses(); p++)
-        {
-            recvOffset -= dd_->comm->cd[0].ind[p].nrecv[nzone_ + 1];
-        }
-        sendPtr   = static_cast<void*>(&(d_ptr[recvOffset]));
+        sendPtr   = static_cast<void*>(&(d_ptr[atomOffset_]));
         sendSize  = fSendSize_;
         remotePtr = remoteFPtr_;
         sendRank  = sendRankF_;
@@ -440,6 +451,7 @@ GpuEventSynchronizer* GpuHaloExchange::Impl::getForcesReadyOnDeviceEvent()
 
 /*! \brief Create Domdec GPU object */
 GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
+                            int                  dimIndex,
                             MPI_Comm             mpi_comm_mysim,
                             const DeviceContext& deviceContext,
                             const DeviceStream&  localStream,
@@ -447,11 +459,12 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
                             int                  pulse,
                             gmx_wallcycle*       wcycle) :
     dd_(dd),
-    sendRankX_(dd->neighbor[0][1]),
-    recvRankX_(dd->neighbor[0][0]),
-    sendRankF_(dd->neighbor[0][0]),
-    recvRankF_(dd->neighbor[0][1]),
-    usePBC_(dd->ci[dd->dim[0]] == 0),
+    dimIndex_(dimIndex),
+    sendRankX_(dd->neighbor[dimIndex][1]),
+    recvRankX_(dd->neighbor[dimIndex][0]),
+    sendRankF_(dd->neighbor[dimIndex][0]),
+    recvRankF_(dd->neighbor[dimIndex][1]),
+    usePBC_(dd->ci[dd->dim[dimIndex]] == 0),
     haloDataTransferLaunched_(new GpuEventSynchronizer()),
     mpi_comm_mysim_(mpi_comm_mysim),
     deviceContext_(deviceContext),
@@ -464,11 +477,6 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
     GMX_RELEASE_ASSERT(GMX_THREAD_MPI,
                        "GPU Halo exchange is currently only supported with thread-MPI enabled");
 
-    if (dd->ndim > 1)
-    {
-        gmx_fatal(FARGS, "Error: dd->ndim > 1 is not yet supported in GPU halo exchange");
-    }
-
     if (usePBC_ && dd->unitCellInfo.haveScrewPBC)
     {
         gmx_fatal(FARGS, "Error: screw is not yet supported in GPU halo exchange\n");
@@ -489,13 +497,14 @@ GpuHaloExchange::Impl::~Impl()
 }
 
 GpuHaloExchange::GpuHaloExchange(gmx_domdec_t*        dd,
+                                 int                  dimIndex,
                                  MPI_Comm             mpi_comm_mysim,
                                  const DeviceContext& deviceContext,
                                  const DeviceStream&  localStream,
                                  const DeviceStream&  nonLocalStream,
                                  int                  pulse,
                                  gmx_wallcycle*       wcycle) :
-    impl_(new Impl(dd, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse, wcycle))
+    impl_(new Impl(dd, dimIndex, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse, wcycle))
 {
 }