Refactor tracking of GPU short-range work/skipping
[alexxy/gromacs.git] / src / gromacs / nbnxm / cuda / nbnxm_cuda.cu
index a0117a5cb1a23725798898d7b66e4ef7e1dafb8b..e8e6c5b5cc1748701beff7a8145c402db2e80ea5 100644 (file)
@@ -309,9 +309,10 @@ void nbnxnInsertNonlocalGpuDependency(const gmx_nbnxn_cuda_t   *nb,
 /*! \brief Launch asynchronously the xq buffer host to device copy. */
 void gpu_copy_xq_to_gpu(gmx_nbnxn_cuda_t       *nb,
                         const nbnxn_atomdata_t *nbatom,
-                        const AtomLocality      atomLocality,
-                        const bool              haveOtherWork)
+                        const AtomLocality      atomLocality)
 {
+    GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
+
     GMX_ASSERT(atomLocality == AtomLocality::Local || atomLocality == AtomLocality::NonLocal,
                "Only local and non-local xq transfers are supported");
 
@@ -335,7 +336,7 @@ void gpu_copy_xq_to_gpu(gmx_nbnxn_cuda_t       *nb,
        we always call the local local x+q copy (and the rest of the local
        work in nbnxn_gpu_launch_kernel().
      */
-    if (!haveOtherWork && canSkipWork(*nb, iloc))
+    if ((iloc == InteractionLocality::NonLocal) && !haveGpuShortRangeWork(*nb, iloc))
     {
         plist->haveFreshList = false;
 
@@ -418,7 +419,7 @@ void gpu_launch_kernel(gmx_nbnxn_cuda_t          *nb,
        clearing. All these operations, except for the local interaction kernel,
        are needed for the non-local interactions. The skip of the local kernel
        call is taken care of later in this function. */
-    if (canSkipWork(*nb, iloc))
+    if (canSkipNonbondedWork(*nb, iloc))
     {
         plist->haveFreshList = false;
 
@@ -639,9 +640,10 @@ void gpu_launch_kernel_pruneonly(gmx_nbnxn_cuda_t          *nb,
 void gpu_launch_cpyback(gmx_nbnxn_cuda_t       *nb,
                         nbnxn_atomdata_t       *nbatom,
                         const int               flags,
-                        const AtomLocality      atomLocality,
-                        const bool              haveOtherWork)
+                        const AtomLocality      atomLocality)
 {
+    GMX_ASSERT(nb, "Need a valid nbnxn_gpu object");
+
     cudaError_t stat;
     int         adat_begin, adat_len; /* local/nonlocal offset and length used for xq and f */
 
@@ -658,7 +660,7 @@ void gpu_launch_cpyback(gmx_nbnxn_cuda_t       *nb,
     bool             bCalcFshift = flags & GMX_FORCE_VIRIAL;
 
     /* don't launch non-local copy-back if there was no non-local work to do */
-    if (!haveOtherWork && canSkipWork(*nb, iloc))
+    if ((iloc == InteractionLocality::NonLocal) && !haveGpuShortRangeWork(*nb, iloc))
     {
         return;
     }