Fix trx frame reading from modular simulator checkpoint files
authorPascal Merz <pascal.merz@me.com>
Fri, 11 Dec 2020 12:22:44 +0000 (12:22 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Fri, 11 Dec 2020 12:22:44 +0000 (12:22 +0000)
trx frames from checkpoints written by modular simulator could be
slightly off depending on the algorithms used (always up-to-date with
domain decomposition only). The global state is still written to
modular checkpoints (as there is a significant amount of setup work at
runner time that needs to be refactored before modular simulator can
get spawned completely without a global state). It is, however, not
guaranteed to be absolutely up to date, as it's not used in the
simulation continuation.

The trx frame reading is now handled by modular simulator if the
checkpoint was written by modular simulator. A test was added to
ensure that the final frame of the simulation is equivalent to the
checkpointed configuration.

Note that this was not affecting simulation restart (exact
continuation is tested also for modular simulator). It would, however,
affect other tools accessing the configuration from the checkpoint
file. The error was only present in master and 2021 beta, so it was
never in a released version.

trx frame reading by modular simulator required some refactoring of
the StatePropagatorData and FreeEnergyPerturbationData
checkpointing. The main read and write function was moved to the data
object itself rather than being owned by its element. This is likely a
better design decision anyway: The element orchestrates the
checkpointing during the simulations, but the data object is handling
the actual data reading / writing. This allows to create a dummy data
object, restore it from checkpoint, and read out the quantities
relevant for the trx frame.

Closes #3838

src/gromacs/fileio/checkpoint.cpp
src/gromacs/modularsimulator/freeenergyperturbationdata.cpp
src/gromacs/modularsimulator/freeenergyperturbationdata.h
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/modularsimulator.h
src/gromacs/modularsimulator/statepropagatordata.cpp
src/gromacs/modularsimulator/statepropagatordata.h
src/programs/mdrun/tests/CMakeLists.txt
src/programs/mdrun/tests/checkpoint.cpp [new file with mode: 0644]

index d5c116e1c88681c73536c6d0ac332cd6d1207806..600ca0bab8e093adf09f45720721271eb57b56b4 100644 (file)
@@ -72,6 +72,7 @@
 #include "gromacs/mdtypes/pullhistory.h"
 #include "gromacs/mdtypes/state.h"
 #include "gromacs/mdtypes/swaphistory.h"
+#include "gromacs/modularsimulator/modularsimulator.h"
 #include "gromacs/trajectory/trajectoryframe.h"
 #include "gromacs/utility/arrayref.h"
 #include "gromacs/utility/baseversion.h"
@@ -2816,7 +2817,8 @@ void read_checkpoint_part_and_step(const char* filename, int* simulation_part, i
 
 static CheckpointHeaderContents read_checkpoint_data(t_fileio*                         fp,
                                                      t_state*                          state,
-                                                     std::vector<gmx_file_position_t>* outputfiles)
+                                                     std::vector<gmx_file_position_t>* outputfiles,
+                                                     gmx::ReadCheckpointDataHolder* modularSimulatorCheckpointData)
 {
     CheckpointHeaderContents headerContents;
     do_cpt_header(gmx_fio_getxdr(fp), TRUE, nullptr, &headerContents);
@@ -2887,11 +2889,9 @@ static CheckpointHeaderContents read_checkpoint_data(t_fileio*
     do_cpt_mdmodules(headerContents.file_version, fp, mdModuleNotifier, nullptr);
     if (headerContents.file_version >= cptv_ModularSimulator)
     {
-        // In the scope of the current function, we can just throw away the content
-        // of the modular checkpoint, but we need to read it to move the file pointer
-        gmx::FileIOXdrSerializer      serializer(fp);
-        gmx::ReadCheckpointDataHolder modularSimulatorCheckpointData;
-        modularSimulatorCheckpointData.deserialize(&serializer);
+        // Store modular checkpoint data into modularSimulatorCheckpointData
+        gmx::FileIOXdrSerializer serializer(fp);
+        modularSimulatorCheckpointData->deserialize(&serializer);
     }
     ret = do_cpt_footer(gmx_fio_getxdr(fp), headerContents.file_version);
     if (ret)
@@ -2905,7 +2905,14 @@ void read_checkpoint_trxframe(t_fileio* fp, t_trxframe* fr)
 {
     t_state                          state;
     std::vector<gmx_file_position_t> outputfiles;
-    CheckpointHeaderContents headerContents = read_checkpoint_data(fp, &state, &outputfiles);
+    gmx::ReadCheckpointDataHolder    modularSimulatorCheckpointData;
+    CheckpointHeaderContents         headerContents =
+            read_checkpoint_data(fp, &state, &outputfiles, &modularSimulatorCheckpointData);
+    if (headerContents.isModularSimulatorCheckpoint)
+    {
+        gmx::ModularSimulator::readCheckpointToTrxFrame(fr, &modularSimulatorCheckpointData, headerContents);
+        return;
+    }
 
     fr->natoms    = state.natoms;
     fr->bStep     = TRUE;
@@ -3027,8 +3034,10 @@ void list_checkpoint(const char* fn, FILE* out)
 CheckpointHeaderContents read_checkpoint_simulation_part_and_filenames(t_fileio* fp,
                                                                        std::vector<gmx_file_position_t>* outputfiles)
 {
-    t_state                  state;
-    CheckpointHeaderContents headerContents = read_checkpoint_data(fp, &state, outputfiles);
+    t_state                       state;
+    gmx::ReadCheckpointDataHolder modularSimulatorCheckpointData;
+    CheckpointHeaderContents      headerContents =
+            read_checkpoint_data(fp, &state, outputfiles, &modularSimulatorCheckpointData);
     if (gmx_fio_close(fp) != 0)
     {
         gmx_file("Cannot read/write checkpoint; corrupt file, or maybe you are out of disk space?");
index 28d5e9d169294a4a08ffbcc37462be02d22bd40d..77037bf2426d667f2b3a26e265ab17a46a83d4e7 100644 (file)
@@ -52,6 +52,7 @@
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/state.h"
+#include "gromacs/trajectory/trajectoryframe.h"
 
 #include "modularsimulator.h"
 #include "simulatoralgorithm.h"
@@ -128,13 +129,12 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 } // namespace
 
 template<CheckpointDataOperation operation>
-void FreeEnergyPerturbationData::Element::doCheckpointData(CheckpointData<operation>* checkpointData)
+void FreeEnergyPerturbationData::doCheckpointData(CheckpointData<operation>* checkpointData)
 {
     checkpointVersion(checkpointData, "FreeEnergyPerturbationData version", c_currentVersion);
 
-    checkpointData->scalar("current FEP state", &freeEnergyPerturbationData_->currentFEPState_);
-    checkpointData->arrayRef("lambda vector",
-                             makeCheckpointArrayRef<operation>(freeEnergyPerturbationData_->lambda_));
+    checkpointData->scalar("current FEP state", &currentFEPState_);
+    checkpointData->arrayRef("lambda vector", makeCheckpointArrayRef<operation>(lambda_));
 }
 
 void FreeEnergyPerturbationData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
@@ -142,7 +142,8 @@ void FreeEnergyPerturbationData::Element::saveCheckpointState(std::optional<Writ
 {
     if (MASTER(cr))
     {
-        doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+        freeEnergyPerturbationData_->doCheckpointData<CheckpointDataOperation::Write>(
+                &checkpointData.value());
     }
 }
 
@@ -151,7 +152,8 @@ void FreeEnergyPerturbationData::Element::restoreCheckpointState(std::optional<R
 {
     if (MASTER(cr))
     {
-        doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
+        freeEnergyPerturbationData_->doCheckpointData<CheckpointDataOperation::Read>(
+                &checkpointData.value());
     }
     if (DOMAINDECOMP(cr))
     {
@@ -163,7 +165,7 @@ void FreeEnergyPerturbationData::Element::restoreCheckpointState(std::optional<R
 
 const std::string& FreeEnergyPerturbationData::Element::clientID()
 {
-    return identifier_;
+    return FreeEnergyPerturbationData::checkpointID();
 }
 
 FreeEnergyPerturbationData::Element::Element(FreeEnergyPerturbationData* freeEnergyPerturbationElement,
@@ -194,4 +196,28 @@ ISimulatorElement* FreeEnergyPerturbationData::Element::getElementPointerImpl(
     return freeEnergyPerturbationData->element();
 }
 
+void FreeEnergyPerturbationData::readCheckpointToTrxFrame(t_trxframe* trxFrame,
+                                                          std::optional<ReadCheckpointData> readCheckpointData)
+{
+    if (readCheckpointData)
+    {
+        FreeEnergyPerturbationData freeEnergyPerturbationData;
+        freeEnergyPerturbationData.doCheckpointData(&readCheckpointData.value());
+        trxFrame->lambda    = freeEnergyPerturbationData.lambda_[efptFEP];
+        trxFrame->fep_state = freeEnergyPerturbationData.currentFEPState_;
+    }
+    else
+    {
+        trxFrame->lambda    = 0;
+        trxFrame->fep_state = 0;
+    }
+    trxFrame->bLambda = true;
+}
+
+const std::string& FreeEnergyPerturbationData::checkpointID()
+{
+    static const std::string identifier = "FreeEnergyPerturbationData";
+    return identifier;
+}
+
 } // namespace gmx
index b1291dfde672b3a3004c4fbe01c4b7291e8827b4..25c2a52916e6bfaa2b6430ccfa5205fdd10d1515 100644 (file)
@@ -51,6 +51,7 @@
 #include "modularsimulatorinterfaces.h"
 
 struct t_inputrec;
+struct t_trxframe;
 
 namespace gmx
 {
@@ -91,9 +92,20 @@ public:
     //! Get pointer to element (whose lifetime is managed by this)
     Element* element();
 
+    //! Read everything that can be stored in t_trxframe from a checkpoint file
+    static void readCheckpointToTrxFrame(t_trxframe*                       trxFrame,
+                                         std::optional<ReadCheckpointData> readCheckpointData);
+    //! CheckpointHelper identifier
+    static const std::string& checkpointID();
+
 private:
+    //! Default constructor - only used internally
+    FreeEnergyPerturbationData() = default;
     //! Update the lambda values
     void updateLambdas(Step step);
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData);
 
     //! The element
     std::unique_ptr<Element> element_;
@@ -166,13 +178,6 @@ private:
     FreeEnergyPerturbationData* freeEnergyPerturbationData_;
     //! Whether lambda values are non-static
     const bool lambdasChange_;
-
-
-    //! CheckpointHelper identifier
-    const std::string identifier_ = "FreeEnergyPerturbationElement";
-    //! Helper function to read from / write to CheckpointData
-    template<CheckpointDataOperation operation>
-    void doCheckpointData(CheckpointData<operation>* checkpointData);
 };
 
 } // namespace gmx
index 0c9dfeba320bedf41fa0e87aab0983963189a35b..2450711c10f94220fa489b298d64329984bd982a 100644 (file)
@@ -48,6 +48,7 @@
 #include "gromacs/ewald/pme.h"
 #include "gromacs/ewald/pme_load_balancing.h"
 #include "gromacs/ewald/pme_pp.h"
+#include "gromacs/fileio/checkpoint.h"
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/listed_forces/listed_forces.h"
 #include "gromacs/mdlib/checkpointhandler.h"
@@ -71,7 +72,9 @@
 #include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/topology/mtop_util.h"
 #include "gromacs/topology/topology.h"
+#include "gromacs/trajectory/trajectoryframe.h"
 #include "gromacs/utility/fatalerror.h"
+#include "gromacs/utility/int64_to_int.h"
 
 #include "computeglobalselement.h"
 #include "constraintelement.h"
@@ -385,4 +388,32 @@ void ModularSimulator::checkInputForDisabledFunctionality()
     }
 }
 
+void ModularSimulator::readCheckpointToTrxFrame(t_trxframe*               fr,
+                                                ReadCheckpointDataHolder* readCheckpointDataHolder,
+                                                const CheckpointHeaderContents& checkpointHeaderContents)
+{
+    GMX_RELEASE_ASSERT(checkpointHeaderContents.isModularSimulatorCheckpoint,
+                       "ModularSimulator::readCheckpointToTrxFrame can only read checkpoints "
+                       "written by modular simulator.");
+    fr->bStep = true;
+    fr->step =
+            int64_to_int(checkpointHeaderContents.step, "conversion of checkpoint to trajectory");
+    fr->bTime = true;
+    fr->time  = checkpointHeaderContents.t;
+
+    fr->bAtoms = false;
+
+    StatePropagatorData::readCheckpointToTrxFrame(
+            fr, readCheckpointDataHolder->checkpointData(StatePropagatorData::checkpointID()));
+    if (readCheckpointDataHolder->keyExists(FreeEnergyPerturbationData::checkpointID()))
+    {
+        FreeEnergyPerturbationData::readCheckpointToTrxFrame(
+                fr, readCheckpointDataHolder->checkpointData(FreeEnergyPerturbationData::checkpointID()));
+    }
+    else
+    {
+        FreeEnergyPerturbationData::readCheckpointToTrxFrame(fr, std::nullopt);
+    }
+}
+
 } // namespace gmx
index a46b52ebdef235bc457101076b00528463f82bc5..42d170e2e7ce0c2414b333f36f7c7ff8fc1dbb27 100644 (file)
@@ -52,7 +52,9 @@
 
 #include "gromacs/mdrun/isimulator.h"
 
+struct CheckpointHeaderContents;
 struct t_fcdata;
+struct t_trxframe;
 
 namespace gmx
 {
@@ -86,6 +88,11 @@ public:
                                   bool                             doEssentialDynamics,
                                   bool                             doMembed);
 
+    //! Read everything that can be stored in t_trxframe from a checkpoint file
+    static void readCheckpointToTrxFrame(t_trxframe*                     fr,
+                                         ReadCheckpointDataHolder*       readCheckpointDataHolder,
+                                         const CheckpointHeaderContents& checkpointHeaderContents);
+
     // Only builder can construct
     friend class SimulatorBuilder;
 
index bf5bf8e10f9a7e7937b1340c968847582778f00a..92691193211a10c919416d3762627a8fda0a6003 100644 (file)
@@ -64,6 +64,7 @@
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/topology/atoms.h"
 #include "gromacs/topology/topology.h"
+#include "gromacs/trajectory/trajectoryframe.h"
 
 #include "freeenergyperturbationdata.h"
 #include "modularsimulator.h"
@@ -481,52 +482,54 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 } // namespace
 
 template<CheckpointDataOperation operation>
-void StatePropagatorData::Element::doCheckpointData(CheckpointData<operation>* checkpointData,
-                                                    const t_commrec*           cr)
+void StatePropagatorData::doCheckpointData(CheckpointData<operation>* checkpointData)
+{
+    checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
+    checkpointData->scalar("numAtoms", &totalNumAtoms_);
+
+    if (operation == CheckpointDataOperation::Read)
+    {
+        xGlobal_.resizeWithPadding(totalNumAtoms_);
+        vGlobal_.resizeWithPadding(totalNumAtoms_);
+    }
+
+    checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobal_));
+    checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobal_));
+    checkpointData->tensor("box", box_);
+    checkpointData->scalar("ddpCount", &ddpCount_);
+    checkpointData->scalar("ddpCountCgGl", &ddpCountCgGl_);
+    checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(cgGl_));
+}
+
+void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                                       const t_commrec*                   cr)
 {
-    ArrayRef<RVec> xGlobalRef;
-    ArrayRef<RVec> vGlobalRef;
     if (DOMAINDECOMP(cr))
     {
-        if (MASTER(cr))
-        {
-            xGlobalRef = statePropagatorData_->xGlobal_;
-            vGlobalRef = statePropagatorData_->vGlobal_;
-        }
-        if (operation == CheckpointDataOperation::Write)
-        {
-            dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
-                           statePropagatorData_->cgGl_, statePropagatorData_->x_, xGlobalRef);
-            dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
-                           statePropagatorData_->cgGl_, statePropagatorData_->v_, vGlobalRef);
-        }
+        // Collect state from all ranks into global vectors
+        dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
+                       statePropagatorData_->cgGl_, statePropagatorData_->x_,
+                       statePropagatorData_->xGlobal_);
+        dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
+                       statePropagatorData_->cgGl_, statePropagatorData_->v_,
+                       statePropagatorData_->vGlobal_);
     }
     else
     {
-        xGlobalRef = statePropagatorData_->x_;
-        vGlobalRef = statePropagatorData_->v_;
+        // Everything is local - copy local vectors into global ones
+        statePropagatorData_->xGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
+        statePropagatorData_->vGlobal_.resizeWithPadding(statePropagatorData_->totalNumAtoms());
+        std::copy(statePropagatorData_->x_.begin(), statePropagatorData_->x_.end(),
+                  statePropagatorData_->xGlobal_.begin());
+        std::copy(statePropagatorData_->v_.begin(), statePropagatorData_->v_.end(),
+                  statePropagatorData_->vGlobal_.begin());
     }
     if (MASTER(cr))
     {
-        GMX_ASSERT(checkpointData, "Master needs a valid pointer to a CheckpointData object");
-        checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
-
-        checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobalRef));
-        checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobalRef));
-        checkpointData->tensor("box", statePropagatorData_->box_);
-        checkpointData->scalar("ddpCount", &statePropagatorData_->ddpCount_);
-        checkpointData->scalar("ddpCountCgGl", &statePropagatorData_->ddpCountCgGl_);
-        checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(statePropagatorData_->cgGl_));
+        statePropagatorData_->doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
     }
 }
 
