Unify coordinate copy handling across GPU platforms
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu.cpp
index 564e213af9a7918fa4cc29992d040dd47685ff59..3f8e00f9953222e72934ca985fa368dea0e03449 100644 (file)
@@ -59,6 +59,7 @@
 #include "gromacs/utility/fatalerror.h"
 #include "gromacs/utility/gmxassert.h"
 #include "gromacs/utility/stringutil.h"
+#include "gromacs/ewald/pme_coordinate_receiver_gpu.h"
 
 #include "pme_gpu_internal.h"
 #include "pme_gpu_settings.h"
@@ -189,14 +190,15 @@ void pme_gpu_prepare_computation(gmx_pme_t*               pme,
     }
 }
 
-void pme_gpu_launch_spread(gmx_pme_t*            pme,
-                           GpuEventSynchronizer* xReadyOnDevice,
-                           gmx_wallcycle*        wcycle,
-                           const real            lambdaQ)
+void pme_gpu_launch_spread(gmx_pme_t*                     pme,
+                           GpuEventSynchronizer*          xReadyOnDevice,
+                           gmx_wallcycle*                 wcycle,
+                           const real                     lambdaQ,
+                           const bool                     useGpuDirectComm,
+                           gmx::PmeCoordinateReceiverGpu* pmeCoordinateReceiverGpu)
 {
     GMX_ASSERT(pme_gpu_active(pme), "This should be a GPU run of PME but it is not enabled.");
-    GMX_ASSERT(!GMX_GPU_CUDA || xReadyOnDevice || !pme->bPPnode,
-               "Need a valid xReadyOnDevice on PP+PME ranks with CUDA.");
+    GMX_ASSERT(xReadyOnDevice || !pme->bPPnode, "Need a valid xReadyOnDevice on PP+PME ranks.");
     GMX_ASSERT(pme->doCoulomb, "Only Coulomb PME can be run on GPU.");
 
     PmeGpu* pmeGpu = pme->gpu;
@@ -215,7 +217,8 @@ void pme_gpu_launch_spread(gmx_pme_t*            pme,
     const bool spreadCharges  = true;
     wallcycle_start_nocount(wcycle, WallCycleCounter::LaunchGpu);
     wallcycle_sub_start_nocount(wcycle, WallCycleSubCounter::LaunchGpuPme);
-    pme_gpu_spread(pmeGpu, xReadyOnDevice, fftgrids, computeSplines, spreadCharges, lambdaQ);
+    pme_gpu_spread(
+            pmeGpu, xReadyOnDevice, fftgrids, computeSplines, spreadCharges, lambdaQ, useGpuDirectComm, pmeCoordinateReceiverGpu);
     wallcycle_sub_stop(wcycle, WallCycleSubCounter::LaunchGpuPme);
     wallcycle_stop(wcycle, WallCycleCounter::LaunchGpu);
 }
@@ -294,7 +297,7 @@ static void sum_forces(gmx::ArrayRef<gmx::RVec> f, gmx::ArrayRef<const gmx::RVec
 {
     const int end = forceToAdd.size();
 
-    int gmx_unused nt = gmx_omp_nthreads_get(emntPME);
+    int gmx_unused nt = gmx_omp_nthreads_get(ModuleMultiThread::Pme);
 #pragma omp parallel for num_threads(nt) schedule(static)
     for (int i = 0; i < end; i++)
     {