Implement alternating GPU wait
[alexxy/gromacs.git] / src / gromacs / ewald / pme-gpu.cpp
index 2a29e25be207b8d35a87f93865aabbff5371c234..45948a2e02a1331caa7a70899278341b1e6a1482 100644 (file)
@@ -316,21 +316,19 @@ void pme_gpu_launch_gather(const gmx_pme_t                 *pme,
     wallcycle_stop(wcycle, ewcLAUNCH_GPU);
 }
 
-void
-pme_gpu_wait_for_gpu(const gmx_pme_t                *pme,
-                     gmx_wallcycle_t                 wcycle,
-                     gmx::ArrayRef<const gmx::RVec> *forces,
-                     matrix                          virial,
-                     real                           *energy)
+/*! \brief Reduce staged virial and energy outputs.
+ *
+ * \param[in]  pme            The PME data structure.
+ * \param[out] forces         Output forces pointer, the internal ArrayRef pointers gets assigned to it.
+ * \param[out] virial         The output virial matrix.
+ * \param[out] energy         The output energy.
+ */
+static void pme_gpu_get_staged_results(const gmx_pme_t                *pme,
+                                       gmx::ArrayRef<const gmx::RVec> *forces,
+                                       matrix                          virial,
+                                       real                           *energy)
 {
-    GMX_ASSERT(pme_gpu_active(pme), "This should be a GPU run of PME but it is not enabled.");
-
     const bool haveComputedEnergyAndVirial = pme->gpu->settings.currentFlags & GMX_PME_CALC_ENER_VIR;
-
-    wallcycle_start(wcycle, ewcWAIT_GPU_PME_GATHER);
-    pme_gpu_finish_computation(pme->gpu);
-    wallcycle_stop(wcycle, ewcWAIT_GPU_PME_GATHER);
-
     *forces = pme_gpu_get_forces(pme->gpu);
 
     if (haveComputedEnergyAndVirial)
@@ -345,3 +343,57 @@ pme_gpu_wait_for_gpu(const gmx_pme_t                *pme,
         }
     }
 }
+
+bool pme_gpu_try_finish_task(const gmx_pme_t                *pme,
+                             gmx_wallcycle_t                 wcycle,
+                             gmx::ArrayRef<const gmx::RVec> *forces,
+                             matrix                          virial,
+                             real                           *energy,
+                             GpuTaskCompletion               completionKind)
+{
+    GMX_ASSERT(pme_gpu_active(pme), "This should be a GPU run of PME but it is not enabled.");
+
+    wallcycle_start_nocount(wcycle, ewcWAIT_GPU_PME_GATHER);
+
+    if (completionKind == GpuTaskCompletion::Check)
+    {
+        // Query the PME stream for completion of all tasks enqueued and
+        // if we're not done, stop the timer before early return.
+        if (!pme_gpu_stream_query(pme->gpu))
+        {
+            wallcycle_stop(wcycle, ewcWAIT_GPU_PME_GATHER);
+            return false;
+        }
+    }
+    else
+    {
+        // Synchronize the whole PME stream at once, including D2H result transfers.
+        pme_gpu_synchronize(pme->gpu);
+    }
+    wallcycle_stop(wcycle, ewcWAIT_GPU_PME_GATHER);
+
+    // Time the final staged data handling separately with a counting call to get
+    // the call count right.
+    wallcycle_start(wcycle, ewcWAIT_GPU_PME_GATHER);
+
+    // The computation has completed, do timing accounting and resetting buffers
+    pme_gpu_update_timings(pme->gpu);
+    // TODO: move this later and launch it together with the other
+    // non-bonded tasks at the end of the step
+    pme_gpu_reinit_computation(pme->gpu);
+
+    pme_gpu_get_staged_results(pme, forces, virial, energy);
+
+    wallcycle_stop(wcycle, ewcWAIT_GPU_PME_GATHER);
+
+    return true;
+}
+
+void pme_gpu_wait_finish_task(const gmx_pme_t                *pme,
+                              gmx_wallcycle_t                 wcycle,
+                              gmx::ArrayRef<const gmx::RVec> *forces,
+                              matrix                          virial,
+                              real                           *energy)
+{
+    pme_gpu_try_finish_task(pme, wcycle, forces, virial, energy, GpuTaskCompletion::Wait);
+}