Refactor tracking of GPU short-range work/skipping
[alexxy/gromacs.git] / src / gromacs / mdlib / sim_util.cpp
index ad61dee88acadbafcb2aa3f878076ad09b28cd8a..5323c187a92b75ba72ad1191dcafd4b6264dbbba 100644 (file)
@@ -646,7 +646,6 @@ static void launchPmeGpuFftAndGather(gmx_pme_t        *pmedata,
  * \param[in,out] enerd            Energy data structure results are reduced into
  * \param[in]     flags            Force flags
  * \param[in]     pmeFlags         PME flags
- * \param[in]     haveOtherWork    Tells whether there is other work than non-bonded in the stream(s)
  * \param[in]     wcycle           The wallcycle structure
  */
 static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t                  *nbv,
@@ -657,7 +656,6 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t                  *nbv
                                         gmx_enerdata_t                      *enerd,
                                         int                                  flags,
                                         int                                  pmeFlags,
-                                        bool                                 haveOtherWork,
                                         gmx_wallcycle_t                      wcycle)
 {
     bool isPmeGpuDone = false;
@@ -681,7 +679,6 @@ static void alternatePmeNbGpuWaitReduce(nonbonded_verlet_t                  *nbv
             isNbGpuDone = Nbnxm::gpu_try_finish_task(nbv->gpu_nbv,
                                                      flags,
                                                      Nbnxm::AtomLocality::Local,
-                                                     haveOtherWork,
                                                      enerd->grpp.ener[egLJSR].data(),
                                                      enerd->grpp.ener[egCOULSR].data(),
                                                      fshift, completionType);
@@ -1035,6 +1032,9 @@ void do_force(FILE                                     *fplog,
         /* Note that with a GPU the launch overhead of the list transfer is not timed separately */
         nbv->constructPairlist(Nbnxm::InteractionLocality::Local,
                                &top->excls, step, nrnb);
+
+        nbv->setupGpuShortRangeWork(fr->gpuBonded, Nbnxm::InteractionLocality::Local);
+
         wallcycle_sub_stop(wcycle, ewcsNBS_SEARCH_LOCAL);
         wallcycle_stop(wcycle, ewcNS);
 
@@ -1061,8 +1061,7 @@ void do_force(FILE                                     *fplog,
         if (bNS || !useGpuXBufOps)
         {
             Nbnxm::gpu_copy_xq_to_gpu(nbv->gpu_nbv, nbv->nbat.get(),
-                                      Nbnxm::AtomLocality::Local,
-                                      ppForceWorkload->haveGpuBondedWork);
+                                      Nbnxm::AtomLocality::Local);
         }
         wallcycle_sub_stop(wcycle, ewcsLAUNCH_GPU_NONBONDED);
         // with X buffer ops offloaded to the GPU on all but the search steps
@@ -1105,6 +1104,8 @@ void do_force(FILE                                     *fplog,
             /* Note that with a GPU the launch overhead of the list transfer is not timed separately */
             nbv->constructPairlist(Nbnxm::InteractionLocality::NonLocal,
                                    &top->excls, step, nrnb);
+
+            nbv->setupGpuShortRangeWork(fr->gpuBonded, Nbnxm::InteractionLocality::NonLocal);
             wallcycle_sub_stop(wcycle, ewcsNBS_SEARCH_NONLOCAL);
             wallcycle_stop(wcycle, ewcNS);
         }
@@ -1125,8 +1126,7 @@ void do_force(FILE                                     *fplog,
             {
                 wallcycle_sub_start(wcycle, ewcsLAUNCH_GPU_NONBONDED);
                 Nbnxm::gpu_copy_xq_to_gpu(nbv->gpu_nbv, nbv->nbat.get(),
-                                          Nbnxm::AtomLocality::NonLocal,
-                                          ppForceWorkload->haveGpuBondedWork);
+                                          Nbnxm::AtomLocality::NonLocal);
                 wallcycle_sub_stop(wcycle, ewcsLAUNCH_GPU_NONBONDED);
             }
 
@@ -1155,10 +1155,10 @@ void do_force(FILE                                     *fplog,
         if (havePPDomainDecomposition(cr))
         {
             Nbnxm::gpu_launch_cpyback(nbv->gpu_nbv, nbv->nbat.get(),
-                                      flags, Nbnxm::AtomLocality::NonLocal, ppForceWorkload->haveGpuBondedWork);
+                                      flags, Nbnxm::AtomLocality::NonLocal);
         }
         Nbnxm::gpu_launch_cpyback(nbv->gpu_nbv, nbv->nbat.get(),
-                                  flags, Nbnxm::AtomLocality::Local, ppForceWorkload->haveGpuBondedWork);
+                                  flags, Nbnxm::AtomLocality::Local);
         wallcycle_sub_stop(wcycle, ewcsLAUNCH_GPU_NONBONDED);
 
         if (ppForceWorkload->haveGpuBondedWork && (flags & GMX_FORCE_ENERGY))
@@ -1323,7 +1323,6 @@ void do_force(FILE                                     *fplog,
                 wallcycle_start(wcycle, ewcWAIT_GPU_NB_NL);
                 Nbnxm::gpu_wait_finish_task(nbv->gpu_nbv,
                                             flags, Nbnxm::AtomLocality::NonLocal,
-                                            ppForceWorkload->haveGpuBondedWork,
                                             enerd->grpp.ener[egLJSR].data(),
                                             enerd->grpp.ener[egCOULSR].data(),
                                             fr->fshift);
@@ -1369,7 +1368,7 @@ void do_force(FILE                                     *fplog,
     if (alternateGpuWait)
     {
         alternatePmeNbGpuWaitReduce(fr->nbv.get(), fr->pmedata, &force, &forceOut.forceWithVirial, fr->fshift, enerd,
-                                    flags, pmeFlags, ppForceWorkload->haveGpuBondedWork, wcycle);
+                                    flags, pmeFlags, wcycle);
     }
 
     if (!alternateGpuWait && useGpuPme)
@@ -1389,7 +1388,7 @@ void do_force(FILE                                     *fplog,
 
         wallcycle_start(wcycle, ewcWAIT_GPU_NB_L);
         Nbnxm::gpu_wait_finish_task(nbv->gpu_nbv,
-                                    flags, Nbnxm::AtomLocality::Local, ppForceWorkload->haveGpuBondedWork,
+                                    flags, Nbnxm::AtomLocality::Local,
                                     enerd->grpp.ener[egLJSR].data(),
                                     enerd->grpp.ener[egCOULSR].data(),
                                     fr->fshift);