Make TrotterSequence enum class
authorejjordan <ejjordan@kth.se>
Mon, 17 May 2021 13:54:03 +0000 (15:54 +0200)
committerJoe Jordan <ejjordan12@gmail.com>
Wed, 19 May 2021 15:13:09 +0000 (15:13 +0000)
A subesquent MR can make an extended Trotter enum class, removing
dependence of coupling and update_vv headers on md_enums header.

src/gromacs/mdlib/coupling.cpp
src/gromacs/mdlib/coupling.h
src/gromacs/mdlib/update_vv.cpp
src/gromacs/mdlib/update_vv.h
src/gromacs/mdrun/md.cpp
src/gromacs/mdtypes/md_enums.h

index c9a288d8a43b8ecd670a51e5f33d6fa05b82e2f2..ef2a5c5048f4c284468d6b2718f381c513fb6096 100644 (file)
@@ -881,8 +881,8 @@ void calculateScalingMatrixImplDetail<PressureCoupling::Berendsen>(const t_input
                                                                    real              dt,
                                                                    const matrix      pres,
                                                                    const matrix      box,
-                                                                   real scalar_pressure,
-                                                                   real xy_pressure,
+                                                                   real    scalar_pressure,
+                                                                   real    xy_pressure,
                                                                    int64_t gmx_unused step)
 {
     real p_corr_z = 0;
@@ -1310,7 +1310,7 @@ void trotter_update(const t_inputrec*                   ir,
                     gmx::ArrayRef<const real>           invMass,
                     const t_extmass*                    MassQ,
                     gmx::ArrayRef<std::vector<int>>     trotter_seqlist,
-                    int                                 trotter_seqno)
+                    TrotterSequence                     trotter_seqno)
 {
 
     int              n, i, d, ngtc, gc = 0, t;
@@ -1322,7 +1322,7 @@ void trotter_update(const t_inputrec*                   ir,
     rvec             sumv = { 0, 0, 0 };
     bool             bCouple;
 
-    if (trotter_seqno <= ettTSEQ2)
+    if (trotter_seqno <= TrotterSequence::Two)
     {
         step_eff = step - 1; /* the velocity verlet calls are actually out of order -- the first
                                 half step is actually the last half step from the previous step.
@@ -1335,7 +1335,7 @@ void trotter_update(const t_inputrec*                   ir,
 
     bCouple = (ir->nsttcouple == 1 || do_per_step(step_eff + ir->nsttcouple, ir->nsttcouple));
 
-    const gmx::ArrayRef<const int> trotter_seq = trotter_seqlist[trotter_seqno];
+    const gmx::ArrayRef<const int> trotter_seq = trotter_seqlist[static_cast<int>(trotter_seqno)];
 
     if ((trotter_seq[0] == etrtSKIPALL) || (!bCouple))
     {
@@ -1537,7 +1537,7 @@ extern void init_npt_masses(const t_inputrec* ir, t_state* state, t_extmass* Mas
     }
 }
 
-std::array<std::vector<int>, ettTSEQMAX>
+gmx::EnumerationArray<TrotterSequence, std::vector<int>>
 init_npt_vars(const t_inputrec* ir, t_state* state, t_extmass* MassQ, bool bTrotter)
 {
     int              i, j, nnhpres, nh;
@@ -1556,8 +1556,8 @@ init_npt_vars(const t_inputrec* ir, t_state* state, t_extmass* MassQ, bool bTrot
     init_npt_masses(ir, state, MassQ, TRUE);
 
     /* first, initialize clear all the trotter calls */
-    std::array<std::vector<int>, ettTSEQMAX> trotter_seq;
-    for (i = 0; i < ettTSEQMAX; i++)
+    gmx::EnumerationArray<TrotterSequence, std::vector<int>> trotter_seq;
+    for (i = 0; i < static_cast<int>(TrotterSequence::Count); i++)
     {
         trotter_seq[i].resize(NTROTTERPARTS, etrtNONE);
         trotter_seq[i][0] = etrtSKIPALL;
index 92333c5cca85786e228f5e76c8c5b3b3007f0c95..15211f5a83b9ed1227bd17f97fb4a65830560786 100644 (file)
@@ -45,6 +45,7 @@
 
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/utility/enumerationhelpers.h"
 #include "gromacs/utility/real.h"
 
 class gmx_ekindata_t;
@@ -152,9 +153,9 @@ void trotter_update(const t_inputrec*                   ir,
                     gmx::ArrayRef<const real>           invMass,
                     const t_extmass*                    MassQ,
                     gmx::ArrayRef<std::vector<int>>     trotter_seqlist,
-                    int                                 trotter_seqno);
+                    TrotterSequence                     trotter_seqno);
 
-std::array<std::vector<int>, ettTSEQMAX>
+gmx::EnumerationArray<TrotterSequence, std::vector<int>>
 init_npt_vars(const t_inputrec* ir, t_state* state, t_extmass* Mass, bool bTrotter);
 
 real NPT_energy(const t_inputrec* ir, const t_state* state, const t_extmass* MassQ);
index d4d1aa416a74367c0b6161ca5e36a11adbea0b32..10feea2d1ed28bd8d202bf1131666be6ad5b5e5d 100644 (file)
 #include "gromacs/timing/wallcycle.h"
 #include "gromacs/topology/topology.h"
 
-void integrateVVFirstStep(int64_t                                  step,
-                          bool                                     bFirstStep,
-                          bool                                     bInitStep,
-                          gmx::StartingBehavior                    startingBehavior,
-                          int                                      nstglobalcomm,
-                          const t_inputrec*                        ir,
-                          t_forcerec*                              fr,
-                          t_commrec*                               cr,
-                          t_state*                                 state,
-                          t_mdatoms*                               mdatoms,
-                          const t_fcdata&                          fcdata,
-                          t_extmass*                               MassQ,
-                          t_vcm*                                   vcm,
-                          const gmx_mtop_t&                        top_global,
-                          const gmx_localtop_t&                    top,
-                          gmx_enerdata_t*                          enerd,
-                          gmx_ekindata_t*                          ekind,
-                          gmx_global_stat*                         gstat,
-                          real*                                    last_ekin,
-                          bool                                     bCalcVir,
-                          tensor                                   total_vir,
-                          tensor                                   shake_vir,
-                          tensor                                   force_vir,
-                          tensor                                   pres,
-                          matrix                                   M,
-                          bool                                     do_log,
-                          bool                                     do_ene,
-                          bool                                     bCalcEner,
-                          bool                                     bGStat,
-                          bool                                     bStopCM,
-                          bool                                     bTrotter,
-                          bool                                     bExchanged,
-                          bool*                                    bSumEkinhOld,
-                          real*                                    saved_conserved_quantity,
-                          gmx::ForceBuffers*                       f,
-                          gmx::Update*                             upd,
-                          gmx::Constraints*                        constr,
-                          gmx::SimulationSignaller*                nullSignaller,
-                          std::array<std::vector<int>, ettTSEQMAX> trotter_seq,
-                          t_nrnb*                                  nrnb,
-                          const gmx::MDLogger&                     mdlog,
-                          FILE*                                    fplog,
-                          gmx_wallcycle*                           wcycle)
+void integrateVVFirstStep(int64_t                   step,
+                          bool                      bFirstStep,
+                          bool                      bInitStep,
+                          gmx::StartingBehavior     startingBehavior,
+                          int                       nstglobalcomm,
+                          const t_inputrec*         ir,
+                          t_forcerec*               fr,
+                          t_commrec*                cr,
+                          t_state*                  state,
+                          t_mdatoms*                mdatoms,
+                          const t_fcdata&           fcdata,
+                          t_extmass*                MassQ,
+                          t_vcm*                    vcm,
+                          const gmx_mtop_t&         top_global,
+                          const gmx_localtop_t&     top,
+                          gmx_enerdata_t*           enerd,
+                          gmx_ekindata_t*           ekind,
+                          gmx_global_stat*          gstat,
+                          real*                     last_ekin,
+                          bool                      bCalcVir,
+                          tensor                    total_vir,
+                          tensor                    shake_vir,
+                          tensor                    force_vir,
+                          tensor                    pres,
+                          matrix                    M,
+                          bool                      do_log,
+                          bool                      do_ene,
+                          bool                      bCalcEner,
+                          bool                      bGStat,
+                          bool                      bStopCM,
+                          bool                      bTrotter,
+                          bool                      bExchanged,
+                          bool*                     bSumEkinhOld,
+                          real*                     saved_conserved_quantity,
+                          gmx::ForceBuffers*        f,
+                          gmx::Update*              upd,
+                          gmx::Constraints*         constr,
+                          gmx::SimulationSignaller* nullSignaller,
+                          gmx::EnumerationArray<TrotterSequence, std::vector<int>> trotter_seq,
+                          t_nrnb*                                                  nrnb,
+                          const gmx::MDLogger&                                     mdlog,
+                          FILE*                                                    fplog,
+                          gmx_wallcycle*                                           wcycle)
 {
     if (!bFirstStep || startingBehavior == gmx::StartingBehavior::NewSimulation)
     {
@@ -144,7 +144,7 @@ void integrateVVFirstStep(int64_t                                  step,
                            gmx::arrayRefFromArray(mdatoms->invmass, mdatoms->nr),
                            MassQ,
                            trotter_seq,
-                           ettTSEQ1);
+                           TrotterSequence::One);
         }
 
         upd->update_coords(*ir,
@@ -255,7 +255,7 @@ void integrateVVFirstStep(int64_t                                  step,
                                gmx::arrayRefFromArray(mdatoms->invmass, mdatoms->nr),
                                MassQ,
                                trotter_seq,
-                               ettTSEQ2);
+                               TrotterSequence::Two);
 
                 /* TODO This is only needed when we're about to write
                  * a checkpoint, because we use it after the restart
@@ -332,39 +332,39 @@ void integrateVVFirstStep(int64_t                                  step,
     }
 }
 
-void integrateVVSecondStep(int64_t                                  step,
-                           const t_inputrec*                        ir,
-                           t_forcerec*                              fr,
-                           t_commrec*                               cr,
-                           t_state*                                 state,
-                           t_mdatoms*                               mdatoms,
-                           const t_fcdata&                          fcdata,
-                           t_extmass*                               MassQ,
-                           t_vcm*                                   vcm,
-                           pull_t*                                  pull_work,
-                           gmx_enerdata_t*                          enerd,
-                           gmx_ekindata_t*                          ekind,
-                           gmx_global_stat*                         gstat,
-                           real*                                    dvdl_constr,
-                           bool                                     bCalcVir,
-                           tensor                                   total_vir,
-                           tensor                                   shake_vir,
-                           tensor                                   force_vir,
-                           tensor                                   pres,
-                           matrix                                   M,
-                           matrix                                   lastbox,
-                           bool                                     do_log,
-                           bool                                     do_ene,
-                           bool                                     bGStat,
-                           bool*                                    bSumEkinhOld,
-                           gmx::ForceBuffers*                       f,
-                           std::vector<gmx::RVec>*                  cbuf,
-                           gmx::Update*                             upd,
-                           gmx::Constraints*                        constr,
-                           gmx::SimulationSignaller*                nullSignaller,
-                           std::array<std::vector<int>, ettTSEQMAX> trotter_seq,
-                           t_nrnb*                                  nrnb,
-                           gmx_wallcycle*                           wcycle)
+void integrateVVSecondStep(int64_t                                                  step,
+                           const t_inputrec*                                        ir,
+                           t_forcerec*                                              fr,
+                           t_commrec*                                               cr,
+                           t_state*                                                 state,
+                           t_mdatoms*                                               mdatoms,
+                           const t_fcdata&                                          fcdata,
+                           t_extmass*                                               MassQ,
+                           t_vcm*                                                   vcm,
+                           pull_t*                                                  pull_work,
+                           gmx_enerdata_t*                                          enerd,
+                           gmx_ekindata_t*                                          ekind,
+                           gmx_global_stat*                                         gstat,
+                           real*                                                    dvdl_constr,
+                           bool                                                     bCalcVir,
+                           tensor                                                   total_vir,
+                           tensor                                                   shake_vir,
+                           tensor                                                   force_vir,
+                           tensor                                                   pres,
+                           matrix                                                   M,
+                           matrix                                                   lastbox,
+                           bool                                                     do_log,
+                           bool                                                     do_ene,
+                           bool                                                     bGStat,
+                           bool*                                                    bSumEkinhOld,
+                           gmx::ForceBuffers*                                       f,
+                           std::vector<gmx::RVec>*                                  cbuf,
+                           gmx::Update*                                             upd,
+                           gmx::Constraints*                                        constr,
+                           gmx::SimulationSignaller*                                nullSignaller,
+                           gmx::EnumerationArray<TrotterSequence, std::vector<int>> trotter_seq,
+                           t_nrnb*                                                  nrnb,
+                           gmx_wallcycle*                                           wcycle)
 {
     /* velocity half-step update */
     upd->update_coords(*ir,
@@ -476,7 +476,7 @@ void integrateVVSecondStep(int64_t                                  step,
                        gmx::arrayRefFromArray(mdatoms->invmass, mdatoms->nr),
                        MassQ,
                        trotter_seq,
-                       ettTSEQ4);
+                       TrotterSequence::Four);
         /* now we know the scaling, we can compute the positions again */
         std::copy(cbuf->begin(), cbuf->end(), state->x.begin());
 
index 258f8bfd2b0b527412779ea9820f85a7eb44f77f..f900c470ab90df688ea9ca03004e99b165c1f5e1 100644 (file)
@@ -45,6 +45,7 @@
 
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/utility/enumerationhelpers.h"
 
 class gmx_ekindata_t;
 struct gmx_enerdata_t;
@@ -119,49 +120,49 @@ enum class StartingBehavior : int;
  * \param[in]  fplog             Another logger.
  * \param[in]  wcycle            Wall-clock cycle counter.
  */
-void integrateVVFirstStep(int64_t                                  step,
-                          bool                                     bFirstStep,
-                          bool                                     bInitStep,
-                          gmx::StartingBehavior                    startingBehavior,
-                          int                                      nstglobalcomm,
-                          const t_inputrec*                        ir,
-                          t_forcerec*                              fr,
-                          t_commrec*                               cr,
-                          t_state*                                 state,
-                          t_mdatoms*                               mdatoms,
-                          const t_fcdata&                          fcdata,
-                          t_extmass*                               MassQ,
-                          t_vcm*                                   vcm,
-                          const gmx_mtop_t&                        top_global,
-                          const gmx_localtop_t&                    top,
-                          gmx_enerdata_t*                          enerd,
-                          gmx_ekindata_t*                          ekind,
-                          gmx_global_stat*                         gstat,
-                          real*                                    last_ekin,
-                          bool                                     bCalcVir,
-                          tensor                                   total_vir,
-                          tensor                                   shake_vir,
-                          tensor                                   force_vir,
-                          tensor                                   pres,
-                          matrix                                   M,
-                          bool                                     do_log,
-                          bool                                     do_ene,
-                          bool                                     bCalcEner,
-                          bool                                     bGStat,
-                          bool                                     bStopCM,
-                          bool                                     bTrotter,
-                          bool                                     bExchanged,
-                          bool*                                    bSumEkinhOld,
-                          real*                                    saved_conserved_quantity,
-                          gmx::ForceBuffers*                       f,
-                          gmx::Update*                             upd,
-                          gmx::Constraints*                        constr,
-                          gmx::SimulationSignaller*                nullSignaller,
-                          std::array<std::vector<int>, ettTSEQMAX> trotter_seq,
-                          t_nrnb*                                  nrnb,
-                          const gmx::MDLogger&                     mdlog,
-                          FILE*                                    fplog,
-                          gmx_wallcycle*                           wcycle);
+void integrateVVFirstStep(int64_t                   step,
+                          bool                      bFirstStep,
+                          bool                      bInitStep,
+                          gmx::StartingBehavior     startingBehavior,
+                          int                       nstglobalcomm,
+                          const t_inputrec*         ir,
+                          t_forcerec*               fr,
+                          t_commrec*                cr,
+                          t_state*                  state,
+                          t_mdatoms*                mdatoms,
+                          const t_fcdata&           fcdata,
+                          t_extmass*                MassQ,
+                          t_vcm*                    vcm,
+                          const gmx_mtop_t&         top_global,
+                          const gmx_localtop_t&     top,
+                          gmx_enerdata_t*           enerd,
+                          gmx_ekindata_t*           ekind,
+                          gmx_global_stat*          gstat,
+                          real*                     last_ekin,
+                          bool                      bCalcVir,
+                          tensor                    total_vir,
+                          tensor                    shake_vir,
+                          tensor                    force_vir,
+                          tensor                    pres,
+                          matrix                    M,
+                          bool                      do_log,
+                          bool                      do_ene,
+                          bool                      bCalcEner,
+                          bool                      bGStat,
+                          bool                      bStopCM,
+                          bool                      bTrotter,
+                          bool                      bExchanged,
+                          bool*                     bSumEkinhOld,
+                          real*                     saved_conserved_quantity,
+                          gmx::ForceBuffers*        f,
+                          gmx::Update*              upd,
+                          gmx::Constraints*         constr,
+                          gmx::SimulationSignaller* nullSignaller,
+                          gmx::EnumerationArray<TrotterSequence, std::vector<int>> trotter_seq,
+                          t_nrnb*                                                  nrnb,
+                          const gmx::MDLogger&                                     mdlog,
+                          FILE*                                                    fplog,
+                          gmx_wallcycle*                                           wcycle);
 
 
 /*! \brief Make the second step of Velocity Verlet integration
@@ -200,39 +201,39 @@ void integrateVVFirstStep(int64_t                                  step,
  * \param[in]  nrnb              Cycle counters.
  * \param[in]  wcycle            Wall-clock cycle counter.
  */
-void integrateVVSecondStep(int64_t                                  step,
-                           const t_inputrec*                        ir,
-                           t_forcerec*                              fr,
-                           t_commrec*                               cr,
-                           t_state*                                 state,
-                           t_mdatoms*                               mdatoms,
-                           const t_fcdata&                          fcdata,
-                           t_extmass*                               MassQ,
-                           t_vcm*                                   vcm,
-                           pull_t*                                  pull_work,
-                           gmx_enerdata_t*                          enerd,
-                           gmx_ekindata_t*                          ekind,
-                           gmx_global_stat*                         gstat,
-                           real*                                    dvdl_constr,
-                           bool                                     bCalcVir,
-                           tensor                                   total_vir,
-                           tensor                                   shake_vir,
-                           tensor                                   force_vir,
-                           tensor                                   pres,
-                           matrix                                   M,
-                           matrix                                   lastbox,
-                           bool                                     do_log,
-                           bool                                     do_ene,
-                           bool                                     bGStat,
-                           bool*                                    bSumEkinhOld,
-                           gmx::ForceBuffers*                       f,
-                           std::vector<gmx::RVec>*                  cbuf,
-                           gmx::Update*                             upd,
-                           gmx::Constraints*                        constr,
-                           gmx::SimulationSignaller*                nullSignaller,
-                           std::array<std::vector<int>, ettTSEQMAX> trotter_seq,
-                           t_nrnb*                                  nrnb,
-                           gmx_wallcycle*                           wcycle);
+void integrateVVSecondStep(int64_t                                                  step,
+                           const t_inputrec*                                        ir,
+                           t_forcerec*                                              fr,
+                           t_commrec*                                               cr,
+                           t_state*                                                 state,
+                           t_mdatoms*                                               mdatoms,
+                           const t_fcdata&                                          fcdata,
+                           t_extmass*                                               MassQ,
+                           t_vcm*                                                   vcm,
+                           pull_t*                                                  pull_work,
+                           gmx_enerdata_t*                                          enerd,
+                           gmx_ekindata_t*                                          ekind,
+                           gmx_global_stat*                                         gstat,
+                           real*                                                    dvdl_constr,
+                           bool                                                     bCalcVir,
+                           tensor                                                   total_vir,
+                           tensor                                                   shake_vir,
+                           tensor                                                   force_vir,
+                           tensor                                                   pres,
+                           matrix                                                   M,
+                           matrix                                                   lastbox,
+                           bool                                                     do_log,
+                           bool                                                     do_ene,
+                           bool                                                     bGStat,
+                           bool*                                                    bSumEkinhOld,
+                           gmx::ForceBuffers*                                       f,
+                           std::vector<gmx::RVec>*                                  cbuf,
+                           gmx::Update*                                             upd,
+                           gmx::Constraints*                                        constr,
+                           gmx::SimulationSignaller*                                nullSignaller,
+                           gmx::EnumerationArray<TrotterSequence, std::vector<int>> trotter_seq,
+                           t_nrnb*                                                  nrnb,
+                           gmx_wallcycle*                                           wcycle);
 
 
 #endif // GMX_MDLIB_UPDATE_VV_H
index 68428e84ef0690f840481d09fb36a7b3f8bcea24..ed1f08b7912ffb8d9ceb965cd0a007b087592c79 100644 (file)
@@ -364,8 +364,8 @@ void gmx::LegacySimulator::do_md()
 
     ForceBuffers     f(fr->useMts,
                    ((useGpuForNonbonded && useGpuForBufferOps) || useGpuForUpdate)
-                               ? PinningPolicy::PinnedIfSupported
-                               : PinningPolicy::CannotBePinned);
+                           ? PinningPolicy::PinnedIfSupported
+                           : PinningPolicy::CannotBePinned);
     const t_mdatoms* md = mdAtoms->mdatoms();
     if (DOMAINDECOMP(cr))
     {
@@ -1298,7 +1298,7 @@ void gmx::LegacySimulator::do_md()
                                               state->v.rvec_array(),
                                               md->homenr,
                                               md->cTC ? gmx::arrayRefFromArray(md->cTC, md->nr)
-                                                                      : gmx::ArrayRef<const unsigned short>());
+                                                      : gmx::ArrayRef<const unsigned short>());
             /* history is maintained in state->dfhist, but state_global is what is sent to trajectory and log output */
             if (MASTER(cr))
             {
@@ -1440,7 +1440,7 @@ void gmx::LegacySimulator::do_md()
                            gmx::arrayRefFromArray(md->invmass, md->nr),
                            &MassQ,
                            trotter_seq,
-                           ettTSEQ3);
+                           TrotterSequence::Three);
             /* We can only do Berendsen coupling after we have summed
              * the kinetic energy or virial. Since the happens
              * in global_state after update, we should only do it at
index 37d67ea13cac7d8cb343817ff599953aafc6afb0..57f2a34b72b5cbdbeeb3cd0245ff9dcd092b5c92 100644 (file)
@@ -186,14 +186,14 @@ enum
 };
 
 //! Sequenced parts of the trotter decomposition.
-enum
+enum class TrotterSequence : int
 {
-    ettTSEQ0,
-    ettTSEQ1,
-    ettTSEQ2,
-    ettTSEQ3,
-    ettTSEQ4,
-    ettTSEQMAX
+    Zero,
+    One,
+    Two,
+    Three,
+    Four,
+    Count
 };
 
 //! Pressure coupling type