Refactor md_enums
[alexxy/gromacs.git] / src / gromacs / listed_forces / listed_forces.cpp
index 10e54fb419a2e07a43fd4aa8d8d40de8da1cf976..0364a2d7349a86b6a0f60b60aa32b2b558e85e49 100644 (file)
@@ -44,6 +44,8 @@
  */
 #include "gmxpre.h"
 
+#include "gromacs/utility/arrayref.h"
+#include "gromacs/utility/enumerationhelpers.h"
 #include "listed_forces.h"
 
 #include <cassert>
@@ -226,7 +228,7 @@ void zero_thread_output(f_thread_t* f_t)
             f_t->grpp.ener[i][j] = 0;
         }
     }
-    for (int i = 0; i < efptNR; i++)
+    for (auto i : keysOf(f_t->dvdl))
     {
         f_t->dvdl[i] = 0;
     }
@@ -298,7 +300,7 @@ void reduce_thread_forces(gmx::ArrayRef<gmx::RVec> force, const bonded_threading
 void reduce_thread_output(gmx::ForceWithShiftForces* forceWithShiftForces,
                           real*                      ener,
                           gmx_grppairener_t*         grpp,
-                          real*                      dvdl,
+                          gmx::ArrayRef<real>        dvdl,
                           const bonded_threading_t*  bt,
                           const gmx::StepWorkload&   stepWork)
 {
@@ -349,12 +351,12 @@ void reduce_thread_output(gmx::ForceWithShiftForces* forceWithShiftForces,
         }
         if (stepWork.computeDhdl)
         {
-            for (int i = 0; i < efptNR; i++)
+            for (auto i : keysOf(f_t[1]->dvdl))
             {
 
                 for (int t = 1; t < bt->nthreads; t++)
                 {
-                    dvdl[i] += f_t[t]->dvdl[i];
+                    dvdl[static_cast<int>(i)] += f_t[t]->dvdl[i];
                 }
             }
         }
@@ -414,8 +416,8 @@ real calc_one_bond(int                           thread,
                    const t_pbc*                  pbc,
                    gmx_grppairener_t*            grpp,
                    t_nrnb*                       nrnb,
-                   const real*                   lambda,
-                   real*                         dvdl,
+                   gmx::ArrayRef<const real>     lambda,
+                   gmx::ArrayRef<real>           dvdl,
                    const t_mdatoms*              md,
                    t_fcdata*                     fcd,
                    const gmx::StepWorkload&      stepWork,
@@ -428,14 +430,14 @@ real calc_one_bond(int                           thread,
             (idef.ilsort == ilsortFE_SORTED && numNonperturbedInteractions < iatoms.ssize());
     BondedKernelFlavor flavor =
             selectBondedKernelFlavor(stepWork, fr->use_simd_kernels, havePerturbedInteractions);
-    int efptFTYPE;
+    FreeEnergyPerturbationCouplingType efptFTYPE;
     if (IS_RESTRAINT_TYPE(ftype))
     {
-        efptFTYPE = efptRESTRAINT;
+        efptFTYPE = FreeEnergyPerturbationCouplingType::Restraint;
     }
     else
     {
-        efptFTYPE = efptBONDED;
+        efptFTYPE = FreeEnergyPerturbationCouplingType::Bonded;
     }
 
     const int nat1   = interaction_function[ftype].nratoms + 1;
@@ -466,8 +468,8 @@ real calc_one_bond(int                           thread,
                           f,
                           fshift,
                           pbc,
-                          lambda[efptFTYPE],
-                          &(dvdl[efptFTYPE]),
+                          lambda[static_cast<int>(efptFTYPE)],
+                          &(dvdl[static_cast<int>(efptFTYPE)]),
                           md,
                           fcd,
                           nullptr,
@@ -484,8 +486,8 @@ real calc_one_bond(int                           thread,
                                     f,
                                     fshift,
                                     pbc,
-                                    lambda[efptFTYPE],
-                                    &(dvdl[efptFTYPE]),
+                                    lambda[static_cast<int>(efptFTYPE)],
+                                    &(dvdl[static_cast<int>(efptFTYPE)]),
                                     md,
                                     fcd,
                                     fcd->disres,
@@ -507,8 +509,8 @@ real calc_one_bond(int                           thread,
                  f,
                  fshift,
                  pbc,
-                 lambda,
-                 dvdl,
+                 lambda.data(),
+                 dvdl.data(),
                  md,
                  fr,
                  havePerturbedInteractions,
@@ -537,8 +539,8 @@ static void calcBondedForces(const InteractionDefinitions& idef,
                              rvec*                         fshiftMasterBuffer,
                              gmx_enerdata_t*               enerd,
                              t_nrnb*                       nrnb,
-                             const real*                   lambda,
-                             real*                         dvdl,
+                             gmx::ArrayRef<const real>     lambda,
+                             gmx::ArrayRef<real>           dvdl,
                              const t_mdatoms*              md,
                              t_fcdata*                     fcd,
                              const gmx::StepWorkload&      stepWork,
@@ -553,9 +555,9 @@ static void calcBondedForces(const InteractionDefinitions& idef,
             int         ftype;
             real *      epot, v;
             /* thread stuff */
-            rvec*              fshift;
-            real*              dvdlt;
-            gmx_grppairener_t* grpp;
+            rvec*               fshift;
+            gmx::ArrayRef<real> dvdlt;
+            gmx_grppairener_t*  grpp;
 
             zero_thread_output(&threadBuffers);
 
@@ -643,7 +645,7 @@ void calc_listed(struct gmx_wallcycle*         wcycle,
                  const t_pbc*                  pbc,
                  gmx_enerdata_t*               enerd,
                  t_nrnb*                       nrnb,
-                 const real*                   lambda,
+                 gmx::ArrayRef<const real>     lambda,
                  const t_mdatoms*              md,
                  t_fcdata*                     fcd,
                  int*                          global_atom_index,
@@ -656,7 +658,7 @@ void calc_listed(struct gmx_wallcycle*         wcycle,
         wallcycle_sub_start(wcycle, ewcsLISTED);
         /* The dummy array is to have a place to store the dhdl at other values
            of lambda, which will be thrown away in the end */
-        real dvdl[efptNR] = { 0 };
+        gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl = { 0 };
         calcBondedForces(idef,
                          bt,
                          x,
@@ -678,7 +680,7 @@ void calc_listed(struct gmx_wallcycle*         wcycle,
 
         if (stepWork.computeDhdl)
         {
-            for (int i = 0; i < efptNR; i++)
+            for (auto i : keysOf(enerd->dvdl_lin))
             {
                 enerd->dvdl_nonlin[i] += dvdl[i];
             }
@@ -709,7 +711,7 @@ void calc_listed_lambda(const InteractionDefinitions& idef,
                         real*                         epot,
                         gmx::ArrayRef<real>           dvdl,
                         t_nrnb*                       nrnb,
-                        const real*                   lambda,
+                        gmx::ArrayRef<const real>     lambda,
                         const t_mdatoms*              md,
                         t_fcdata*                     fcd,
                         int*                          global_atom_index)
@@ -766,7 +768,7 @@ void calc_listed_lambda(const InteractionDefinitions& idef,
                                        grpp,
                                        nrnb,
                                        lambda,
-                                       dvdl.data(),
+                                       dvdl,
                                        md,
                                        fcd,
                                        tempFlags,
@@ -793,7 +795,7 @@ void ListedForces::calculate(struct gmx_wallcycle*                     wcycle,
                              const struct t_pbc*                       pbc,
                              gmx_enerdata_t*                           enerd,
                              t_nrnb*                                   nrnb,
-                             const real*                               lambda,
+                             gmx::ArrayRef<const real>                 lambda,
                              const t_mdatoms*                          md,
                              int*                                      global_atom_index,
                              const gmx::StepWorkload&                  stepWork)
@@ -874,7 +876,7 @@ void ListedForces::calculate(struct gmx_wallcycle*                     wcycle,
      */
     if (fepvals->n_lambda > 0 && stepWork.computeDhdl)
     {
-        real dvdl[efptNR] = { 0 };
+        gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> dvdl = { 0 };
         if (!idef.il[F_POSRES].empty())
         {
             posres_wrapper_lambda(wcycle, fepvals, idef, &pbc_full, x, enerd, lambda, fr);
@@ -888,12 +890,12 @@ void ListedForces::calculate(struct gmx_wallcycle*                     wcycle,
             }
             for (int i = 0; i < 1 + enerd->foreignLambdaTerms.numLambdas(); i++)
             {
-                real lam_i[efptNR];
+                gmx::EnumerationArray<FreeEnergyPerturbationCouplingType, real> lam_i;
 
                 reset_foreign_enerdata(enerd);
-                for (int j = 0; j < efptNR; j++)
+                for (auto j : keysOf(lam_i))
                 {
-                    lam_i[j] = (i == 0 ? lambda[j] : fepvals->all_lambda[j][i - 1]);
+                    lam_i[j] = (i == 0 ? lambda[static_cast<int>(j)] : fepvals->all_lambda[j][i - 1]);
                 }
                 calc_listed_lambda(idef,
                                    threading_.get(),