Pass gmx::ForceFlags to CPU nbnxm dispatch code
[alexxy/gromacs.git] / src / gromacs / nbnxm / kerneldispatch.cpp
index 1303a20e483d67467a635795deb1578efab32274..2d53a92dc752691fca58e0ee902d8f577a22c68e 100644 (file)
@@ -42,8 +42,8 @@
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdlib/enerdata_utils.h"
 #include "gromacs/mdlib/force.h"
-#include "gromacs/mdlib/force_flags.h"
 #include "gromacs/mdlib/gmx_omp_nthreads.h"
+#include "gromacs/mdlib/ppforceworkload.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/forceoutput.h"
 #include "gromacs/mdtypes/inputrec.h"
@@ -155,7 +155,7 @@ nbnxn_kernel_cpu(const PairlistSet              &pairlistSet,
                  nbnxn_atomdata_t               *nbat,
                  const interaction_const_t      &ic,
                  rvec                           *shiftVectors,
-                 int                             forceFlags,
+                 const gmx::ForceFlags          &forceFlags,
                  int                             clearF,
                  real                           *vCoulomb,
                  real                           *vVdw,
@@ -266,7 +266,7 @@ nbnxn_kernel_cpu(const PairlistSet              &pairlistSet,
         // TODO: Change to reference
         const NbnxnPairlistCpu *pairlist = &pairlists[nb];
 
-        if (!(forceFlags & GMX_FORCE_ENERGY))
+        if (!forceFlags.computeEnergy)
         {
             /* Don't calculate energies */
             switch (kernelSetup.kernelType)
@@ -396,7 +396,7 @@ nbnxn_kernel_cpu(const PairlistSet              &pairlistSet,
     }
     wallcycle_sub_stop(wcycle, ewcsNONBONDED_KERNEL);
 
-    if (forceFlags & GMX_FORCE_ENERGY)
+    if (forceFlags.computeEnergy)
     {
         reduce_energies_over_lists(nbat, pairlists.ssize(), vVdw, vCoulomb);
     }
@@ -406,7 +406,7 @@ static void accountFlops(t_nrnb                           *nrnb,
                          const PairlistSet                &pairlistSet,
                          const nonbonded_verlet_t         &nbv,
                          const interaction_const_t        &ic,
-                         const int                         forceFlags)
+                         const gmx::ForceFlags            &forceFlags)
 {
     const bool usingGpuKernels = nbv.useGpu();
 
@@ -425,7 +425,7 @@ static void accountFlops(t_nrnb                           *nrnb,
         enr_nbnxn_kernel_ljc = eNR_NBNXN_LJ_TAB;
     }
     int enr_nbnxn_kernel_lj = eNR_NBNXN_LJ;
-    if (forceFlags & GMX_FORCE_ENERGY)
+    if (forceFlags.computeEnergy)
     {
         /* In eNR_??? the nbnxn F+E kernels are always the F kernel + 1 */
         enr_nbnxn_kernel_ljc += 1;
@@ -440,23 +440,22 @@ static void accountFlops(t_nrnb                           *nrnb,
     inc_nrnb(nrnb, enr_nbnxn_kernel_ljc-eNR_NBNXN_LJ_RF+eNR_NBNXN_RF,
              pairlistSet.natpair_q_);
 
-    const bool calcEnergy = ((forceFlags & GMX_FORCE_ENERGY) != 0);
     if (ic.vdw_modifier == eintmodFORCESWITCH)
     {
         /* We add up the switch cost separately */
-        inc_nrnb(nrnb, eNR_NBNXN_ADD_LJ_FSW + (calcEnergy ? 1 : 0),
+        inc_nrnb(nrnb, eNR_NBNXN_ADD_LJ_FSW + (forceFlags.computeEnergy ? 1 : 0),
                  pairlistSet.natpair_ljq_ + pairlistSet.natpair_lj_);
     }
     if (ic.vdw_modifier == eintmodPOTSWITCH)
     {
         /* We add up the switch cost separately */
-        inc_nrnb(nrnb, eNR_NBNXN_ADD_LJ_PSW + (calcEnergy ? 1 : 0),
+        inc_nrnb(nrnb, eNR_NBNXN_ADD_LJ_PSW + (forceFlags.computeEnergy ? 1 : 0),
                  pairlistSet.natpair_ljq_ + pairlistSet.natpair_lj_);
     }
     if (ic.vdwtype == evdwPME)
     {
         /* We add up the LJ Ewald cost separately */
-        inc_nrnb(nrnb, eNR_NBNXN_ADD_LJ_EWALD + (calcEnergy ? 1 : 0),
+        inc_nrnb(nrnb, eNR_NBNXN_ADD_LJ_EWALD + (forceFlags.computeEnergy ? 1 : 0),
                  pairlistSet.natpair_ljq_ + pairlistSet.natpair_lj_);
     }
 }
@@ -464,7 +463,6 @@ static void accountFlops(t_nrnb                           *nrnb,
 void
 nonbonded_verlet_t::dispatchNonbondedKernel(Nbnxm::InteractionLocality iLocality,
                                             const interaction_const_t &ic,
-                                            int                        legacyForceFlags,
                                             const gmx::ForceFlags     &forceFlags,
                                             int                        clearF,
                                             const t_forcerec          &fr,
@@ -483,7 +481,7 @@ nonbonded_verlet_t::dispatchNonbondedKernel(Nbnxm::InteractionLocality iLocality
                              nbat.get(),
                              ic,
                              fr.shift_vec,
-                             legacyForceFlags,
+                             forceFlags,
                              clearF,
                              enerd->grpp.ener[egCOULSR].data(),
                              fr.bBHAM ?
@@ -500,7 +498,7 @@ nonbonded_verlet_t::dispatchNonbondedKernel(Nbnxm::InteractionLocality iLocality
             nbnxn_kernel_gpu_ref(pairlistSet.gpuList(),
                                  nbat.get(), &ic,
                                  fr.shift_vec,
-                                 legacyForceFlags,
+                                 forceFlags,
                                  clearF,
                                  nbat->out[0].f,
                                  nbat->out[0].fshift.data(),
@@ -515,7 +513,7 @@ nonbonded_verlet_t::dispatchNonbondedKernel(Nbnxm::InteractionLocality iLocality
 
     }
 
-    accountFlops(nrnb, pairlistSet, *this, ic, legacyForceFlags);
+    accountFlops(nrnb, pairlistSet, *this, ic, forceFlags);
 }
 
 void
@@ -527,7 +525,7 @@ nonbonded_verlet_t::dispatchFreeEnergyKernel(Nbnxm::InteractionLocality  iLocali
                                              t_lambda                   *fepvals,
                                              real                       *lambda,
                                              gmx_enerdata_t             *enerd,
-                                             const int                   forceFlags,
+                                             const gmx::ForceFlags      &forceFlags,
                                              t_nrnb                     *nrnb)
 {
     const auto nbl_fep = pairlistSets().pairlistSet(iLocality).fepLists();
@@ -543,15 +541,15 @@ nonbonded_verlet_t::dispatchFreeEnergyKernel(Nbnxm::InteractionLocality  iLocali
     donb_flags |= GMX_NONBONDED_DO_SR;
 
     /* Currently all group scheme kernels always calculate (shift-)forces */
-    if (forceFlags & GMX_FORCE_FORCES)
+    if (forceFlags.computeForces)
     {
         donb_flags |= GMX_NONBONDED_DO_FORCE;
     }
-    if (forceFlags & GMX_FORCE_VIRIAL)
+    if (forceFlags.computeVirial)
     {
         donb_flags |= GMX_NONBONDED_DO_SHIFTFORCE;
     }
-    if (forceFlags & GMX_FORCE_ENERGY)
+    if (forceFlags.computeEnergy)
     {
         donb_flags |= GMX_NONBONDED_DO_POTENTIAL;
     }
@@ -594,7 +592,7 @@ nonbonded_verlet_t::dispatchFreeEnergyKernel(Nbnxm::InteractionLocality  iLocali
     /* If we do foreign lambda and we have soft-core interactions
      * we have to recalculate the (non-linear) energies contributions.
      */
-    if (fepvals->n_lambda > 0 && (forceFlags & GMX_FORCE_DHDL) && fepvals->sc_alpha != 0)
+    if (fepvals->n_lambda > 0 && forceFlags.computeDhdl && fepvals->sc_alpha != 0)
     {
         real lam_i[efptNR];
         kernel_data.flags          = (donb_flags & ~(GMX_NONBONDED_DO_FORCE | GMX_NONBONDED_DO_SHIFTFORCE)) | GMX_NONBONDED_DO_FOREIGNLAMBDA;