Fix cycle counters for "comm.coord" and "Wait + Comm. F" to support GPU halo exchange...
[alexxy/gromacs.git] / src / gromacs / domdec / gpuhaloexchange_impl.cu
index 0eebc6c0e4d3868e605a010b95e7792a4fcbdfd4..cbcff3defbc7274da10cbbbd2119fdb4cf129476 100644 (file)
@@ -61,6 +61,7 @@
 #include "gromacs/gpu_utils/vectype_ops.cuh"
 #include "gromacs/math/vectypes.h"
 #include "gromacs/pbcutil/ishift.h"
+#include "gromacs/timing/wallcycle.h"
 #include "gromacs/utility/gmxmpi.h"
 
 #include "domdec_internal.h"
@@ -205,6 +206,9 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
         coordinatesReadyOnDeviceEvent->enqueueWaitEvent(nonLocalStream_);
     }
 
+    wallcycle_start(wcycle_, ewcLAUNCH_GPU);
+    wallcycle_sub_start(wcycle_, ewcsLAUNCH_GPU_MOVEX);
+
     // launch kernel to pack send buffer
     KernelLaunchConfig config;
     config.blockSize[0]     = c_threadsPerBlock;
@@ -243,8 +247,17 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
                         "Domdec GPU Apply X Halo Exchange", kernelArgs);
     }
 
+    wallcycle_sub_stop(wcycle_, ewcsLAUNCH_GPU_MOVEX);
+    wallcycle_stop(wcycle_, ewcLAUNCH_GPU);
+
+    // Consider time spent in communicateHaloData as Comm.X counter
+    // ToDo: We need further refinement here as communicateHaloData includes launch time for cudamemcpyasync
+    wallcycle_start(wcycle_, ewcMOVEX);
+
     communicateHaloData(d_x_, HaloQuantity::HaloCoordinates, coordinatesReadyOnDeviceEvent);
 
+    wallcycle_stop(wcycle_, ewcMOVEX);
+
     return;
 }
 
@@ -252,10 +265,18 @@ void GpuHaloExchange::Impl::communicateHaloCoordinates(const matrix          box
 // and before the local buffer operations. It operates in the non-local stream.
 void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
 {
+    // Consider time spent in communicateHaloData as Comm.F counter
+    // ToDo: We need further refinement here as communicateHaloData includes launch time for cudamemcpyasync
+    wallcycle_start(wcycle_, ewcMOVEF);
 
     // Communicate halo data (in non-local stream)
     communicateHaloData(d_f_, HaloQuantity::HaloForces, nullptr);
 
+    wallcycle_stop(wcycle_, ewcMOVEF);
+
+    wallcycle_start_nocount(wcycle_, ewcLAUNCH_GPU);
+    wallcycle_sub_start(wcycle_, ewcsLAUNCH_GPU_MOVEF);
+
     float3* d_f = d_f_;
 
     if (pulse_ == (dd_->comm->cd[0].numPulses() - 1))
@@ -313,6 +334,9 @@ void GpuHaloExchange::Impl::communicateHaloForces(bool accumulateForces)
     {
         fReadyOnDevice_.markEvent(nonLocalStream_);
     }
+
+    wallcycle_sub_stop(wcycle_, ewcsLAUNCH_GPU_MOVEF);
+    wallcycle_stop(wcycle_, ewcLAUNCH_GPU);
 }
 
 
@@ -385,6 +409,7 @@ void GpuHaloExchange::Impl::communicateHaloDataWithCudaDirect(void* sendPtr,
     {
         stat = cudaMemcpyAsync(remotePtr, sendPtr, sendSize * DIM * sizeof(float),
                                cudaMemcpyDeviceToDevice, nonLocalStream_.stream());
+
         CU_RET_ERR(stat, "cudaMemcpyAsync on GPU Domdec CUDA direct data transfer failed");
     }
 
@@ -419,7 +444,8 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
                             const DeviceContext& deviceContext,
                             const DeviceStream&  localStream,
                             const DeviceStream&  nonLocalStream,
-                            int                  pulse) :
+                            int                  pulse,
+                            gmx_wallcycle*       wcycle) :
     dd_(dd),
     sendRankX_(dd->neighbor[0][1]),
     recvRankX_(dd->neighbor[0][0]),
@@ -431,7 +457,8 @@ GpuHaloExchange::Impl::Impl(gmx_domdec_t*        dd,
     deviceContext_(deviceContext),
     localStream_(localStream),
     nonLocalStream_(nonLocalStream),
-    pulse_(pulse)
+    pulse_(pulse),
+    wcycle_(wcycle)
 {
 
     GMX_RELEASE_ASSERT(GMX_THREAD_MPI,
@@ -466,8 +493,9 @@ GpuHaloExchange::GpuHaloExchange(gmx_domdec_t*        dd,
                                  const DeviceContext& deviceContext,
                                  const DeviceStream&  localStream,
                                  const DeviceStream&  nonLocalStream,
-                                 int                  pulse) :
-    impl_(new Impl(dd, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse))
+                                 int                  pulse,
+                                 gmx_wallcycle*       wcycle) :
+    impl_(new Impl(dd, mpi_comm_mysim, deviceContext, localStream, nonLocalStream, pulse, wcycle))
 {
 }