Pass the new gmx::ForceFlags to the nbnxm module
[alexxy/gromacs.git] / src / gromacs / nbnxm / gpu_common.h
index 3d7871c96900d91702d832559d3153270ca1d151..599c97edd4d249c5d9980b9abc601a1baf2205ce 100644 (file)
@@ -58,7 +58,7 @@
 #include "gromacs/gpu_utils/gpu_utils.h"
 #include "gromacs/listed_forces/gpubonded.h"
 #include "gromacs/math/vec.h"
-#include "gromacs/mdlib/force_flags.h"
+#include "gromacs/mdlib/ppforceworkload.h"
 #include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/pbcutil/ishift.h"
 #include "gromacs/timing/gpu_timing.h"
@@ -367,7 +367,7 @@ gpu_accumulate_timings(gmx_wallclock_gpu_nbnxn_t *timings,
 //TODO: move into shared source file with gmx_compile_cpp_as_cuda
 //NOLINTNEXTLINE(misc-definitions-in-headers)
 bool gpu_try_finish_task(gmx_nbnxn_gpu_t          *nb,
-                         const int                 flags,
+                         const gmx::ForceFlags    &forceFlags,
                          const AtomLocality        aloc,
                          real                     *e_lj,
                          real                     *e_el,
@@ -410,13 +410,10 @@ bool gpu_try_finish_task(gmx_nbnxn_gpu_t          *nb,
             gpuStreamSynchronize(nb->stream[iLocality]);
         }
 
-        bool calcEner   = (flags & GMX_FORCE_ENERGY) != 0;
-        bool calcFshift = (flags & GMX_FORCE_VIRIAL) != 0;
-
-        gpu_accumulate_timings(nb->timings, nb->timers, nb->plist[iLocality], aloc, calcEner,
+        gpu_accumulate_timings(nb->timings, nb->timers, nb->plist[iLocality], aloc, forceFlags.computeEnergy,
                                nb->bDoTime != 0);
 
-        gpu_reduce_staged_outputs(nb->nbst, iLocality, calcEner, calcFshift,
+        gpu_reduce_staged_outputs(nb->nbst, iLocality, forceFlags.computeEnergy, forceFlags.computeVirial,
                                   e_lj, e_el, as_rvec_array(shiftForces.data()));
     }
 
@@ -438,7 +435,7 @@ bool gpu_try_finish_task(gmx_nbnxn_gpu_t          *nb,
  * pruning flags.
  *
  * \param[in] nb The nonbonded data GPU structure
- * \param[in] flags Force flags
+ * \param[in]  forceFlags     Force schedule flags
  * \param[in] aloc Atom locality identifier
  * \param[out] e_lj Pointer to the LJ energy output to accumulate into
  * \param[out] e_el Pointer to the electrostatics energy output to accumulate into
@@ -448,7 +445,7 @@ bool gpu_try_finish_task(gmx_nbnxn_gpu_t          *nb,
  */
 //NOLINTNEXTLINE(misc-definitions-in-headers) TODO: move into source file
 float gpu_wait_finish_task(gmx_nbnxn_gpu_t         *nb,
-                           int                      flags,
+                           const gmx::ForceFlags   &forceFlags,
                            AtomLocality             aloc,
                            real                    *e_lj,
                            real                    *e_el,
@@ -459,7 +456,7 @@ float gpu_wait_finish_task(gmx_nbnxn_gpu_t         *nb,
         (gpuAtomToInteractionLocality(aloc) == InteractionLocality::Local) ? ewcWAIT_GPU_NB_L : ewcWAIT_GPU_NB_NL;
 
     wallcycle_start(wcycle, cycleCounter);
-    gpu_try_finish_task(nb, flags, aloc, e_lj, e_el, shiftForces,
+    gpu_try_finish_task(nb, forceFlags, aloc, e_lj, e_el, shiftForces,
                         GpuTaskCompletion::Wait, wcycle);
     float waitTime = wallcycle_stop(wcycle, cycleCounter);