-void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
-                                                       const t_commrec*                   cr)
-{
-    doCheckpointData<CheckpointDataOperation::Write>(
-            checkpointData ? &checkpointData.value() : nullptr, cr);
-}
-
 /*!
  * \brief Update the legacy global state
  *
@@ -552,7 +555,10 @@ static void updateGlobalState(t_state*                      globalState,
 void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
                                                           const t_commrec*                  cr)
 {
-    doCheckpointData<CheckpointDataOperation::Read>(checkpointData ? &checkpointData.value() : nullptr, cr);
+    if (MASTER(cr))
+    {
+        statePropagatorData_->doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
+    }
 
     // Copy data to global state to be distributed by DD at setup stage
     if (DOMAINDECOMP(cr) && MASTER(cr))
@@ -562,11 +568,21 @@ void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadChec
                           statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
                           statePropagatorData_->cgGl_);
     }
+    // Everything is local - copy global vectors to local ones
+    if (!DOMAINDECOMP(cr))
+    {
+        statePropagatorData_->x_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
+        statePropagatorData_->v_.resizeWithPadding(statePropagatorData_->totalNumAtoms_);
+        std::copy(statePropagatorData_->xGlobal_.begin(), statePropagatorData_->xGlobal_.end(),
+                  statePropagatorData_->x_.begin());
+        std::copy(statePropagatorData_->vGlobal_.begin(), statePropagatorData_->vGlobal_.end(),
+                  statePropagatorData_->v_.begin());
+    }
 }
 
 const std::string& StatePropagatorData::Element::clientID()
 {
-    return identifier_;
+    return StatePropagatorData::checkpointID();
 }
 
 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
@@ -673,4 +689,25 @@ ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
     return statePropagatorData->element();
 }
 
+void StatePropagatorData::readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData)
+{
+    StatePropagatorData statePropagatorData;
+    statePropagatorData.doCheckpointData(&readCheckpointData);
+
+    trxFrame->natoms = statePropagatorData.totalNumAtoms_;
+    trxFrame->bX     = true;
+    trxFrame->x  = makeRvecArray(statePropagatorData.xGlobal_, statePropagatorData.totalNumAtoms_);
+    trxFrame->bV = true;
+    trxFrame->v  = makeRvecArray(statePropagatorData.vGlobal_, statePropagatorData.totalNumAtoms_);
+    trxFrame->bF = false;
+    trxFrame->bBox = true;
+    copy_mat(statePropagatorData.box_, trxFrame->box);
+}
+
+const std::string& StatePropagatorData::checkpointID()
+{
+    static const std::string identifier = "StatePropagatorData";
+    return identifier;
+}
+
 } // namespace gmx
index b3c1b050cac28d488b5d59dd2af5778968f20af7..fe277200617671a5cdbf863b6123a240db7869e7 100644 (file)
@@ -60,6 +60,7 @@ struct t_commrec;
 struct t_inputrec;
 class t_state;
 struct t_mdatoms;
+struct t_trxframe;
 
 namespace gmx
 {
@@ -148,6 +149,11 @@ public:
     //! Initial set up for the associated element
     void setup();
 
+    //! Read everything that can be stored in t_trxframe from a checkpoint file
+    static void readCheckpointToTrxFrame(t_trxframe* trxFrame, ReadCheckpointData readCheckpointData);
+    //! CheckpointHelper identifier
+    static const std::string& checkpointID();
+
     //! \cond
     // (doxygen doesn't like these)
     // Classes which need access to legacy state
@@ -155,6 +161,8 @@ public:
     //! \endcond
 
 private:
+    //! Default constructor - only used internally
+    StatePropagatorData() = default;
     //! The total number of atoms in the system
     int totalNumAtoms_;
     //! The local number of atoms
@@ -195,6 +203,10 @@ private:
     //! OMP helper to move x_ to previousX_
     void copyPosition(int start, int end);
 
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData);
+
     // Access to legacy state
     //! Get a deep copy of the current state in legacy format
     std::unique_ptr<t_state> localState();
@@ -347,12 +359,6 @@ private:
     //! ITrajectoryWriterClient implementation
     std::optional<ITrajectoryWriterCallback> registerTrajectoryWriterCallback(TrajectoryEvent event) override;
 
-    //! CheckpointHelper identifier
-    const std::string identifier_ = "StatePropagatorData";
-    //! Helper function to read from / write to CheckpointData
-    template<CheckpointDataOperation operation>
-    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
-
     //! ILastStepSignallerClient implementation (used for final output only)
     std::optional<SignallerCallback> registerLastStepCallback() override;
 
index 34f05ebfb5bfa5f749b75bca8c1f4a0f2e84424e..e4eccebd696ee00f00d79b2dad0ac391dda809ab 100644 (file)
@@ -85,6 +85,7 @@ set(exename "mdrun-io-test")
 
 gmx_add_gtest_executable(${exename}
     CPP_SOURCE_FILES
+        checkpoint.cpp
         exactcontinuation.cpp
         grompp.cpp
         initialconstraints.cpp
diff --git a/src/programs/mdrun/tests/checkpoint.cpp b/src/programs/mdrun/tests/checkpoint.cpp
new file mode 100644 (file)
index 0000000..dcee170
--- /dev/null
@@ -0,0 +1,171 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2020, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+
+/*! \internal \file
+ * \brief Tests for checkpoint writing sanity checks
+ *
+ * Checks that final checkpoint is equal to final trajectory output.
+ *
+ * \author Pascal Merz <pascal.merz@me.com>
+ * \ingroup module_mdrun_integration_tests
+ */
+#include "gmxpre.h"
+
+#include "config.h"
+
+#include "gromacs/utility/strconvert.h"
+#include "gromacs/utility/stringutil.h"
+
+#include "testutils/simulationdatabase.h"
+
+#include "moduletest.h"
+#include "simulatorcomparison.h"
+#include "trajectoryreader.h"
+
+namespace gmx::test
+{
+namespace
+{
+
+class CheckpointCoordinatesSanityChecks :
+    public MdrunTestFixture,
+    public ::testing::WithParamInterface<std::tuple<std::string, std::string, std::string, std::string>>
+{
+public:
+    void runSimulation(MdpFieldValues     mdpFieldValues,
+                       int                numSteps,
+                       const std::string& trrFileName,
+                       const std::string& cptFileName)
+    {
+        mdpFieldValues["nsteps"] = toString(numSteps);
+        // Trajectories have the initial and the last frame
+        mdpFieldValues["nstxout"] = toString(numSteps);
+        mdpFieldValues["nstvout"] = toString(numSteps);
+        mdpFieldValues["nstfout"] = toString(0);
+
+        // Run grompp
+        runner_.useStringAsMdpFile(prepareMdpFileContents(mdpFieldValues));
+        runGrompp(&runner_);
+
+        // Do first mdrun
+        runner_.fullPrecisionTrajectoryFileName_ = trrFileName;
+        runMdrun(&runner_, { { "-cpo", cptFileName } });
+    }
+
+    static void compareCptAndTrr(const std::string&          trrFileName,
+                                 const std::string&          cptFileName,
+                                 const TrajectoryComparison& trajectoryComparison)
+    {
+        TrajectoryFrameReader trrReader(trrFileName);
+        TrajectoryFrameReader cptReader(cptFileName);
+        // Checkpoint has at least one frame
+        EXPECT_TRUE(cptReader.readNextFrame());
+        // Trajectory has at least two frames
+        EXPECT_TRUE(trrReader.readNextFrame());
+        EXPECT_NO_THROW(trrReader.frame());
+        EXPECT_TRUE(trrReader.readNextFrame());
+
+        // Now compare frames
+        trajectoryComparison(cptReader.frame(), trrReader.frame());
+
+        // Files had exactly 1 / 2 frames
+        EXPECT_FALSE(cptReader.readNextFrame());
+        EXPECT_FALSE(trrReader.readNextFrame());
+    }
+};
+
+TEST_P(CheckpointCoordinatesSanityChecks, WithinTolerances)
+{
+    const auto& params              = GetParam();
+    const auto& simulationName      = std::get<0>(params);
+    const auto& integrator          = std::get<1>(params);
+    const auto& temperatureCoupling = std::get<2>(params);
+    const auto& pressureCoupling    = std::get<3>(params);
+
+    // Specify how trajectory frame matching must work.
+    TrajectoryFrameMatchSettings trajectoryMatchSettings{ true,
+                                                          true,
+                                                          true,
+                                                          ComparisonConditions::MustCompare,
+                                                          ComparisonConditions::MustCompare,
+                                                          ComparisonConditions::NoComparison,
+                                                          MaxNumFrames::compareAllFrames() };
+    if (integrator == "md-vv")
+    {
+        // When using md-vv and modular simulator, the velocities are expected to be off by
+        // 1/2 dt between checkpoint (top of the loop) and trajectory (full time step state)
+        trajectoryMatchSettings.velocitiesComparison = ComparisonConditions::NoComparison;
+    }
+    const TrajectoryTolerances trajectoryTolerances{ defaultRealTolerance(), defaultRealTolerance(),
+                                                     defaultRealTolerance(), defaultRealTolerance() };
+
+    const auto mdpFieldValues =
+            prepareMdpFieldValues(simulationName, integrator, temperatureCoupling, pressureCoupling);
+    runner_.useTopGroAndNdxFromDatabase(simulationName);
+    // Set file names
+    const auto cptFileName = fileManager_.getTemporaryFilePath(".cpt");
+    const auto trrFileName = fileManager_.getTemporaryFilePath(".trr");
+
+    SCOPED_TRACE(formatString(
+            "Checking the sanity of the checkpointed coordinates using system '%s' "
+            "with integrator '%s', '%s' temperature coupling, and '%s' pressure coupling ",
+            simulationName.c_str(), integrator.c_str(), temperatureCoupling.c_str(),
+            pressureCoupling.c_str()));
+
+    SCOPED_TRACE("End of trajectory sanity");
+    // Running a few steps - we expect the checkpoint to be equal
+    // to the final configuration
+    runSimulation(mdpFieldValues, 16, trrFileName, cptFileName);
+    compareCptAndTrr(trrFileName, cptFileName, { trajectoryMatchSettings, trajectoryTolerances });
+}
+
+#if !GMX_GPU_OPENCL
+INSTANTIATE_TEST_CASE_P(CheckpointCoordinatesAreSane,
+                        CheckpointCoordinatesSanityChecks,
+                        ::testing::Combine(::testing::Values("spc2"),
+                                           ::testing::Values("md", "md-vv"),
+                                           ::testing::Values("no"),
+                                           ::testing::Values("no")));
+#else
+INSTANTIATE_TEST_CASE_P(DISABLED_CheckpointCoordinatesAreSane,
+                        CheckpointCoordinatesSanityChecks,
+                        ::testing::Combine(::testing::Values("spc2"),
+                                           ::testing::Values("md", "md-vv"),
+                                           ::testing::Values("no"),
+                                           ::testing::Values("no")));
+#endif
+
+} // namespace
+} // namespace gmx::test