Move computeSlowForces into stepWork
[alexxy/gromacs.git] / src / gromacs / mdtypes / inputrec.cpp
index e1f25314bd22b8e281d7ad94f8cb34dd23afba03..4f0ec9ca8feb62887dc2e51971ce07ae6f48e846 100644 (file)
 #include <cstring>
 
 #include <algorithm>
+#include <numeric>
 
 #include "gromacs/math/veccompare.h"
 #include "gromacs/math/vecdump.h"
 #include "gromacs/mdtypes/awh_params.h"
 #include "gromacs/mdtypes/md_enums.h"
+#include "gromacs/mdtypes/multipletimestepping.h"
 #include "gromacs/mdtypes/pull_params.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/utility/compare.h"
@@ -75,7 +77,7 @@ const int nstmin_berendsen_tcouple = 5;
 const int nstmin_berendsen_pcouple = 10;
 const int nstmin_harmonic          = 20;
 
-/* Default values for nst T- and P- coupling intervals, used when the are no other
+/* Default values for T- and P- coupling intervals, used when the are no other
  * restrictions.
  */
 constexpr int c_defaultNstTCouple = 10;
@@ -109,6 +111,11 @@ int ir_optimal_nstcalcenergy(const t_inputrec* ir)
         nst = c_defaultNstCalcEnergy;
     }
 
+    if (ir->useMts)
+    {
+        nst = std::lcm(nst, ir->mtsLevels.back().stepFactor);
+    }
+
     return nst;
 }
 
@@ -196,29 +203,41 @@ int pcouple_min_integration_steps(int epc)
 
 int ir_optimal_nstpcouple(const t_inputrec* ir)
 {
-    int nmin, nwanted, n;
+    const int minIntegrationSteps = pcouple_min_integration_steps(ir->epc);
 
-    nmin = pcouple_min_integration_steps(ir->epc);
+    const int nwanted = c_defaultNstPCouple;
 
-    nwanted = c_defaultNstPCouple;
+    // With multiple time-stepping we can only compute the pressure at slowest steps
+    const int minNstPCouple = (ir->useMts ? ir->mtsLevels.back().stepFactor : 1);
 
-    if (nmin == 0 || ir->delta_t * nwanted <= ir->tau_p)
+    int n;
+    if (minIntegrationSteps == 0 || ir->delta_t * nwanted <= ir->tau_p)
     {
         n = nwanted;
     }
     else
     {
-        n = static_cast<int>(ir->tau_p / (ir->delta_t * nmin) + 0.001);
-        if (n < 1)
+        n = static_cast<int>(ir->tau_p / (ir->delta_t * minIntegrationSteps) + 0.001);
+        if (n < minNstPCouple)
         {
-            n = 1;
+            n = minNstPCouple;
         }
-        while (nwanted % n != 0)
+        // Without MTS we try to make nstpcouple a "nice" number
+        if (!ir->useMts)
         {
-            n--;
+            while (nwanted % n != 0)
+            {
+                n--;
+            }
         }
     }
 
+    // With MTS, nstpcouple should be a multiple of the slowest MTS interval
+    if (ir->useMts)
+    {
+        n = n - (n % minNstPCouple);
+    }
+
     return n;
 }
 
@@ -819,6 +838,30 @@ void pr_inputrec(FILE* fp, int indent, const char* title, const t_inputrec* ir,
         PSTEP("nsteps", ir->nsteps);
         PSTEP("init-step", ir->init_step);
         PI("simulation-part", ir->simulation_part);
+        PS("mts", EBOOL(ir->useMts));
+        if (ir->useMts)
+        {
+            for (int mtsIndex = 1; mtsIndex < static_cast<int>(ir->mtsLevels.size()); mtsIndex++)
+            {
+                const auto&       mtsLevel = ir->mtsLevels[mtsIndex];
+                const std::string forceKey = gmx::formatString("mts-level%d-forces", mtsIndex + 1);
+                std::string       forceGroups;
+                for (int i = 0; i < static_cast<int>(gmx::MtsForceGroups::Count); i++)
+                {
+                    if (mtsLevel.forceGroups[i])
+                    {
+                        if (!forceGroups.empty())
+                        {
+                            forceGroups += " ";
+                        }
+                        forceGroups += gmx::mtsForceGroupNames[i];
+                    }
+                }
+                PS(forceKey.c_str(), forceGroups.c_str());
+                const std::string factorKey = gmx::formatString("mts-level%d-factor", mtsIndex + 1);
+                PI(factorKey.c_str(), mtsLevel.stepFactor);
+            }
+        }
         PS("comm-mode", ECOM(ir->comm_mode));
         PI("nstcomm", ir->nstcomm);
 
@@ -1287,6 +1330,15 @@ void cmp_inputrec(FILE* fp, const t_inputrec* ir1, const t_inputrec* ir2, real f
     cmp_int64(fp, "inputrec->nsteps", ir1->nsteps, ir2->nsteps);
     cmp_int64(fp, "inputrec->init_step", ir1->init_step, ir2->init_step);
     cmp_int(fp, "inputrec->simulation_part", -1, ir1->simulation_part, ir2->simulation_part);
+    cmp_int(fp, "inputrec->mts", -1, static_cast<int>(ir1->useMts), static_cast<int>(ir2->useMts));
+    if (ir1->useMts && ir2->useMts)
+    {
+        cmp_int(fp, "inputrec->mts-levels", -1, ir1->mtsLevels.size(), ir2->mtsLevels.size());
+        cmp_int(fp, "inputrec->mts-level2-forces", -1, ir1->mtsLevels[1].forceGroups.to_ulong(),
+                ir2->mtsLevels[1].forceGroups.to_ulong());
+        cmp_int(fp, "inputrec->mts-level2-factor", -1, ir1->mtsLevels[1].stepFactor,
+                ir2->mtsLevels[1].stepFactor);
+    }
     cmp_int(fp, "inputrec->pbcType", -1, static_cast<int>(ir1->pbcType), static_cast<int>(ir2->pbcType));
     cmp_bool(fp, "inputrec->bPeriodicMols", -1, ir1->bPeriodicMols, ir2->bPeriodicMols);
     cmp_int(fp, "inputrec->cutoff_scheme", -1, ir1->cutoff_scheme, ir2->cutoff_scheme);