Refactor MTS preprocessing and checking
authorBerk Hess <hess@kth.se>
Mon, 19 Oct 2020 07:53:59 +0000 (09:53 +0200)
committerPaul Bauer <paul.bauer.q@gmail.com>
Tue, 27 Oct 2020 08:54:07 +0000 (08:54 +0000)
Separated the checking the MTS requirements on other parts of inputrec
from the setup of MTS. Replaced assertMtsRequirements() by the same
checking function. Added tests for all aspects of both the MTS setup
and requirements checking.

src/gromacs/gmxpreprocess/readir.cpp
src/gromacs/gmxpreprocess/readir.h
src/gromacs/mdlib/forcerec.cpp
src/gromacs/mdtypes/multipletimestepping.cpp
src/gromacs/mdtypes/multipletimestepping.h
src/gromacs/mdtypes/tests/CMakeLists.txt
src/gromacs/mdtypes/tests/multipletimestepping.cpp [new file with mode: 0644]

index 2edc32384d8d18298892a67c7b8fd88b918b74aa..cdb0ef93cd82214df4a860733b27a5e4d4ecbb1d 100644 (file)
@@ -232,102 +232,6 @@ static void process_interaction_modifier(int* eintmod)
     }
 }
 
-static void checkMtsRequirement(const t_inputrec& ir, const char* param, const int nstValue, warninp_t wi)
-{
-    GMX_RELEASE_ASSERT(ir.mtsLevels.size() >= 2, "Need at least two levels for MTS");
-    const int mtsFactor = ir.mtsLevels.back().stepFactor;
-    if (nstValue % mtsFactor != 0)
-    {
-        auto message = gmx::formatString(
-                "With MTS, %s = %d should be a multiple of mts-factor = %d", param, nstValue, mtsFactor);
-        warning_error(wi, message.c_str());
-    }
-}
-
-static void setupMtsLevels(gmx::ArrayRef<gmx::MtsLevel> mtsLevels,
-                           const t_inputrec&            ir,
-                           const t_gromppopts&          opts,
-                           warninp_t                    wi)
-{
-    if (!(ir.eI == eiMD || ir.eI == eiSD1))
-    {
-        auto message = gmx::formatString(
-                "Multiple time stepping is only supported with integrators %s and %s",
-                ei_names[eiMD], ei_names[eiSD1]);
-        warning_error(wi, message.c_str());
-    }
-    if (opts.numMtsLevels != 2)
-    {
-        warning_error(wi, "Only mts-levels = 2 is supported");
-    }
-    else
-    {
-        const std::vector<std::string> inputForceGroups = gmx::splitString(opts.mtsLevel2Forces);
-        auto&                          forceGroups      = mtsLevels[1].forceGroups;
-        for (const auto& inputForceGroup : inputForceGroups)
-        {
-            bool found     = false;
-            int  nameIndex = 0;
-            for (const auto& forceGroupName : gmx::mtsForceGroupNames)
-            {
-                if (gmx::equalCaseInsensitive(inputForceGroup, forceGroupName))
-                {
-                    forceGroups.set(nameIndex);
-                    found = true;
-                }
-                nameIndex++;
-            }
-            if (!found)
-            {
-                auto message =
-                        gmx::formatString("Unknown MTS force group '%s'", inputForceGroup.c_str());
-                warning_error(wi, message.c_str());
-            }
-        }
-
-        if (mtsLevels[1].stepFactor <= 1)
-        {
-            gmx_fatal(FARGS, "mts-factor should be larger than 1");
-        }
-
-        // Make the level 0 use the complement of the force groups of group 1
-        mtsLevels[0].forceGroups = ~mtsLevels[1].forceGroups;
-        mtsLevels[0].stepFactor  = 1;
-
-        if ((EEL_FULL(ir.coulombtype) || EVDW_PME(ir.vdwtype))
-            && !mtsLevels[1].forceGroups[static_cast<int>(gmx::MtsForceGroups::LongrangeNonbonded)])
-        {
-            warning_error(wi,
-                          "With long-range electrostatics and/or LJ treatment, the long-range part "
-                          "has to be part of the mts-level2-forces");
-        }
-
-        if (ir.nstcalcenergy > 0)
-        {
-            checkMtsRequirement(ir, "nstcalcenergy", ir.nstcalcenergy, wi);
-        }
-        checkMtsRequirement(ir, "nstenergy", ir.nstenergy, wi);
-        checkMtsRequirement(ir, "nstlog", ir.nstlog, wi);
-        if (ir.efep != efepNO)
-        {
-            checkMtsRequirement(ir, "nstdhdl", ir.fepvals->nstdhdl, wi);
-        }
-
-        if (ir.bPull)
-        {
-            const int pullMtsLevel = gmx::forceGroupMtsLevel(ir.mtsLevels, gmx::MtsForceGroups::Pull);
-            if (ir.pull->nstxout % ir.mtsLevels[pullMtsLevel].stepFactor != 0)
-            {
-                warning_error(wi, "pull-nstxout should be a multiple of mts-factor");
-            }
-            if (ir.pull->nstfout % ir.mtsLevels[pullMtsLevel].stepFactor != 0)
-            {
-                warning_error(wi, "pull-nstfout should be a multiple of mts-factor");
-            }
-        }
-    }
-}
-
 void check_ir(const char*                   mdparin,
               const gmx::MdModulesNotifier& mdModulesNotifier,
               t_inputrec*                   ir,
@@ -351,6 +255,19 @@ void check_ir(const char*                   mdparin,
 
     set_warning_line(wi, mdparin, -1);
 
+    /* We cannot check MTS requirements with an invalid MTS setup
+     * and we will already have generated errors with an invalid MTS setup.
+     */
+    if (gmx::haveValidMtsSetup(*ir))
+    {
+        std::vector<std::string> errorMessages = gmx::checkMtsRequirements(*ir);
+
+        for (const auto& errorMessage : errorMessages)
+        {
+            warning_error(wi, errorMessage.c_str());
+        }
+    }
+
     if (ir->coulombtype == eelRF_NEC_UNSUPPORTED)
     {
         sprintf(warn_buf, "%s electrostatics is no longer supported", eel_names[eelRF_NEC_UNSUPPORTED]);
@@ -1989,18 +1906,17 @@ void get_ir(const char*     mdparin,
     ir->useMts = (get_eeenum(&inp, "mts", yesno_names, wi) != 0);
     if (ir->useMts)
     {
-        opts->numMtsLevels = get_eint(&inp, "mts-levels", 2, wi);
+        gmx::GromppMtsOpts& mtsOpts = opts->mtsOpts;
+        mtsOpts.numLevels           = get_eint(&inp, "mts-levels", 2, wi);
         ir->mtsLevels.resize(2);
-        gmx::MtsLevel& mtsLevel = ir->mtsLevels[1];
-        opts->mtsLevel2Forces   = setStringEntry(&inp, "mts-level2-forces",
-                                               "longrange-nonbonded nonbonded pair dihedral");
-        mtsLevel.stepFactor     = get_eint(&inp, "mts-level2-factor", 2, wi);
+        mtsOpts.level2Forces = setStringEntry(&inp, "mts-level2-forces",
+                                              "longrange-nonbonded nonbonded pair dihedral");
+        mtsOpts.level2Factor = get_eint(&inp, "mts-level2-factor", 2, wi);
 
         // We clear after reading without dynamics to not force the user to remove MTS mdp options
         if (!EI_DYNAMICS(ir->eI))
         {
             ir->useMts = false;
-            ir->mtsLevels.clear();
         }
     }
     printStringNoNewline(&inp, "mode for center of mass motion removal");
@@ -2755,7 +2671,13 @@ void get_ir(const char*     mdparin,
     /* Set up MTS levels, this needs to happen before checking AWH parameters */
     if (ir->useMts)
     {
-        setupMtsLevels(ir->mtsLevels, *ir, *opts, wi);
+        std::vector<std::string> errorMessages;
+        ir->mtsLevels = gmx::setupMtsLevels(opts->mtsOpts, &errorMessages);
+
+        for (const auto& errorMessage : errorMessages)
+        {
+            warning_error(wi, errorMessage.c_str());
+        }
     }
 
     if (ir->bDoAwh)
@@ -4182,7 +4104,7 @@ static void check_combination_rules(const t_inputrec* ir, const gmx_mtop_t* mtop
 void triple_check(const char* mdparin, t_inputrec* ir, gmx_mtop_t* sys, warninp_t wi)
 {
     // Not meeting MTS requirements should have resulted in a fatal error, so we can assert here
-    gmx::assertMtsRequirements(*ir);
+    GMX_ASSERT(gmx::checkMtsRequirements(*ir).empty(), "All MTS requirements should be met here");
 
     char                      err_buf[STRLEN];
     int                       i, m, c, nmol;
index 65050546ff4003ed05215ab1662986c553a95a47..a936ef43d5c9263179043cd7d45664fad55cd8d3 100644 (file)
@@ -43,6 +43,7 @@
 
 #include "gromacs/fileio/readinp.h"
 #include "gromacs/math/vectypes.h"
+#include "gromacs/mdtypes/multipletimestepping.h"
 #include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/real.h"
 
@@ -87,23 +88,22 @@ enum
 
 struct t_gromppopts
 {
-    int         warnings     = 0;
-    int         nshake       = 0;
-    char*       include      = nullptr;
-    char*       define       = nullptr;
-    bool        bGenVel      = false;
-    bool        bGenPairs    = false;
-    real        tempi        = 0;
-    int         seed         = 0;
-    int         numMtsLevels = 0;
-    std::string mtsLevel2Forces;
-    bool        bOrire           = false;
-    bool        bMorse           = false;
-    char*       wall_atomtype[2] = { nullptr, nullptr };
-    char*       couple_moltype   = nullptr;
-    int         couple_lam0      = 0;
-    int         couple_lam1      = 0;
-    bool        bCoupleIntra     = false;
+    int                warnings  = 0;
+    int                nshake    = 0;
+    char*              include   = nullptr;
+    char*              define    = nullptr;
+    bool               bGenVel   = false;
+    bool               bGenPairs = false;
+    real               tempi     = 0;
+    int                seed      = 0;
+    gmx::GromppMtsOpts mtsOpts;
+    bool               bOrire           = false;
+    bool               bMorse           = false;
+    char*              wall_atomtype[2] = { nullptr, nullptr };
+    char*              couple_moltype   = nullptr;
+    int                couple_lam0      = 0;
+    int                couple_lam1      = 0;
+    bool               bCoupleIntra     = false;
 };
 
 /*! \brief Initialise object to hold strings parsed from an .mdp file */
index 0a9f0dc89f6d760a31fcef32f4cdacd9a7982daf..ad0e645c4e19cc40309c7be7ac9a2de652c0b64c 100644 (file)
@@ -1163,7 +1163,8 @@ void init_forcerec(FILE*                            fp,
 
     if (fr->useMts)
     {
-        gmx::assertMtsRequirements(*ir);
+        GMX_ASSERT(gmx::checkMtsRequirements(*ir).empty(),
+                   "All MTS requirements should be met here");
     }
 
     const bool haveDirectVirialContributionsFast =
index 02bcf41be33c7f765b74dd9de4eb892a8bb1d52c..ef25f9cd0e4c4b67e3ae3ed9683ee233afaf652d 100644 (file)
 
 #include "multipletimestepping.h"
 
+#include <optional>
+
 #include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/mdtypes/pull_params.h"
+#include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/gmxassert.h"
+#include "gromacs/utility/stringutil.h"
 
 namespace gmx
 {
@@ -56,41 +61,163 @@ int nonbondedMtsFactor(const t_inputrec& ir)
     }
 }
 
-void assertMtsRequirements(const t_inputrec& ir)
+std::vector<MtsLevel> setupMtsLevels(const GromppMtsOpts& mtsOpts, std::vector<std::string>* errorMessages)
+{
+    std::vector<MtsLevel> mtsLevels;
+
+    if (mtsOpts.numLevels != 2)
+    {
+        if (errorMessages)
+        {
+            errorMessages->push_back("Only mts-levels = 2 is supported");
+        }
+    }
+    else
+    {
+        mtsLevels.resize(2);
+
+        const std::vector<std::string> inputForceGroups = gmx::splitString(mtsOpts.level2Forces);
+        auto&                          forceGroups      = mtsLevels[1].forceGroups;
+        for (const auto& inputForceGroup : inputForceGroups)
+        {
+            bool found     = false;
+            int  nameIndex = 0;
+            for (const auto& forceGroupName : gmx::mtsForceGroupNames)
+            {
+                if (gmx::equalCaseInsensitive(inputForceGroup, forceGroupName))
+                {
+                    forceGroups.set(nameIndex);
+                    found = true;
+                }
+                nameIndex++;
+            }
+            if (!found && errorMessages)
+            {
+                errorMessages->push_back(
+                        gmx::formatString("Unknown MTS force group '%s'", inputForceGroup.c_str()));
+            }
+        }
+
+        // Make the level 0 use the complement of the force groups of group 1
+        mtsLevels[0].forceGroups = ~mtsLevels[1].forceGroups;
+        mtsLevels[0].stepFactor  = 1;
+
+        mtsLevels[1].stepFactor = mtsOpts.level2Factor;
+
+        if (errorMessages && mtsLevels[1].stepFactor <= 1)
+        {
+            errorMessages->push_back("mts-factor should be larger than 1");
+        }
+    }
+
+    return mtsLevels;
+}
+
+bool haveValidMtsSetup(const t_inputrec& ir)
+{
+    return (ir.useMts && ir.mtsLevels.size() == 2 && ir.mtsLevels[1].stepFactor > 1);
+}
+
+namespace
+{
+
+//! Checks whether \p nstValue is a multiple of the largest MTS step, returns an error string for parameter \p param when this is not the case
+std::optional<std::string> checkMtsInterval(ArrayRef<const MtsLevel> mtsLevels, const char* param, const int nstValue)
+{
+    GMX_RELEASE_ASSERT(mtsLevels.size() >= 2, "Need at least two levels for MTS");
+
+    const int mtsFactor = mtsLevels.back().stepFactor;
+    if (nstValue % mtsFactor == 0)
+    {
+        return {};
+    }
+    else
+    {
+        return gmx::formatString("With MTS, %s = %d should be a multiple of mts-factor = %d", param,
+                                 nstValue, mtsFactor);
+    }
+}
+
+} // namespace
+
+std::vector<std::string> checkMtsRequirements(const t_inputrec& ir)
 {
+    std::vector<std::string> errorMessages;
+
     if (!ir.useMts)
     {
-        return;
+        return errorMessages;
     }
 
-    GMX_RELEASE_ASSERT(ir.mtsLevels.size() >= 2, "Need at least two levels for MTS");
+    GMX_RELEASE_ASSERT(haveValidMtsSetup(ir), "MTS setup should be valid here");
+
+    ArrayRef<const MtsLevel> mtsLevels = ir.mtsLevels;
+
+    if (!(ir.eI == eiMD || ir.eI == eiSD1))
+    {
+        errorMessages.push_back(gmx::formatString(
+                "Multiple time stepping is only supported with integrators %s and %s",
+                ei_names[eiMD], ei_names[eiSD1]));
+    }
 
-    GMX_RELEASE_ASSERT(ir.mtsLevels[0].stepFactor == 1, "Base MTS step should be 1");
+    if ((EEL_FULL(ir.coulombtype) || EVDW_PME(ir.vdwtype))
+        && forceGroupMtsLevel(ir.mtsLevels, MtsForceGroups::LongrangeNonbonded) == 0)
+    {
+        errorMessages.emplace_back(
+                "With long-range electrostatics and/or LJ treatment, the long-range part "
+                "has to be part of the mts-level2-forces");
+    }
 
-    GMX_RELEASE_ASSERT((!EEL_FULL(ir.coulombtype) && !EVDW_PME(ir.vdwtype))
-                               || forceGroupMtsLevel(ir.mtsLevels, MtsForceGroups::LongrangeNonbonded) > 0,
-                       "Long-range nonbondeds should be in the highest MTS level");
+    std::optional<std::string> mesg;
+    if (ir.nstcalcenergy > 0)
+    {
+        if ((mesg = checkMtsInterval(mtsLevels, "nstcalcenergy", ir.nstcalcenergy)))
+        {
+            errorMessages.push_back(mesg.value());
+        }
+    }
+    if ((mesg = checkMtsInterval(mtsLevels, "nstenergy", ir.nstenergy)))
+    {
+        errorMessages.push_back(mesg.value());
+    }
+    if ((mesg = checkMtsInterval(mtsLevels, "nstlog", ir.nstlog)))
+    {
+        errorMessages.push_back(mesg.value());
+    }
+    if ((mesg = checkMtsInterval(mtsLevels, "nstfout", ir.nstfout)))
+    {
+        errorMessages.push_back(mesg.value());
+    }
+    if (ir.efep != efepNO)
+    {
+        if ((mesg = checkMtsInterval(mtsLevels, "nstdhdl", ir.fepvals->nstdhdl)))
+        {
+            errorMessages.push_back(mesg.value());
+        }
+    }
+    if (mtsLevels.back().forceGroups[static_cast<int>(gmx::MtsForceGroups::Nonbonded)])
+    {
+        if ((mesg = checkMtsInterval(mtsLevels, "nstlist", ir.nstlist)))
+        {
+            errorMessages.push_back(mesg.value());
+        }
+    }
 
-    for (const auto& mtsLevel : ir.mtsLevels)
+    if (ir.bPull)
     {
-        const int mtsFactor = mtsLevel.stepFactor;
-        GMX_RELEASE_ASSERT(ir.nstcalcenergy % mtsFactor == 0,
-                           "nstcalcenergy should be a multiple of mtsFactor");
-        GMX_RELEASE_ASSERT(ir.nstenergy % mtsFactor == 0,
-                           "nstenergy should be a multiple of mtsFactor");
-        GMX_RELEASE_ASSERT(ir.nstlog % mtsFactor == 0, "nstlog should be a multiple of mtsFactor");
-        GMX_RELEASE_ASSERT(ir.epc == epcNO || ir.nstpcouple % mtsFactor == 0,
-                           "nstpcouple should be a multiple of mtsFactor");
-        GMX_RELEASE_ASSERT(ir.efep == efepNO || ir.fepvals->nstdhdl % mtsFactor == 0,
-                           "nstdhdl should be a multiple of mtsFactor");
-        if (ir.mtsLevels.back().forceGroups[static_cast<int>(gmx::MtsForceGroups::Nonbonded)])
+        const int pullMtsLevel  = gmx::forceGroupMtsLevel(ir.mtsLevels, gmx::MtsForceGroups::Pull);
+        const int mtsStepFactor = ir.mtsLevels[pullMtsLevel].stepFactor;
+        if (ir.pull->nstxout % mtsStepFactor != 0)
+        {
+            errorMessages.emplace_back("pull-nstxout should be a multiple of mts-factor");
+        }
+        if (ir.pull->nstfout % mtsStepFactor != 0)
         {
-            GMX_RELEASE_ASSERT(ir.nstlist % ir.mtsLevels.back().stepFactor == 0,
-                               "With multiple time stepping for the non-bonded pair interactions, "
-                               "nstlist should be a "
-                               "multiple of mtsFactor");
+            errorMessages.emplace_back("pull-nstfout should be a multiple of mts-factor");
         }
     }
+
+    return errorMessages;
 }
 
 } // namespace gmx
index 5796114005aea21d695b4281488294ac672efd4d..f59a5b49f8d7eeab4b6d739f13d33ac33dc139f2 100644 (file)
@@ -35,6 +35,9 @@
 #ifndef GMX_MULTIPLETIMESTEPPING_H
 #define GMX_MULTIPLETIMESTEPPING_H
 
+#include <string>
+#include <vector>
+
 #include <bitset>
 
 #include "gromacs/utility/arrayref.h"
@@ -58,6 +61,7 @@ enum class MtsForceGroups : int
     Count               //!< The number of groups above
 };
 
+//! Names for the MTS force groups
 static const gmx::EnumerationArray<MtsForceGroups, std::string> mtsForceGroupNames = {
     "longrange-nonbonded", "nonbonded", "pair", "dihedral", "angle", "pull", "awh"
 };
@@ -89,8 +93,42 @@ static inline int forceGroupMtsLevel(ArrayRef<const MtsLevel> mtsLevels, const M
  */
 int nonbondedMtsFactor(const t_inputrec& ir);
 
-//! (Release) Asserts that all multiple time-stepping requirements on \p ir are fulfilled
-void assertMtsRequirements(const t_inputrec& ir);
+//! Struct for passing the MTS mdp options to setupMtsLevels()
+struct GromppMtsOpts
+{
+    //! The number of MTS levels
+    int numLevels = 0;
+    //! The names of the force groups assigned by the user to level 2, internal index 1
+    std::string level2Forces;
+    //! The step factor assigned by the user to level 2, internal index 1
+    int level2Factor = 0;
+};
+
+/*! \brief Sets up and returns the MTS levels and checks requirements of MTS
+ *
+ * Appends errors about allowed input values ir to errorMessages, when not nullptr.
+ *
+ * \param[in]     mtsOpts        Options for setting the MTS levels
+ * \param[in,out] errorMessages  List of error messages, can be nullptr
+ */
+std::vector<MtsLevel> setupMtsLevels(const GromppMtsOpts& mtsOpts, std::vector<std::string>* errorMessages);
+
+/*! \brief Returns whether we use MTS and the MTS setup is internally valid
+ *
+ * Note that setupMtsLevels would have returned at least one error message
+ * when this function returns false
+ */
+bool haveValidMtsSetup(const t_inputrec& ir);
+
+/*! \brief Checks whether the MTS requirements on other algorithms and output frequencies are met
+ *
+ * Note: exits with an assertion failure when
+ * ir.useMts == true && haveValidMtsSetup(ir) == false
+ *
+ * \param[in] ir  Complete input record
+ * \returns list of error messages, empty when all MTS requirements are met
+ */
+std::vector<std::string> checkMtsRequirements(const t_inputrec& ir);
 
 } // namespace gmx
 
index 69e1b6b00b17e8bebf82e474a0a5d423843ab274..90d37b1d51389edbc0d15720cd30f7b4f16e7d8d 100644 (file)
@@ -36,4 +36,5 @@ gmx_add_unit_test(MdtypesUnitTest mdtypes-test
     CPP_SOURCE_FILES
         checkpointdata.cpp
         forcebuffers.cpp
+        multipletimestepping.cpp
         )
diff --git a/src/gromacs/mdtypes/tests/multipletimestepping.cpp b/src/gromacs/mdtypes/tests/multipletimestepping.cpp
new file mode 100644 (file)
index 0000000..f69a45a
--- /dev/null
@@ -0,0 +1,232 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2020, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+/*! \internal \file
+ * \brief
+ * Tests for the MultipleTimeStepping class and stand-alone functions.
+ *
+ * \author berk Hess <hess@kth.se>
+ * \ingroup module_mdtypes
+ */
+#include "gmxpre.h"
+
+#include "gromacs/mdtypes/multipletimestepping.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/utility/gmxassert.h"
+#include "gromacs/utility/smalloc.h"
+
+#include "testutils/testasserts.h"
+
+namespace gmx
+{
+
+namespace test
+{
+
+namespace
+{
+
+//! brief Sets up the MTS levels in \p ir and tests whether the number of errors matches \p numExpectedErrors
+void setAndCheckMtsLevels(const GromppMtsOpts& mtsOpts, t_inputrec* ir, const int numExpectedErrors)
+{
+    std::vector<std::string> errorMessages;
+    ir->useMts    = true;
+    ir->mtsLevels = setupMtsLevels(mtsOpts, &errorMessages);
+
+    if (haveValidMtsSetup(*ir))
+    {
+        std::vector<std::string> errorMessagesCheck = checkMtsRequirements(*ir);
+
+        // Concatenate the two lists with error messages
+        errorMessages.insert(errorMessages.end(), errorMessagesCheck.begin(), errorMessagesCheck.end());
+    }
+
+    EXPECT_EQ(errorMessages.size(), numExpectedErrors);
+}
+
+} // namespace
+
+//! Checks that only numLevels = 2 does not produce an error
+TEST(MultipleTimeStepping, ChecksNumLevels)
+{
+    for (int numLevels = 0; numLevels <= 3; numLevels++)
+    {
+        GromppMtsOpts mtsOpts;
+        mtsOpts.numLevels    = numLevels;
+        mtsOpts.level2Factor = 2;
+
+        t_inputrec ir;
+
+        setAndCheckMtsLevels(mtsOpts, &ir, numLevels != 2 ? 1 : 0);
+    }
+}
+
+//! Test that each force group works
+TEST(MultipleTimeStepping, SelectsForceGroups)
+{
+    for (int forceGroupIndex = 0; forceGroupIndex < static_cast<int>(MtsForceGroups::Count);
+         forceGroupIndex++)
+    {
+        const MtsForceGroups forceGroup = static_cast<MtsForceGroups>(forceGroupIndex);
+        SCOPED_TRACE("Testing force group " + mtsForceGroupNames[forceGroup]);
+
+        GromppMtsOpts mtsOpts;
+        mtsOpts.numLevels    = 2;
+        mtsOpts.level2Forces = mtsForceGroupNames[forceGroup];
+        mtsOpts.level2Factor = 2;
+
+        t_inputrec ir;
+
+        setAndCheckMtsLevels(mtsOpts, &ir, 0);
+
+        EXPECT_EQ(ir.mtsLevels[1].forceGroups.count(), 1);
+        EXPECT_EQ(ir.mtsLevels[1].forceGroups[forceGroupIndex], true);
+    }
+}
+
+//! Checks that factor is checked
+TEST(MultipleTimeStepping, ChecksStepFactor)
+{
+    for (int stepFactor = 0; stepFactor <= 3; stepFactor++)
+    {
+        GromppMtsOpts mtsOpts;
+        mtsOpts.numLevels    = 2;
+        mtsOpts.level2Factor = stepFactor;
+
+        t_inputrec ir;
+
+        setAndCheckMtsLevels(mtsOpts, &ir, stepFactor < 2 ? 1 : 0);
+    }
+}
+
+namespace
+{
+
+GromppMtsOpts simpleMtsOpts()
+{
+    GromppMtsOpts mtsOpts;
+    mtsOpts.numLevels    = 2;
+    mtsOpts.level2Forces = "nonbonded";
+    mtsOpts.level2Factor = 4;
+
+    return mtsOpts;
+}
+
+} // namespace
+
+TEST(MultipleTimeStepping, ChecksPmeIsAtLastLevel)
+{
+    const GromppMtsOpts mtsOpts = simpleMtsOpts();
+
+    t_inputrec ir;
+    ir.coulombtype = eelPME;
+
+    setAndCheckMtsLevels(mtsOpts, &ir, 1);
+}
+
+//! Test fixture base for parametrizing interval tests
+using MtsIntervalTestParams = std::tuple<std::string, int>;
+class MtsIntervalTest : public ::testing::Test, public ::testing::WithParamInterface<MtsIntervalTestParams>
+{
+public:
+    MtsIntervalTest()
+    {
+        const auto  params        = GetParam();
+        const auto& parameterName = std::get<0>(params);
+        const auto  interval      = std::get<1>(params);
+        numExpectedErrors_        = (interval == 4 ? 0 : 1);
+
+        if (parameterName == "nstcalcenergy")
+        {
+            ir_.nstcalcenergy = interval;
+        }
+        else if (parameterName == "nstenergy")
+        {
+            ir_.nstenergy = interval;
+        }
+        else if (parameterName == "nstfout")
+        {
+            ir_.nstfout = interval;
+        }
+        else if (parameterName == "nstlist")
+        {
+            ir_.nstlist = interval;
+        }
+        else if (parameterName == "nstdhdl")
+        {
+            ir_.efep             = efepYES;
+            ir_.fepvals->nstdhdl = interval;
+        }
+        else
+
+        {
+            GMX_RELEASE_ASSERT(false, "unknown parameter name");
+        }
+    }
+
+    t_inputrec ir_;
+    int        numExpectedErrors_;
+};
+
+TEST_P(MtsIntervalTest, Works)
+{
+    const GromppMtsOpts mtsOpts = simpleMtsOpts();
+
+    setAndCheckMtsLevels(mtsOpts, &ir_, numExpectedErrors_);
+}
+
+INSTANTIATE_TEST_CASE_P(
+        ChecksStepInterval,
+        MtsIntervalTest,
+        ::testing::Combine(
+                ::testing::Values("nstcalcenergy", "nstenergy", "nstfout", "nstlist", "nstdhdl"),
+                ::testing::Values(3, 4, 5)));
+
+// Check that correct input does not produce errors
+TEST(MultipleTimeStepping, ChecksIntegrator)
+{
+    const GromppMtsOpts mtsOpts = simpleMtsOpts();
+
+    t_inputrec ir;
+    ir.eI = eiBD;
+
+    setAndCheckMtsLevels(mtsOpts, &ir, 1);
+}
+
+} // namespace test
+} // namespace gmx