Implement alternating GPU wait
[alexxy/gromacs.git] / src / gromacs / mdlib / nbnxn_gpu_common.h
index 63519e0daaf157975539b01fc95992282bd875bb..b11be12411a0482775b820dcf5a8f4d10379ba41 100644 (file)
@@ -55,6 +55,7 @@
 #include "nbnxn_ocl/nbnxn_ocl_types.h"
 #endif
 
+#include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdlib/nbnxn_gpu_types.h"
 #include "gromacs/pbcutil/ishift.h"
@@ -227,6 +228,10 @@ static inline void nbnxn_gpu_reduce_staged_outputs(const StagingData &nbst,
  *  nonbonded tasks have completed with the exception of the rolling pruning kernels
  *  that are accounted for during the following step.
  *
+ * NOTE: if timing with multiple GPUs (streams) becomes possible, the
+ *      counters could end up being inconsistent due to not being incremented
+ *      on some of the node when this is skipped on empty local domains!
+ *
  * \tparam     GpuTimers         GPU timers type
  * \tparam     GpuPairlist       Pair list type
  * \param[out] timings           Pointer to the NB GPU timings data
@@ -294,27 +299,35 @@ static inline void nbnxn_gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timin
     }
 }
 
-// Documented in nbnxn_gpu.h
-void nbnxn_gpu_wait_for_gpu(gmx_nbnxn_gpu_t *nb,
-                            int              flags,
-                            int              aloc,
-                            real            *e_lj,
-                            real            *e_el,
-                            rvec            *fshift)
+bool nbnxn_gpu_try_finish_task(gmx_nbnxn_gpu_t  *nb,
+                               int               flags,
+                               int               aloc,
+                               real             *e_lj,
+                               real             *e_el,
+                               rvec             *fshift,
+                               GpuTaskCompletion completionKind)
 {
     /* determine interaction locality from atom locality */
     int iLocality = gpuAtomToInteractionLocality(aloc);
 
-    /* Launch wait/update timers & counters and do reduction into staging buffers
-       BUT skip it when during the non-local phase there was actually no work to do.
-       This is consistent with nbnxn_gpu_launch_kernel.
-
-       NOTE: if timing with multiple GPUs (streams) becomes possible, the
-       counters could end up being inconsistent due to not being incremented
-       on some of the nodes! */
+    //  We skip when during the non-local phase there was actually no work to do.
+    //  This is consistent with nbnxn_gpu_launch_kernel.
     if (!canSkipWork(nb, iLocality))
     {
-        gpuStreamSynchronize(nb->stream[iLocality]);
+        // Query the state of the GPU stream and return early if we're not done
+        if (completionKind == GpuTaskCompletion::Check)
+        {
+            if (!haveStreamTasksCompleted(nb->stream[iLocality]))
+            {
+                // Early return to skip the steps below that we have to do only
+                // after the NB task completed
+                return false;
+            }
+        }
+        else
+        {
+            gpuStreamSynchronize(nb->stream[iLocality]);
+        }
 
         bool calcEner   = flags & GMX_FORCE_ENERGY;
         bool calcFshift = flags & GMX_FORCE_VIRIAL;
@@ -329,6 +342,34 @@ void nbnxn_gpu_wait_for_gpu(gmx_nbnxn_gpu_t *nb,
 
     /* Turn off initial list pruning (doesn't hurt if this is not pair-search step). */
     nb->plist[iLocality]->haveFreshList = false;
+
+    return true;
+}
+
+/*! \brief
+ * Wait for the asynchronously launched nonbonded tasks and data
+ * transfers to finish.
+ *
+ * Also does timing accounting and reduction of the internal staging buffers.
+ * As this is called at the end of the step, it also resets the pair list and
+ * pruning flags.
+ *
+ * \param[in] nb The nonbonded data GPU structure
+ * \param[in] flags Force flags
+ * \param[in] aloc Atom locality identifier
+ * \param[out] e_lj Pointer to the LJ energy output to accumulate into
+ * \param[out] e_el Pointer to the electrostatics energy output to accumulate into
+ * \param[out] fshift Pointer to the shift force buffer to accumulate into
+ */
+void nbnxn_gpu_wait_finish_task(gmx_nbnxn_gpu_t *nb,
+                                int              flags,
+                                int              aloc,
+                                real            *e_lj,
+                                real            *e_el,
+                                rvec            *fshift)
+{
+    nbnxn_gpu_try_finish_task(nb, flags, aloc, e_lj, e_el, fshift,
+                              GpuTaskCompletion::Wait);
 }
 
 #endif