Pass gmx::ForceFlags to CPU nbnxm dispatch code
[alexxy/gromacs.git] / src / gromacs / nbnxm / kernels_reference / kernel_gpu_ref.cpp
index 91bad8a61a414ab43e233752b6bdd1acb4f67b5a..16a439711295b359e5276ee0508ad8462e0b13a4 100644 (file)
@@ -43,7 +43,7 @@
 #include "gromacs/math/functions.h"
 #include "gromacs/math/utilities.h"
 #include "gromacs/math/vec.h"
-#include "gromacs/mdlib/force_flags.h"
+#include "gromacs/mdlib/ppforceworkload.h"
 #include "gromacs/mdtypes/md_enums.h"
 #include "gromacs/nbnxm/atomdata.h"
 #include "gromacs/nbnxm/nbnxm.h"
@@ -59,14 +59,13 @@ nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu     *nbl,
                      const nbnxn_atomdata_t     *nbat,
                      const interaction_const_t  *iconst,
                      rvec                       *shift_vec,
-                     int                         force_flags,
+                     const gmx::ForceFlags      &forceFlags,
                      int                         clearF,
                      gmx::ArrayRef<real>         f,
                      real  *                     fshift,
                      real  *                     Vc,
                      real  *                     Vvdw)
 {
-    gmx_bool            bEner;
     gmx_bool            bEwald;
     const real         *Ftab = nullptr;
     real                rcut2, rvdw2, rlist2;
@@ -114,8 +113,6 @@ nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu     *nbl,
         }
     }
 
-    bEner = ((force_flags & GMX_FORCE_ENERGY) != 0);
-
     bEwald = EEL_FULL(iconst->eeltype);
     if (bEwald)
     {
@@ -265,7 +262,7 @@ nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu     *nbl,
                                     /* Reaction-field */
                                     krsq  = iconst->k_rf*rsq;
                                     fscal = qq*(int_bit*rinv - 2*krsq)*rinvsq;
-                                    if (bEner)
+                                    if (forceFlags.computeEnergy)
                                     {
                                         vcoul = qq*(int_bit*rinv + krsq - iconst->c_rf);
                                     }
@@ -281,7 +278,7 @@ nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu     *nbl,
 
                                     fscal = qq*(int_bit*rinvsq - fexcl)*rinv;
 
-                                    if (bEner)
+                                    if (forceFlags.computeEnergy)
                                     {
                                         vcoul = qq*((int_bit - std::erf(iconst->ewaldcoeff_q*r))*rinv - int_bit*iconst->sh_ewald);
                                     }
@@ -300,7 +297,7 @@ nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu     *nbl,
                                     Vvdw_rep  = c12*rinvsix*rinvsix;
                                     fscal    += (Vvdw_rep - Vvdw_disp)*rinvsq;
 
-                                    if (bEner)
+                                    if (forceFlags.computeEnergy)
                                     {
                                         vctot   += vcoul;
 
@@ -350,7 +347,7 @@ nbnxn_kernel_gpu_ref(const NbnxnPairlistGpu     *nbl,
             }
         }
 
-        if (bEner)
+        if (forceFlags.computeEnergy)
         {
             ggid             = 0;
             Vc[ggid]         = Vc[ggid]   + vctot;