Implement modular checkpointing
authorPascal Merz <pascal.merz@me.com>
Wed, 16 Sep 2020 07:34:44 +0000 (07:34 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Wed, 16 Sep 2020 07:34:44 +0000 (07:34 +0000)
Using the CheckpointData format introduced in a parent commit, this
rewrites checkpointing for the modular simulator to completely use
the new format.

The CheckpointHelper is now passing a CheckpointData object to its
clients (instead of a legacy t_state object). Clients are now stored
in a map, as they are identified by their unique key to be able to
assign the correct CheckpointData sub-objects at reading and writing.

If checkpoint reading occured, the newly introduced
CheckpointHelperBuilder receives the CheckpointData object read at the
runner level from the ModularSimulator. It then initializes its clients
with their respective, read-only CheckpointData subobjects.

The ICheckpointHelperClient interface is adapted to reflect above
changes.

The ModularSimulatorAlgorithmBuilder is slightly simplified thanks to
to the introduction of a proper builder for the CheckpointHelper.

The ComputeGlobalsElement is simplified, as it is not required to know
about the needs of communication of the EnergyData object which
depends on checkpoint reading.

Finally, all elements which are checkpoint clients are updated to
implement the new design. Note that they all introduce their own
checkpoint versioning, as the data being checkpointed is opaque to the
checkpointing infrastructure.

Closes #3517
Closes #3422
In partial fulfillment of #3419

23 files changed:
src/gromacs/fileio/checkpoint.cpp
src/gromacs/mdrun/simulatorbuilder.h
src/gromacs/mdtypes/checkpointdata.h
src/gromacs/mdtypes/energyhistory.cpp
src/gromacs/mdtypes/state.cpp
src/gromacs/modularsimulator/checkpointhelper.cpp
src/gromacs/modularsimulator/checkpointhelper.h
src/gromacs/modularsimulator/computeglobalselement.cpp
src/gromacs/modularsimulator/computeglobalselement.h
src/gromacs/modularsimulator/energydata.cpp
src/gromacs/modularsimulator/energydata.h
src/gromacs/modularsimulator/freeenergyperturbationdata.cpp
src/gromacs/modularsimulator/freeenergyperturbationdata.h
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/modularsimulatorinterfaces.h
src/gromacs/modularsimulator/parrinellorahmanbarostat.cpp
src/gromacs/modularsimulator/parrinellorahmanbarostat.h
src/gromacs/modularsimulator/simulatoralgorithm.cpp
src/gromacs/modularsimulator/simulatoralgorithm.h
src/gromacs/modularsimulator/statepropagatordata.cpp
src/gromacs/modularsimulator/statepropagatordata.h
src/gromacs/modularsimulator/vrescalethermostat.cpp
src/gromacs/modularsimulator/vrescalethermostat.h

index 5a9e548c812969605dea7d2d8a3072e4f6b884c3..39d5b2984ae703896505932f40306cd2a2eaa84e 100644 (file)
@@ -2579,6 +2579,13 @@ static void read_checkpoint(const char*                    fn,
                   fn);
     }
 
+    GMX_ASSERT(!(headerContents->isModularSimulatorCheckpoint && !useModularSimulator),
+               "Checkpoint file was written by modular simulator, but the current simulation uses "
+               "the legacy simulator.");
+    GMX_ASSERT(!(!headerContents->isModularSimulatorCheckpoint && useModularSimulator),
+               "Checkpoint file was written by legacy simulator, but the current simulation uses "
+               "the modular simulator.");
+
     if (MASTER(cr))
     {
         check_match(fplog, cr, dd_nc, *headerContents, reproducibilityRequested);
index 74b56cec00247b5e378ade22f5828a3e4d20f5c8..f16d8f5c59ca4ac1a25f30c7687b7aa9256e7b3a 100644 (file)
@@ -335,7 +335,19 @@ public:
         boxDeformation_ = std::make_unique<BoxDeformationHandle>(boxDeformation);
     }
 
-    //! Pass the read checkpoint data for modular simulator
+    /*!
+     * \brief Pass the read checkpoint data for modular simulator
+     *
+     * Note that this is currently the point at which the ReadCheckpointDataHolder
+     * is fully filled. Consequently it stops being an object at which read
+     * operations from file are targeted, and becomes a read-only object from
+     * which elements read their data to recreate an earlier internal state.
+     *
+     * Currently, this behavior change is not enforced. Once input reading and
+     * simulator builder have matured, these restrictions could be imposed.
+     *
+     * See #3656
+     */
     void add(std::unique_ptr<ReadCheckpointDataHolder> modularSimulatorCheckpointData);
 
     /*! \brief Build a Simulator object based on input data
index e2a2f8829c73a297171969662445de3e584a9dd4..38ee26b6775ccb0c796cf0fc9c6074f744731f88 100644 (file)
@@ -46,7 +46,9 @@
 
 #include "gromacs/math/vectypes.h"
 #include "gromacs/utility/arrayref.h"
+#include "gromacs/utility/exceptions.h"
 #include "gromacs/utility/keyvaluetreebuilder.h"
+#include "gromacs/utility/stringutil.h"
 
 namespace gmx
 {
@@ -255,6 +257,69 @@ private:
     friend class WriteCheckpointDataHolder;
 };
 
+/*! \brief Read a checkpoint version enum variable
+ *
+ * This reads the checkpoint version from file. The read version is returned.
+ *
+ * If the read version is more recent than the code version, this throws an error, since
+ * we cannot know what has changed in the meantime. Using newer checkpoint files with
+ * old code is not a functionality we can offer. Note, however, that since the checkpoint
+ * version is saved by module, older checkpoint files of all simulations that don't use
+ * that specific module can still be used.
+ *
+ * Allowing backwards compatibility of files (i.e., reading an older checkpoint file with
+ * a newer version of the code) is in the responsibility of the caller module. They can
+ * use the returned file checkpoint version to do that:
+ *
+ *     const auto fileVersion = checkpointVersion(checkpointData, "version", c_currentVersion);
+ *     if (fileVersion >= CheckpointVersion::AddedX)
+ *     {
+ *         checkpointData->scalar("x", &x_));
+ *     }
+ *
+ * @tparam VersionEnum     The type of the checkpoint version enum
+ * @param  checkpointData  A reading checkpoint data object
+ * @param  key             The key under which the version is saved - also used for error output
+ * @param  programVersion  The checkpoint version of the current code
+ * @return                 The checkpoint version read from file
+ */
+template<typename VersionEnum>
+VersionEnum checkpointVersion(const ReadCheckpointData* checkpointData,
+                              const std::string&        key,
+                              const VersionEnum         programVersion)
+{
+    VersionEnum fileVersion;
+    checkpointData->enumScalar(key, &fileVersion);
+    if (fileVersion > programVersion)
+    {
+        throw FileIOError(
+                formatString("The checkpoint file contains a %s that is more recent than the "
+                             "current program version and is not backward compatible.",
+                             key.c_str()));
+    }
+    return fileVersion;
+}
+
+/*! \brief Write the current code checkpoint version enum variable
+ *
+ * Write the current program checkpoint version to the checkpoint data object.
+ * Returns the written checkpoint version to mirror the signature of the reading version.
+ *
+ * @tparam VersionEnum     The type of the checkpoint version enum
+ * @param  checkpointData  A writing checkpoint data object
+ * @param  key             The key under which the version is saved
+ * @param  programVersion  The checkpoint version of the current code
+ * @return                 The checkpoint version written to file
+ */
+template<typename VersionEnum>
+VersionEnum checkpointVersion(WriteCheckpointData* checkpointData,
+                              const std::string&   key,
+                              const VersionEnum    programVersion)
+{
+    checkpointData->enumScalar(key, &programVersion);
+    return programVersion;
+}
+
 
 /*! \libinternal
  * \brief Holder for read checkpoint data
index 5cadd9e4cbad4f9445bddc6b031ec144d34164d2..07da6e77fda84ba2b7dad59d169633d905de94f1 100644 (file)
@@ -79,12 +79,7 @@ static void checkpointVectorSize(gmx::CheckpointData<operation>* checkpointData,
 template<gmx::CheckpointDataOperation operation>
 void delta_h_history_t::doCheckpoint(gmx::CheckpointData<operation> checkpointData)
 {
-    auto version = c_currentVersionDeltaHH;
-    checkpointData.enumScalar("version", &version);
-    if (version != c_currentVersionDeltaHH)
-    {
-        throw gmx::FileIOError("delta_h_history_t checkpoint version mismatch.");
-    }
+    gmx::checkpointVersion(&checkpointData, "delta_h_history_t version", c_currentVersionDeltaHH);
 
     checkpointVectorSize(&checkpointData, "numDeltaH", &dh);
     checkpointData.scalar("start_time", &start_time);
@@ -118,12 +113,7 @@ constexpr auto c_currentVersionEnergyHistory =
 template<gmx::CheckpointDataOperation operation>
 void energyhistory_t::doCheckpoint(gmx::CheckpointData<operation> checkpointData)
 {
-    auto version = c_currentVersionEnergyHistory;
-    checkpointData.enumScalar("version", &version);
-    if (version != c_currentVersionEnergyHistory)
-    {
-        throw gmx::FileIOError("energyhistory_t checkpoint version mismatch.");
-    }
+    gmx::checkpointVersion(&checkpointData, "energyhistory_t version", c_currentVersionEnergyHistory);
 
     bool useCheckpoint = (nsum <= 0 && nsum_sim <= 0);
     checkpointData.scalar("useCheckpoint", &useCheckpoint);
index 7cf3e29ff7dcee3aeed2d9d2964fb0f54e88194a..0f36009513b6d8f6afa60bb12ac7cc54221de96f 100644 (file)
@@ -108,12 +108,7 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 template<gmx::CheckpointDataOperation operation>
 void ekinstate_t::doCheckpoint(gmx::CheckpointData<operation> checkpointData)
 {
-    auto version = c_currentVersion;
-    checkpointData.enumScalar("version", &version);
-    if (version != c_currentVersion)
-    {
-        throw gmx::FileIOError("ekinstate_t checkpoint version mismatch.");
-    }
+    gmx::checkpointVersion(&checkpointData, "ekinstate_t version", c_currentVersion);
 
     checkpointData.scalar("bUpToDate", &bUpToDate);
     if (!bUpToDate)
index 3246aebe7aad21cc9eb7f401a07ce67b0482d433..643e96a02d087a24021a350127d7b1d51b458930 100644 (file)
 #include "gromacs/mdlib/mdoutf.h"
 #include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
+#include "gromacs/mdtypes/energyhistory.h"
+#include "gromacs/mdtypes/observableshistory.h"
+#include "gromacs/mdtypes/pullhistory.h"
 #include "gromacs/mdtypes/state.h"
+#include "gromacs/utility/stringutil.h"
 
 #include "trajectoryelement.h"
 
 namespace gmx
 {
-CheckpointHelper::CheckpointHelper(std::vector<ICheckpointHelperClient*> clients,
-                                   std::unique_ptr<CheckpointHandler>    checkpointHandler,
-                                   int                                   initStep,
-                                   TrajectoryElement*                    trajectoryElement,
-                                   int                                   globalNumAtoms,
-                                   FILE*                                 fplog,
-                                   t_commrec*                            cr,
-                                   ObservablesHistory*                   observablesHistory,
-                                   gmx_walltime_accounting*              walltime_accounting,
-                                   t_state*                              state_global,
-                                   bool                                  writeFinalCheckpoint) :
+CheckpointHelper::CheckpointHelper(std::vector<std::tuple<std::string, ICheckpointHelperClient*>>&& clients,
+                                   std::unique_ptr<CheckpointHandler> checkpointHandler,
+                                   int                                initStep,
+                                   TrajectoryElement*                 trajectoryElement,
+                                   int                                globalNumAtoms,
+                                   FILE*                              fplog,
+                                   t_commrec*                         cr,
+                                   ObservablesHistory*                observablesHistory,
+                                   gmx_walltime_accounting*           walltime_accounting,
+                                   t_state*                           state_global,
+                                   bool                               writeFinalCheckpoint) :
     clients_(std::move(clients)),
     checkpointHandler_(std::move(checkpointHandler)),
     initStep_(initStep),
@@ -78,10 +82,6 @@ CheckpointHelper::CheckpointHelper(std::vector<ICheckpointHelperClient*> clients
     walltime_accounting_(walltime_accounting),
     state_global_(state_global)
 {
-    // Get rid of nullptr in clients list
-    clients_.erase(std::remove_if(clients_.begin(), clients_.end(),
-                                  [](ICheckpointHelperClient* ptr) { return ptr == nullptr; }),
-                   clients_.end());
     if (DOMAINDECOMP(cr))
     {
         localState_ = std::make_unique<t_state>();
@@ -93,6 +93,15 @@ CheckpointHelper::CheckpointHelper(std::vector<ICheckpointHelperClient*> clients
         state_change_natoms(state_global, state_global->natoms);
         localStateInstance_ = state_global;
     }
+
+    if (!observablesHistory_->energyHistory)
+    {
+        observablesHistory_->energyHistory = std::make_unique<energyhistory_t>();
+    }
+    if (!observablesHistory_->pullHistory)
+    {
+        observablesHistory_->pullHistory = std::make_unique<PullHistory>();
+    }
 }
 
 void CheckpointHelper::run(Step step, Time time)
@@ -123,9 +132,9 @@ void CheckpointHelper::writeCheckpoint(Step step, Time time)
     localStateInstance_->flags = 0;
 
     WriteCheckpointDataHolder checkpointDataHolder;
-    for (const auto& client : clients_)
+    for (const auto& [key, client] : clients_)
     {
-        client->writeCheckpoint(localStateInstance_, state_global_);
+        client->writeCheckpoint(checkpointDataHolder.checkpointData(key), cr_);
     }
 
     mdoutf_write_to_trajectory_files(fplog_, cr_, trajectoryElement_->outf_, MDOF_CPT,
@@ -138,4 +147,52 @@ std::optional<SignallerCallback> CheckpointHelper::registerLastStepCallback()
     return [this](Step step, Time gmx_unused time) { this->lastStep_ = step; };
 }
 
+CheckpointHelperBuilder::CheckpointHelperBuilder(std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder,
+                                                 StartingBehavior startingBehavior,
+                                                 t_commrec*       cr) :
+    resetFromCheckpoint_(startingBehavior != StartingBehavior::NewSimulation),
+    checkpointDataHolder_(std::move(checkpointDataHolder)),
+    checkpointHandler_(nullptr),
+    cr_(cr),
+    state_(ModularSimulatorBuilderState::AcceptingClientRegistrations)
+{
+}
+
+void CheckpointHelperBuilder::registerClient(ICheckpointHelperClient* client)
+{
+    if (!client)
+    {
+        return;
+    }
+    if (state_ == ModularSimulatorBuilderState::NotAcceptingClientRegistrations)
+    {
+        throw SimulationAlgorithmSetupError(
+                "Tried to register to CheckpointHelper after it was built.");
+    }
+    const auto& key = client->clientID();
+    if (clientsMap_.count(key) != 0)
+    {
+        throw SimulationAlgorithmSetupError("CheckpointHelper client key is not unique.");
+    }
+    clientsMap_[key] = client;
+    if (resetFromCheckpoint_)
+    {
+        if (!checkpointDataHolder_->keyExists(key))
+        {
+            throw SimulationAlgorithmSetupError(
+                    formatString(
+                            "CheckpointHelper client with key %s registered for checkpointing, "
+                            "but %s does not exist in the input checkpoint file.",
+                            key.c_str(), key.c_str())
+                            .c_str());
+        }
+        client->readCheckpoint(checkpointDataHolder_->checkpointData(key), cr_);
+    }
+}
+
+void CheckpointHelperBuilder::setCheckpointHandler(std::unique_ptr<CheckpointHandler> checkpointHandler)
+{
+    checkpointHandler_ = std::move(checkpointHandler);
+}
+
 } // namespace gmx
index 04e564b1db9723575e45e6e697755b01363ffb0c..b7c2630d4c18418e3a9548a3e24dd512f9c10ed4 100644 (file)
 #ifndef GMX_MODULARSIMULATOR_CHECKPOINTHELPER_H
 #define GMX_MODULARSIMULATOR_CHECKPOINTHELPER_H
 
+#include <map>
 #include <vector>
 
 #include "gromacs/mdlib/checkpointhandler.h"
+#include "gromacs/mdrunutility/handlerestart.h"
 
 #include "modularsimulatorinterfaces.h"
 
@@ -55,6 +57,7 @@ struct ObservablesHistory;
 
 namespace gmx
 {
+class KeyValueTreeObject;
 class MDLogger;
 class TrajectoryElement;
 
@@ -80,33 +83,30 @@ class TrajectoryElement;
  * Checkpointing happens at the end of a simulation step, which gives a
  * straightforward re-entry point at the top of the simulator loop.
  *
- * In the current implementation, the clients of CheckpointHelper fill a
- * legacy t_state object (passed via pointer) with whatever data they need
- * to store. The CheckpointHelper then writes the t_state object to file.
- * This is an intermediate state of the code, as the long-term plan is for
- * modules to read and write from a checkpoint file directly, without the
- * need for a central object. The current implementation allows, however,
- * to define clearly which modules take part in checkpointing, while using
- * the current infrastructure for reading and writing to checkpoint.
+ * Checkpoint writing is done by passing sub-objects of a
+ * WriteCheckpointDataHolder object to the clients. Checkpoint reading is
+ * done by passing sub-objects of a ReadCheckpointDataHolder object (passed
+ * in from runner level) do the clients.
  *
- * \todo Develop this into a module solely providing a file handler to
- *       modules for checkpoint reading and writing.
+ * \see ReadCheckpointDataHolder
+ * \see WriteCheckpointDataHolder
+ * \see CheckpointData
  */
 class CheckpointHelper final : public ILastStepSignallerClient, public ISimulatorElement
 {
 public:
     //! Constructor
-    CheckpointHelper(std::vector<ICheckpointHelperClient*> clients,
-                     std::unique_ptr<CheckpointHandler>    checkpointHandler,
-                     int                                   initStep,
-                     TrajectoryElement*                    trajectoryElement,
-                     int                                   globalNumAtoms,
-                     FILE*                                 fplog,
-                     t_commrec*                            cr,
-                     ObservablesHistory*                   observablesHistory,
-                     gmx_walltime_accounting*              walltime_accounting,
-                     t_state*                              state_global,
-                     bool                                  writeFinalCheckpoint);
+    CheckpointHelper(std::vector<std::tuple<std::string, ICheckpointHelperClient*>>&& clients,
+                     std::unique_ptr<CheckpointHandler> checkpointHandler,
+                     int                                initStep,
+                     TrajectoryElement*                 trajectoryElement,
+                     int                                globalNumAtoms,
+                     FILE*                              fplog,
+                     t_commrec*                         cr,
+                     ObservablesHistory*                observablesHistory,
+                     gmx_walltime_accounting*           walltime_accounting,
+                     t_state*                           state_global,
+                     bool                               writeFinalCheckpoint);
 
     /*! \brief Run checkpointing
      *
@@ -136,7 +136,7 @@ public:
 
 private:
     //! List of checkpoint clients
-    std::vector<ICheckpointHelperClient*> clients_;
+    std::vector<std::tuple<std::string, ICheckpointHelperClient*>> clients_;
 
     //! The checkpoint handler
     std::unique_ptr<CheckpointHandler> checkpointHandler_;
@@ -178,6 +178,68 @@ private:
     t_state* state_global_;
 };
 
+/*! \internal
+ * \ingroup module_modularsimulator
+ * \brief Builder for the checkpoint helper
+ */
+class CheckpointHelperBuilder
+{
+public:
+    //! Constructor
+    CheckpointHelperBuilder(std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder,
+                            StartingBehavior                          startingBehavior,
+                            t_commrec*                                cr);
+
+    //! Register checkpointing client
+    void registerClient(ICheckpointHelperClient* client);
+
+    //! Set CheckpointHandler
+    void setCheckpointHandler(std::unique_ptr<CheckpointHandler> checkpointHandler);
+
+    //! Return CheckpointHelper
+    template<typename... Args>
+    std::unique_ptr<CheckpointHelper> build(Args&&... args);
+
+private:
+    //! Map of checkpoint clients
+    std::map<std::string, ICheckpointHelperClient*> clientsMap_;
+    //! Whether we are resetting from checkpoint
+    const bool resetFromCheckpoint_;
+    //! The input checkpoint data
+    std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder_;
+    //! The checkpoint handler
+    std::unique_ptr<CheckpointHandler> checkpointHandler_;
+    //! Handles communication.
+    t_commrec* cr_;
+    //! Whether the builder accepts registrations.
+    ModularSimulatorBuilderState state_;
+};
+
+template<typename... Args>
+std::unique_ptr<CheckpointHelper> CheckpointHelperBuilder::build(Args&&... args)
+{
+    state_ = ModularSimulatorBuilderState::NotAcceptingClientRegistrations;
+    // Make sure that we don't have unused entries in checkpoint
+    if (resetFromCheckpoint_)
+    {
+        for (const auto& key : checkpointDataHolder_->keys())
+        {
+            if (clientsMap_.count(key) == 0)
+            {
+                // We have an entry in checkpointDataHolder_ which has no matching client
+                throw CheckpointError("Checkpoint entry " + key + " was not read. This "
+                                      "likely means that you are not using the same algorithm "
+                                      "that was used to create the checkpoint file.");
+            }
+        }
+    }
+
+    std::vector<std::tuple<std::string, ICheckpointHelperClient*>>&& clients = { clientsMap_.begin(),
+                                                                                 clientsMap_.end() };
+    return std::make_unique<CheckpointHelper>(std::move(clients), std::move(checkpointHandler_),
+                                              std::forward<Args>(args)...);
+}
+
 } // namespace gmx
 
 #endif // GMX_MODULARSIMULATOR_CHECKPOINTHELPER_H
index e11d9838af22beb815512413c5f48c43622fe30b..ae4635d19d2405a122f2c83800fd6ba3c3a0f9b8 100644 (file)
@@ -79,8 +79,7 @@ ComputeGlobalsElement<algorithm>::ComputeGlobalsElement(StatePropagatorData* sta
                                                         gmx_wallcycle*     wcycle,
                                                         t_forcerec*        fr,
                                                         const gmx_mtop_t*  global_top,
-                                                        Constraints*       constr,
-                                                        bool               hasReadEkinState) :
+                                                        Constraints*       constr) :
     energyReductionStep_(-1),
     virialReductionStep_(-1),
     vvSchedulingStep_(-1),
@@ -90,7 +89,6 @@ ComputeGlobalsElement<algorithm>::ComputeGlobalsElement(StatePropagatorData* sta
     lastStep_(inputrec->nsteps + inputrec->init_step),
     initStep_(inputrec->init_step),
     nullSignaller_(std::make_unique<SimulationSignaller>(nullptr, nullptr, nullptr, false, false)),
-    hasReadEkinState_(hasReadEkinState),
     totalNumberOfBondedInteractions_(0),
     shouldCheckNumberOfBondedInteractions_(false),
     statePropagatorData_(statePropagatorData),
@@ -144,7 +142,8 @@ void ComputeGlobalsElement<algorithm>::elementSetup()
         inc_nrnb(nrnb_, eNR_STOPCM, mdAtoms_->mdatoms()->homenr);
     }
 
-    unsigned int cglo_flags = (CGLO_TEMPERATURE | CGLO_GSTAT | (hasReadEkinState_ ? CGLO_READEKIN : 0));
+    unsigned int cglo_flags = (CGLO_TEMPERATURE | CGLO_GSTAT
+                               | (energyData_->hasReadEkinFromCheckpoint() ? CGLO_READEKIN : 0));
 
     if (algorithm == ComputeGlobalsAlgorithm::VelocityVerlet)
     {
@@ -356,31 +355,6 @@ ISimulatorElement* ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>::get
         FreeEnergyPerturbationData*             freeEnergyPerturbationData,
         GlobalCommunicationHelper*              globalCommunicationHelper)
 {
-    /* When restarting from a checkpoint, it can be appropriate to
-     * initialize ekind from quantities in the checkpoint. Otherwise,
-     * compute_globals must initialize ekind before the simulation
-     * starts/restarts. However, only the master rank knows what was
-     * found in the checkpoint file, so we have to communicate in
-     * order to coordinate the restart.
-     *
-     * TODO This will become obsolete as soon as checkpoint reading
-     *      happens within the modular simulator framework: The energy
-     *      element will read its data from the checkpoint file pointer,
-     *      and signal to the compute globals element if it needs anything
-     *      reduced.
-     */
-    bool hasReadEkinState = MASTER(legacySimulatorData->cr)
-                                    ? legacySimulatorData->state_global->ekinstate.hasReadEkinState
-                                    : false;
-    if (PAR(legacySimulatorData->cr))
-    {
-        gmx_bcast(sizeof(hasReadEkinState), &hasReadEkinState, legacySimulatorData->cr->mpi_comm_mygroup);
-    }
-    if (hasReadEkinState)
-    {
-        restore_ekinstate_from_state(legacySimulatorData->cr, legacySimulatorData->ekind,
-                                     &legacySimulatorData->state_global->ekinstate);
-    }
     auto* element = builderHelper->storeElement(
             std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>(
                     statePropagatorData, energyData, freeEnergyPerturbationData,
@@ -389,7 +363,7 @@ ISimulatorElement* ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>::get
                     legacySimulatorData->mdlog, legacySimulatorData->cr,
                     legacySimulatorData->inputrec, legacySimulatorData->mdAtoms,
                     legacySimulatorData->nrnb, legacySimulatorData->wcycle, legacySimulatorData->fr,
-                    legacySimulatorData->top_global, legacySimulatorData->constr, hasReadEkinState));
+                    legacySimulatorData->top_global, legacySimulatorData->constr));
 
     // TODO: Remove this when DD can reduce bonded interactions independently (#3421)
     auto* castedElement = static_cast<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>*>(element);
@@ -413,38 +387,13 @@ ISimulatorElement* ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet
     static thread_local ISimulatorElement* vvComputeGlobalsElement = nullptr;
     if (!builderHelper->elementIsStored(vvComputeGlobalsElement))
     {
-        /* When restarting from a checkpoint, it can be appropriate to
-         * initialize ekind from quantities in the checkpoint. Otherwise,
-         * compute_globals must initialize ekind before the simulation
-         * starts/restarts. However, only the master rank knows what was
-         * found in the checkpoint file, so we have to communicate in
-         * order to coordinate the restart.
-         *
-         * TODO This will become obsolete as soon as checkpoint reading
-         *      happens within the modular simulator framework: The energy
-         *      element will read its data from the checkpoint file pointer,
-         *      and signal to the compute globals element if it needs anything
-         *      reduced.
-         */
-        bool hasReadEkinState =
-                MASTER(simulator->cr) ? simulator->state_global->ekinstate.hasReadEkinState : false;
-        if (PAR(simulator->cr))
-        {
-            gmx_bcast(sizeof(hasReadEkinState), &hasReadEkinState, simulator->cr->mpi_comm_mygroup);
-        }
-        if (hasReadEkinState)
-        {
-            restore_ekinstate_from_state(simulator->cr, simulator->ekind,
-                                         &simulator->state_global->ekinstate);
-        }
         vvComputeGlobalsElement = builderHelper->storeElement(
                 std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>(
                         statePropagatorData, energyData, freeEnergyPerturbationData,
                         globalCommunicationHelper->simulationSignals(),
-                        globalCommunicationHelper->nstglobalcomm(), simulator->fplog,
-                        simulator->mdlog, simulator->cr, simulator->inputrec, simulator->mdAtoms,
-                        simulator->nrnb, simulator->wcycle, simulator->fr, simulator->top_global,
-                        simulator->constr, hasReadEkinState));
+                        globalCommunicationHelper->nstglobalcomm(), simulator->fplog, simulator->mdlog,
+                        simulator->cr, simulator->inputrec, simulator->mdAtoms, simulator->nrnb,
+                        simulator->wcycle, simulator->fr, simulator->top_global, simulator->constr));
 
         // TODO: Remove this when DD can reduce bonded interactions independently (#3421)
         auto* castedElement =
index ae7fa90409c95a21f8f8aad96c703251156e4b54..28f0dbb5c9d099a657831a626d53c0b25aa5b420 100644 (file)
@@ -119,8 +119,7 @@ public:
                           gmx_wallcycle*              wcycle,
                           t_forcerec*                 fr,
                           const gmx_mtop_t*           global_top,
-                          Constraints*                constr,
-                          bool                        hasReadEkinState);
+                          Constraints*                constr);
 
     //! Destructor
     ~ComputeGlobalsElement() override;
@@ -194,8 +193,6 @@ private:
     const Step initStep_;
     //! A dummy signaller (used for setup and VV)
     std::unique_ptr<SimulationSignaller> nullSignaller_;
-    //! Whether we read kinetic energy from checkpoint
-    const bool hasReadEkinState_;
 
     /*! \brief Check that DD doesn't miss bonded interactions
      *
index 3898cc29000eb1a02bfeff6b7ce4f7548e9a2bec..1619484851c76390e7460571bdb0a27b3077ade0 100644 (file)
@@ -43,6 +43,7 @@
 
 #include "energydata.h"
 
+#include "gromacs/gmxlib/network.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdlib/compute_io.h"
 #include "gromacs/mdlib/coupling.h"
@@ -53,6 +54,7 @@
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdlib/update.h"
 #include "gromacs/mdrunutility/handlerestart.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/energyhistory.h"
@@ -60,7 +62,6 @@
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/observableshistory.h"
 #include "gromacs/mdtypes/pullhistory.h"
-#include "gromacs/mdtypes/state.h"
 #include "gromacs/topology/topology.h"
 
 #include "freeenergyperturbationdata.h"
@@ -98,6 +99,7 @@ EnergyData::EnergyData(StatePropagatorData*        statePropagatorData,
     totalVirialStep_(-1),
     pressureStep_(-1),
     needToSumEkinhOld_(false),
+    hasReadEkinFromCheckpoint_(false),
     startingBehavior_(startingBehavior),
     statePropagatorData_(statePropagatorData),
     freeEnergyPerturbationData_(freeEnergyPerturbationData),
@@ -120,6 +122,9 @@ EnergyData::EnergyData(StatePropagatorData*        statePropagatorData,
     clear_mat(totalVirial_);
     clear_mat(pressure_);
     clear_rvec(muTot_);
+
+    init_ekinstate(&ekinstate_, inputrec_);
+    observablesHistory_->energyHistory = std::make_unique<energyhistory_t>();
 }
 
 void EnergyData::Element::scheduleTask(Step step, Time time, const RegisterRunFunction& registerRunFunction)
@@ -357,22 +362,79 @@ bool* EnergyData::needToSumEkinhOld()
     return &needToSumEkinhOld_;
 }
 
-void EnergyData::Element::writeCheckpoint(t_state gmx_unused* localState, t_state* globalState)
+bool EnergyData::hasReadEkinFromCheckpoint() const
+{
+    return hasReadEkinFromCheckpoint_;
+}
+
+namespace
+{
+/*!
+ * \brief Enum describing the contents EnergyData::Element writes to modular checkpoint
+ *
+ * When changing the checkpoint content, add a new element just above Count, and adjust the
+ * checkpoint functionality.
+ */
+enum class CheckpointVersion
+{
+    Base, //!< First version of modular checkpointing
+    Count //!< Number of entries. Add new versions right above this!
+};
+constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
+} // namespace
+
+template<CheckpointDataOperation operation>
+void EnergyData::Element::doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr)
+{
+    if (MASTER(cr))
+    {
+        checkpointVersion(checkpointData, "EnergyData version", c_currentVersion);
+
+        energyData_->observablesHistory_->energyHistory->doCheckpoint<operation>(
+                checkpointData->subCheckpointData("energy history"));
+        energyData_->ekinstate_.doCheckpoint<operation>(
+                checkpointData->subCheckpointData("ekinstate"));
+    }
+}
+
+void EnergyData::Element::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
 {
-    if (isMasterRank_)
+    if (MASTER(cr))
     {
         if (energyData_->needToSumEkinhOld_)
         {
-            globalState->ekinstate.bUpToDate = false;
+            energyData_->ekinstate_.bUpToDate = false;
         }
         else
         {
-            update_ekinstate(&globalState->ekinstate, energyData_->ekind_);
-            globalState->ekinstate.bUpToDate = true;
+            update_ekinstate(&energyData_->ekinstate_, energyData_->ekind_);
+            energyData_->ekinstate_.bUpToDate = true;
         }
         energyData_->energyOutput_->fillEnergyHistory(
                 energyData_->observablesHistory_->energyHistory.get());
     }
+    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+}
+
+void EnergyData::Element::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+    energyData_->hasReadEkinFromCheckpoint_ = MASTER(cr) ? energyData_->ekinstate_.bUpToDate : false;
+    if (PAR(cr))
+    {
+        gmx_bcast(sizeof(hasReadEkinFromCheckpoint_), &energyData_->hasReadEkinFromCheckpoint_,
+                  cr->mpi_comm_mygroup);
+    }
+    if (energyData_->hasReadEkinFromCheckpoint_)
+    {
+        // this takes care of broadcasting from master to agents
+        restore_ekinstate_from_state(cr, energyData_->ekind_, &energyData_->ekinstate_);
+    }
+}
+
+const std::string& EnergyData::Element::clientID()
+{
+    return identifier_;
 }
 
 void EnergyData::initializeEnergyHistory(StartingBehavior    startingBehavior,
index 0c2588efdcb9c7df6f86d297bf47c41b338857a1..660f9f07757cc1cfbd67d9377d0ee5bb78917f82 100644 (file)
@@ -45,6 +45,7 @@
 #define GMX_ENERGYELEMENT_MICROSTATE_H
 
 #include "gromacs/math/vectypes.h"
+#include "gromacs/mdtypes/state.h"
 
 #include "modularsimulatorinterfaces.h"
 
@@ -179,6 +180,13 @@ public:
      */
     bool* needToSumEkinhOld();
 
+    /*! \brief Whether kinetic energy was read from checkpoint
+     *
+     * This is needed by the compute globals element
+     * TODO: Remove this when moving global reduction to client system (#3421)
+     */
+    [[nodiscard]] bool hasReadEkinFromCheckpoint() const;
+
     /*! \brief set vrescale thermostat
      *
      * This allows to set a pointer to the vrescale thermostat used to
@@ -239,6 +247,8 @@ private:
     std::unique_ptr<Element> element_;
     //! The energy output object
     std::unique_ptr<EnergyOutput> energyOutput_;
+    //! Helper object to checkpoint kinetic energy data
+    ekinstate_t ekinstate_;
 
     //! Whether this is the master rank
     const bool isMasterRank_;
@@ -265,6 +275,8 @@ private:
 
     //! Whether ekinh_old needs to be summed up (set by compute globals)
     bool needToSumEkinhOld_;
+    //! Whether we have read ekin from checkpoint
+    bool hasReadEkinFromCheckpoint_;
 
     //! Describes how the simulation (re)starts
     const StartingBehavior startingBehavior_;
@@ -351,6 +363,13 @@ public:
     //! No element teardown needed
     void elementTeardown() override {}
 
+    //! ICheckpointHelperClient write checkpoint implementation
+    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient read checkpoint implementation
+    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient key implementation
+    const std::string& clientID() override;
+
     /*! \brief Factory method implementation
      *
      * \param legacySimulatorData  Pointer allowing access to simulator level data
@@ -391,8 +410,12 @@ private:
     //! IEnergySignallerClient implementation
     std::optional<SignallerCallback> registerEnergyCallback(EnergySignallerEvent event) override;
 
-    //! ICheckpointHelperClient implementation
-    void writeCheckpoint(t_state* localState, t_state* globalState) override;
+
+    //! CheckpointHelper identifier
+    const std::string identifier_ = "EnergyElement";
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
 
     //! Whether this is the master rank
     const bool isMasterRank_;
index 99f8d3205b2833662a934400bcc9ea8c8e98e41a..d284eceb1f9276e503e3a8f848c25bd017a59f76 100644 (file)
 
 #include "freeenergyperturbationdata.h"
 
+#include "gromacs/domdec/domdec_network.h"
 #include "gromacs/mdlib/freeenergyparameters.h"
 #include "gromacs/mdlib/md_support.h"
 #include "gromacs/mdlib/mdatoms.h"
+#include "gromacs/mdtypes/checkpointdata.h"
+#include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/state.h"
@@ -65,6 +68,9 @@ FreeEnergyPerturbationData::FreeEnergyPerturbationData(FILE* fplog, const t_inpu
     mdAtoms_(mdAtoms)
 {
     lambda_.fill(0);
+    // The legacy implementation only filled the lambda vector in state_global, which is only
+    // available on master. We have the lambda vector available everywhere, so we pass a `true`
+    // for isMaster on all ranks. See #3647.
     initialize_lambdas(fplog_, *inputrec_, true, &currentFEPState_, lambda_);
     update_mdatoms(mdAtoms_->mdatoms(), lambda_[efptMASS]);
 }
@@ -101,11 +107,62 @@ int FreeEnergyPerturbationData::currentFEPState()
     return currentFEPState_;
 }
 
-void FreeEnergyPerturbationData::Element::writeCheckpoint(t_state* localState, t_state gmx_unused* globalState)
+namespace
 {
-    localState->fep_state = freeEnergyPerturbationData_->currentFEPState_;
-    localState->lambda    = freeEnergyPerturbationData_->lambda_;
-    localState->flags |= (1U << estLAMBDA) | (1U << estFEPSTATE);
+/*!
+ * \brief Enum describing the contents FreeEnergyPerturbationData::Element writes to modular checkpoint
+ *
+ * When changing the checkpoint content, add a new element just above Count, and adjust the
+ * checkpoint functionality.
+ */
+enum class CheckpointVersion
+{
+    Base, //!< First version of modular checkpointing
+    Count //!< Number of entries. Add new versions right above this!
+};
+constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
+} // namespace
+
+template<CheckpointDataOperation operation>
+void FreeEnergyPerturbationData::Element::doCheckpointData(CheckpointData<operation>* checkpointData,
+                                                           const t_commrec*           cr)
+{
+    if (MASTER(cr))
+    {
+        checkpointVersion(checkpointData, "FreeEnergyPerturbationData version", c_currentVersion);
+
+        checkpointData->scalar("current FEP state", &freeEnergyPerturbationData_->currentFEPState_);
+        checkpointData->arrayRef("lambda vector",
+                                 makeCheckpointArrayRef<operation>(freeEnergyPerturbationData_->lambda_));
+    }
+    if (operation == CheckpointDataOperation::Read)
+    {
+        if (DOMAINDECOMP(cr))
+        {
+            dd_bcast(cr->dd, sizeof(int), &freeEnergyPerturbationData_->currentFEPState_);
+            dd_bcast(cr->dd, freeEnergyPerturbationData_->lambda_.size() * sizeof(real),
+                     freeEnergyPerturbationData_->lambda_.data());
+        }
+        update_mdatoms(freeEnergyPerturbationData_->mdAtoms_->mdatoms(),
+                       freeEnergyPerturbationData_->lambda_[efptMASS]);
+    }
+}
+
+void FreeEnergyPerturbationData::Element::writeCheckpoint(WriteCheckpointData checkpointData,
+                                                          const t_commrec*    cr)
+{
+    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+}
+
+void FreeEnergyPerturbationData::Element::readCheckpoint(ReadCheckpointData checkpointData,
+                                                         const t_commrec*   cr)
+{
+    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+}
+
+const std::string& FreeEnergyPerturbationData::Element::clientID()
+{
+    return identifier_;
 }
 
 FreeEnergyPerturbationData::Element::Element(FreeEnergyPerturbationData* freeEnergyPerturbationElement,
index 69ab744b0297a6dadbf530b9f82efa67f9afbde8..c2844dcb8cd36f430f21e60e176fc51fee5592a4 100644 (file)
@@ -54,6 +54,7 @@ struct t_inputrec;
 
 namespace gmx
 {
+enum class CheckpointDataOperation;
 class EnergyData;
 class GlobalCommunicationHelper;
 class LegacySimulatorData;
@@ -133,6 +134,13 @@ public:
     //! No teardown needed
     void elementTeardown() override{};
 
+    //! ICheckpointHelperClient write checkpoint implementation
+    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient read checkpoint implementation
+    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient key implementation
+    const std::string& clientID() override;
+
     /*! \brief Factory method implementation
      *
      * \param legacySimulatorData  Pointer allowing access to simulator level data
@@ -157,8 +165,12 @@ private:
     //! Whether lambda values are non-static
     const bool lambdasChange_;
 
-    //! ICheckpointHelperClient implementation
-    void writeCheckpoint(t_state* localState, t_state* globalState) override;
+
+    //! CheckpointHelper identifier
+    const std::string identifier_ = "FreeEnergyPerturbationElement";
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
 };
 
 } // namespace gmx
index bb2cd7e5a53008e1bdbaee5ed72b8a4016b67b3f..f10d038808c45d63c24277e8cd2ad5c7aa6952bd 100644 (file)
@@ -89,7 +89,8 @@ void ModularSimulator::run()
             .asParagraph()
             .appendText("Using the modular simulator.");
 
-    ModularSimulatorAlgorithmBuilder algorithmBuilder(compat::make_not_null(legacySimulatorData_));
+    ModularSimulatorAlgorithmBuilder algorithmBuilder(compat::make_not_null(legacySimulatorData_),
+                                                      std::move(checkpointDataHolder_));
     addIntegrationElements(&algorithmBuilder);
     auto algorithm = algorithmBuilder.build();
 
index 05c069164193889cb63e50ae51a65feaa0c53563..935f690ae3924baf40dec5c972d4f9d695107129 100644 (file)
 #include <optional>
 
 #include "gromacs/math/vectypes.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/utility/basedefinitions.h"
 #include "gromacs/utility/exceptions.h"
 
 struct gmx_localtop_t;
 struct gmx_mdoutf;
+struct t_commrec;
 class t_state;
 
 namespace gmx
@@ -341,28 +343,43 @@ protected:
 /*! \internal
  * \brief Client that needs to store data during checkpointing
  *
- * The current checkpointing helper uses the legacy t_state object to collect
- * the data to be checkpointed. Clients get queried for their contributions
- * using pointers to t_state objects.
- * \todo Add checkpoint reading
- * \todo Evolve this to a model in which the checkpoint helper passes a file
- *       pointer rather than a t_state object, and the clients are responsible
- *       to read / write.
+ * Clients receive a CheckpointData object for reading and writing.
+ * Note that `ReadCheckpointData` is a typedef for
+ * `CheckpointData<CheckpointDataOperation::Read>`, and
+ * `WriteCheckpointData` is a typedef for
+ * `CheckpointData<CheckpointDataOperation::Write>`. This allows clients
+ * to write a single templated function, e.g.
+ *     template<CheckpointDataOperation operation>
+ *     void doCheckpointData(CheckpointData<operation>* checkpointData,
+ *                           const t_commrec* cr)
+ *     {
+ *         checkpointData->scalar("important value", &value_);
+ *     }
+ * for both checkpoint reading and writing. This function can then be
+ * dispatched from the interface functions,
+ *     void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+ *     {
+ *         doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+ *     }
+ *     void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+ *     {
+ *         doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+ *     }
+ * This reduces code duplication and ensures that reading and writing
+ * operations will not get out of sync.
  */
 class ICheckpointHelperClient
 {
 public:
-    //! \cond
-    // (doxygen doesn't like these...)
-    //! Allow CheckpointHelper to interact
-    friend class CheckpointHelper;
-    //! \endcond
     //! Standard virtual destructor
     virtual ~ICheckpointHelperClient() = default;
 
-protected:
     //! Write checkpoint
-    virtual void writeCheckpoint(t_state* localState, t_state* globalState) = 0;
+    virtual void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) = 0;
+    //! Read checkpoint
+    virtual void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) = 0;
+    //! Get unique client id
+    [[nodiscard]] virtual const std::string& clientID() = 0;
 };
 
 /*! \brief
index 1b2da2b84098e358a41c904fb535b19bbe855519..518356401273b6d2530e045f8d71934cddb14575 100644 (file)
 #include "gromacs/mdlib/coupling.h"
 #include "gromacs/mdlib/mdatoms.h"
 #include "gromacs/mdlib/stat.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/mdatom.h"
-#include "gromacs/mdtypes/state.h"
 #include "gromacs/pbcutil/boxutilities.h"
 
 #include "energydata.h"
@@ -71,10 +71,7 @@ ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                  nstpcoup
                                                    EnergyData*          energyData,
                                                    FILE*                fplog,
                                                    const t_inputrec*    inputrec,
-                                                   const MDAtoms*       mdAtoms,
-                                                   const t_state*       globalState,
-                                                   t_commrec*           cr,
-                                                   bool                 isRestart) :
+                                                   const MDAtoms*       mdAtoms) :
     nstpcouple_(nstpcouple),
     offset_(offset),
     couplingTimeStep_(couplingTimeStep),
@@ -89,21 +86,6 @@ ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                  nstpcoup
     mdAtoms_(mdAtoms)
 {
     energyData->setParrinelloRahamnBarostat(this);
-    // TODO: This is only needed to restore the thermostatIntegral_ from cpt. Remove this when
-    //       switching to purely client-based checkpointing.
-    if (isRestart)
-    {
-        if (MASTER(cr))
-        {
-            copy_mat(globalState->boxv, boxVelocity_);
-            copy_mat(globalState->box_rel, boxRel_);
-        }
-        if (DOMAINDECOMP(cr))
-        {
-            dd_bcast(cr->dd, sizeof(boxVelocity_), boxVelocity_);
-            dd_bcast(cr->dd, sizeof(boxRel_), boxRel_);
-        }
-    }
 }
 
 void ParrinelloRahmanBarostat::connectWithPropagator(const PropagatorBarostatConnection& connectionData)
@@ -240,11 +222,53 @@ real ParrinelloRahmanBarostat::conservedEnergyContribution() const
     return energy;
 }
 
-void ParrinelloRahmanBarostat::writeCheckpoint(t_state* localState, t_state gmx_unused* globalState)
+namespace
+{
+/*!
+ * \brief Enum describing the contents ParrinelloRahmanBarostat writes to modular checkpoint
+ *
+ * When changing the checkpoint content, add a new element just above Count, and adjust the
+ * checkpoint functionality.
+ */
+enum class CheckpointVersion
+{
+    Base, //!< First version of modular checkpointing
+    Count //!< Number of entries. Add new versions right above this!
+};
+constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
+} // namespace
+
+template<CheckpointDataOperation operation>
+void ParrinelloRahmanBarostat::doCheckpointData(CheckpointData<operation>* checkpointData,
+                                                const t_commrec*           cr)
+{
+    if (MASTER(cr))
+    {
+        checkpointVersion(checkpointData, "ParrinelloRahmanBarostat version", c_currentVersion);
+
+        checkpointData->tensor("box velocity", boxVelocity_);
+        checkpointData->tensor("relative box vector", boxRel_);
+    }
+    if (operation == CheckpointDataOperation::Read && DOMAINDECOMP(cr))
+    {
+        dd_bcast(cr->dd, sizeof(boxVelocity_), boxVelocity_);
+        dd_bcast(cr->dd, sizeof(boxRel_), boxRel_);
+    }
+}
+
+void ParrinelloRahmanBarostat::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+}
+
+void ParrinelloRahmanBarostat::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+}
+
+const std::string& ParrinelloRahmanBarostat::clientID()
 {
-    copy_mat(boxVelocity_, localState->boxv);
-    copy_mat(boxRel_, localState->box_rel);
-    localState->flags |= (1U << estBOXV) | (1U << estBOX_REL);
+    return identifier_;
 }
 
 ISimulatorElement* ParrinelloRahmanBarostat::getElementPointerImpl(
@@ -260,9 +284,7 @@ ISimulatorElement* ParrinelloRahmanBarostat::getElementPointerImpl(
             legacySimulatorData->inputrec->nstpcouple, offset,
             legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nstpcouple,
             legacySimulatorData->inputrec->init_step, statePropagatorData, energyData,
-            legacySimulatorData->fplog, legacySimulatorData->inputrec, legacySimulatorData->mdAtoms,
-            legacySimulatorData->state_global, legacySimulatorData->cr,
-            legacySimulatorData->inputrec->bContinuation));
+            legacySimulatorData->fplog, legacySimulatorData->inputrec, legacySimulatorData->mdAtoms));
     auto* barostat = static_cast<ParrinelloRahmanBarostat*>(element);
     builderHelper->registerBarostat([barostat](const PropagatorBarostatConnection& connection) {
         barostat->connectWithPropagator(connection);
index 6f20f4be6b53c1342c9c7602f6bbe877379694f7..980675572559bfaf544f72382d332a94a7ac747c 100644 (file)
@@ -54,6 +54,7 @@ struct t_commrec;
 
 namespace gmx
 {
+enum class CheckpointDataOperation;
 class EnergyData;
 class LegacySimulatorData;
 class MDAtoms;
@@ -81,10 +82,7 @@ public:
                              EnergyData*          energyData,
                              FILE*                fplog,
                              const t_inputrec*    inputrec,
-                             const MDAtoms*       mdAtoms,
-                             const t_state*       globalState,
-                             t_commrec*           cr,
-                             bool                 isRestart);
+                             const MDAtoms*       mdAtoms);
 
     /*! \brief Register run function for step / time
      *
@@ -107,6 +105,13 @@ public:
     //! Connect this to propagator
     void connectWithPropagator(const PropagatorBarostatConnection& connectionData);
 
+    //! ICheckpointHelperClient write checkpoint implementation
+    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient read checkpoint implementation
+    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient key implementation
+    const std::string& clientID() override;
+
     /*! \brief Factory method implementation
      *
      * \param legacySimulatorData  Pointer allowing access to simulator level data
@@ -160,8 +165,11 @@ private:
     //! Scale box and positions
     void scaleBoxAndPositions();
 
-    //! ICheckpointHelperClient implementation
-    void writeCheckpoint(t_state* localState, t_state* globalState) override;
+    //! CheckpointHelper identifier
+    const std::string identifier_ = "ParrinelloRahmanBarostat";
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
 
     // Access to ISimulator data
     //! Handles logging.
index 44886b83ca1c2280eaa92ee2df97b08ca00774c2..26c8f02d7971d7e321c3cddcdfd42408127cc875 100644 (file)
@@ -382,14 +382,18 @@ void ModularSimulatorAlgorithm::populateTaskQueue()
 }
 
 ModularSimulatorAlgorithmBuilder::ModularSimulatorAlgorithmBuilder(
-        compat::not_null<LegacySimulatorData*> legacySimulatorData) :
+        compat::not_null<LegacySimulatorData*>    legacySimulatorData,
+        std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder) :
     legacySimulatorData_(legacySimulatorData),
     signals_(std::make_unique<SimulationSignals>()),
     elementAdditionHelper_(this),
     globalCommunicationHelper_(computeGlobalCommunicationPeriod(legacySimulatorData->mdlog,
                                                                 legacySimulatorData->inputrec,
                                                                 legacySimulatorData->cr),
-                               signals_.get())
+                               signals_.get()),
+    checkpointHelperBuilder_(std::move(checkpointDataHolder),
+                             legacySimulatorData->startingBehavior,
+                             legacySimulatorData->cr)
 {
     if (legacySimulatorData->inputrec->efep != efepNO)
     {
@@ -528,14 +532,12 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::build()
 
     // Build checkpoint helper (do this last so everyone else can be a checkpoint client!)
     {
-        auto checkpointHandler = std::make_unique<CheckpointHandler>(
-                compat::make_not_null<SimulationSignal*>(
-                        &(*globalCommunicationHelper_.simulationSignals())[eglsCHKPT]),
+        checkpointHelperBuilder_.setCheckpointHandler(std::make_unique<CheckpointHandler>(
+                compat::make_not_null<SimulationSignal*>(&(*algorithm.signals_)[eglsCHKPT]),
                 simulationsShareState, legacySimulatorData_->inputrec->nstlist == 0,
                 MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.writeConfout,
-                legacySimulatorData_->mdrunOptions.checkpointOptions.period);
-        algorithm.checkpointHelper_ = std::make_unique<CheckpointHelper>(
-                std::move(checkpointClients_), std::move(checkpointHandler),
+                legacySimulatorData_->mdrunOptions.checkpointOptions.period));
+        algorithm.checkpointHelper_ = checkpointHelperBuilder_.build(
                 legacySimulatorData_->inputrec->init_step, trajectoryElement.get(),
                 legacySimulatorData_->top_global->natoms, legacySimulatorData_->fplog,
                 legacySimulatorData_->cr, legacySimulatorData_->observablesHistory,
index 46009cdabfd11aa02003dcc0be56185657cafdaf..ab4146bfa556f89d028ae5d8406043ec49183a1d 100644 (file)
@@ -351,7 +351,8 @@ class ModularSimulatorAlgorithmBuilder final
 {
 public:
     //! Constructor
-    explicit ModularSimulatorAlgorithmBuilder(compat::not_null<LegacySimulatorData*> legacySimulatorData);
+    ModularSimulatorAlgorithmBuilder(compat::not_null<LegacySimulatorData*>    legacySimulatorData,
+                                     std::unique_ptr<ReadCheckpointDataHolder> checkpointDataHolder);
     //! Build algorithm
     ModularSimulatorAlgorithm build();
 
@@ -455,6 +456,8 @@ private:
     TrajectoryElementBuilder trajectoryElementBuilder_;
     //! Builder for the TopologyHolder
     TopologyHolder::Builder topologyHolderBuilder_;
+    //! Builder for the CheckpointHelper
+    CheckpointHelperBuilder checkpointHelperBuilder_;
 
     /*! \brief List of clients for the CheckpointHelper
      *
@@ -594,10 +597,7 @@ void ModularSimulatorAlgorithmBuilder::registerWithInfrastructureAndSignallers(E
     // Register element to topology holder (if applicable)
     topologyHolderBuilder_.registerClient(castOrNull<ITopologyHolderClient, Element>(element));
     // Register element to checkpoint client (if applicable)
-    if (auto castedElement = castOrNull<ICheckpointHelperClient, Element>(element))
-    {
-        checkpointClients_.emplace_back(castedElement);
-    }
+    checkpointHelperBuilder_.registerClient(castOrNull<ICheckpointHelperClient, Element>(element));
 }
 
 } // namespace gmx
index 93fba49cb36a08c93754f4e00ac305705f0ad205..e66483718d926cbe0863aa112cb01d1031fb5900 100644 (file)
@@ -61,7 +61,6 @@
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/mdrunoptions.h"
 #include "gromacs/mdtypes/state.h"
-#include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/topology/atoms.h"
 #include "gromacs/topology/topology.h"
@@ -132,6 +131,14 @@ StatePropagatorData::StatePropagatorData(int                numAtoms,
         changePinningPolicy(&x_, gmx::PinningPolicy::PinnedIfSupported);
     }
 
+    if (DOMAINDECOMP(cr) && MASTER(cr))
+    {
+        xGlobal_.reserveWithPadding(totalNumAtoms_);
+        previousXGlobal_.reserveWithPadding(totalNumAtoms_);
+        vGlobal_.reserveWithPadding(totalNumAtoms_);
+        fGlobal_.reserveWithPadding(totalNumAtoms_);
+    }
+
     if (!inputrec->bContinuation)
     {
         if (stateHasVelocities)
@@ -254,7 +261,9 @@ std::unique_ptr<t_state> StatePropagatorData::localState()
     state->x = x_;
     state->v = v_;
     copy_mat(box_, state->box);
-    state->ddp_count = ddpCount_;
+    state->ddp_count       = ddpCount_;
+    state->ddp_count_cg_gl = ddpCountCgGl_;
+    state->cg_gl           = cgGl_;
     return state;
 }
 
@@ -268,7 +277,9 @@ void StatePropagatorData::setLocalState(std::unique_ptr<t_state> state)
     v_ = state->v;
     copy_mat(state->box, box_);
     copyPosition();
-    ddpCount_ = state->ddp_count;
+    ddpCount_     = state->ddp_count;
+    ddpCountCgGl_ = state->ddp_count_cg_gl;
+    cgGl_         = state->cg_gl;
 
     if (vvResetVelocities_)
     {
@@ -453,14 +464,74 @@ void StatePropagatorData::resetVelocities()
     v_ = velocityBackup_;
 }
 
-void StatePropagatorData::Element::writeCheckpoint(t_state* localState, t_state gmx_unused* globalState)
+namespace
+{
+/*!
+ * \brief Enum describing the contents StatePropagatorData::Element writes to modular checkpoint
+ *
+ * When changing the checkpoint content, add a new element just above Count, and adjust the
+ * checkpoint functionality.
+ */
+enum class CheckpointVersion
+{
+    Base, //!< First version of modular checkpointing
+    Count //!< Number of entries. Add new versions right above this!
+};
+constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
+} // namespace
+
+template<CheckpointDataOperation operation>
+void StatePropagatorData::Element::doCheckpointData(CheckpointData<operation>* checkpointData,
+                                                    const t_commrec*           cr)
+{
+    ArrayRef<RVec> xGlobalRef;
+    ArrayRef<RVec> vGlobalRef;
+    if (DOMAINDECOMP(cr))
+    {
+        if (MASTER(cr))
+        {
+            xGlobalRef = statePropagatorData_->xGlobal_;
+            vGlobalRef = statePropagatorData_->vGlobal_;
+        }
+        if (operation == CheckpointDataOperation::Write)
+        {
+            dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
+                           statePropagatorData_->cgGl_, statePropagatorData_->x_, xGlobalRef);
+            dd_collect_vec(cr->dd, statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
+                           statePropagatorData_->cgGl_, statePropagatorData_->v_, vGlobalRef);
+        }
+    }
+    else
+    {
+        xGlobalRef = statePropagatorData_->x_;
+        vGlobalRef = statePropagatorData_->v_;
+    }
+    if (MASTER(cr))
+    {
+        checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
+
+        checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobalRef));
+        checkpointData->arrayRef("velocities", makeCheckpointArrayRef<operation>(vGlobalRef));
+        checkpointData->tensor("box", statePropagatorData_->box_);
+        checkpointData->scalar("ddpCount", &statePropagatorData_->ddpCount_);
+        checkpointData->scalar("ddpCountCgGl", &statePropagatorData_->ddpCountCgGl_);
+        checkpointData->arrayRef("cgGl", makeCheckpointArrayRef<operation>(statePropagatorData_->cgGl_));
+    }
+}
+
+void StatePropagatorData::Element::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+}
+
+void StatePropagatorData::Element::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+}
+
+const std::string& StatePropagatorData::Element::clientID()
 {
-    state_change_natoms(localState, statePropagatorData_->localNAtoms_);
-    localState->x = statePropagatorData_->x_;
-    localState->v = statePropagatorData_->v_;
-    copy_mat(statePropagatorData_->box_, localState->box);
-    localState->ddp_count = statePropagatorData_->ddpCount_;
-    localState->flags |= (1U << estX) | (1U << estV) | (1U << estBOX);
+    return identifier_;
 }
 
 void StatePropagatorData::Element::trajectoryWriterTeardown(gmx_mdoutf* gmx_unused outf)
index 2bafcfad8ad6ee4a5a373d3c83b8e7f3fa211acc..2a8ba2314af0427563b25faaab1c7598da5baa17 100644 (file)
@@ -49,6 +49,7 @@
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/forcebuffers.h"
+#include "gromacs/utility/keyvaluetree.h"
 
 #include "modularsimulatorinterfaces.h"
 #include "topologyholder.h"
@@ -62,6 +63,7 @@ struct t_mdatoms;
 
 namespace gmx
 {
+enum class CheckpointDataOperation;
 enum class ConstraintVariable;
 class EnergyData;
 class FreeEnergyPerturbationData;
@@ -169,8 +171,21 @@ private:
     matrix box_;
     //! The box matrix of the previous step
     matrix previousBox_;
-    //! The DD partitioning count for legacy t_state compatibility
+    //! The DD partitioning count
     int ddpCount_;
+    //! The DD partitioning count for index_gl
+    int ddpCountCgGl_;
+    //! The global cg number of the local cgs
+    std::vector<int> cgGl_;
+
+    //! The global position vector
+    PaddedHostVector<RVec> xGlobal_;
+    //! The global position vector of the previous step
+    PaddedHostVector<RVec> previousXGlobal_;
+    //! The global velocity vector
+    PaddedHostVector<RVec> vGlobal_;
+    //! The global force vector
+    PaddedHostVector<RVec> fGlobal_;
 
     //! The element
     std::unique_ptr<Element> element_;
@@ -281,6 +296,13 @@ public:
     //! Set free energy data
     void setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData);
 
+    //! ICheckpointHelperClient write checkpoint implementation
+    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient read checkpoint implementation
+    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient key implementation
+    const std::string& clientID() override;
+
     /*! \brief Factory method implementation
      *
      * \param legacySimulatorData  Pointer allowing access to simulator level data
@@ -325,8 +347,11 @@ private:
     //! ITrajectoryWriterClient implementation
     std::optional<ITrajectoryWriterCallback> registerTrajectoryWriterCallback(TrajectoryEvent event) override;
 
-    //! ICheckpointHelperClient implementation
-    void writeCheckpoint(t_state* localState, t_state* globalState) override;
+    //! CheckpointHelper identifier
+    const std::string identifier_ = "StatePropagatorData";
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
 
     //! ILastStepSignallerClient implementation (used for final output only)
     std::optional<SignallerCallback> registerLastStepCallback() override;
index e2e56dafc9193be92f0a1503ed2bbc6805ae7044..40f87857e954c1c2ce9c918687ebf8ce1630bfc2 100644 (file)
 #include "gromacs/math/vec.h"
 #include "gromacs/mdlib/coupling.h"
 #include "gromacs/mdlib/stat.h"
+#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/group.h"
 #include "gromacs/mdtypes/inputrec.h"
-#include "gromacs/mdtypes/state.h"
 #include "gromacs/utility/fatalerror.h"
 
 #include "modularsimulator.h"
@@ -72,10 +72,7 @@ VRescaleThermostat::VRescaleThermostat(int                               nstcoup
                                        const real*                       referenceTemperature,
                                        const real*                       couplingTime,
                                        const real*                       numDegreesOfFreedom,
-                                       EnergyData*                       energyData,
-                                       const t_state*                    globalState,
-                                       t_commrec*                        cr,
-                                       bool                              isRestart) :
+                                       EnergyData*                       energyData) :
     nstcouple_(nstcouple),
     offset_(offset),
     useFullStepKE_(useFullStepKE),
@@ -94,22 +91,6 @@ VRescaleThermostat::VRescaleThermostat(int                               nstcoup
         thermostatIntegralPreviousStep_ = thermostatIntegral_;
     }
     energyData->setVRescaleThermostat(this);
-    // TODO: This is only needed to restore the thermostatIntegral_ from cpt. Remove this when
-    //       switching to purely client-based checkpointing.
-    if (isRestart)
-    {
-        if (MASTER(cr))
-        {
-            for (unsigned long i = 0; i < thermostatIntegral_.size(); ++i)
-            {
-                thermostatIntegral_[i] = globalState->therm_integral[i];
-            }
-        }
-        if (DOMAINDECOMP(cr))
-        {
-            dd_bcast(cr->dd, int(thermostatIntegral_.size() * sizeof(double)), thermostatIntegral_.data());
-        }
-    }
 }
 
 void VRescaleThermostat::connectWithPropagator(const PropagatorThermostatConnection& connectionData)
@@ -207,10 +188,51 @@ void VRescaleThermostat::setLambda(Step step)
     }
 }
 
-void VRescaleThermostat::writeCheckpoint(t_state* localState, t_state gmx_unused* globalState)
+namespace
+{
+/*!
+ * \brief Enum describing the contents VRescaleThermostat writes to modular checkpoint
+ *
+ * When changing the checkpoint content, add a new element just above Count, and adjust the
+ * checkpoint functionality.
+ */
+enum class CheckpointVersion
+{
+    Base, //!< First version of modular checkpointing
+    Count //!< Number of entries. Add new versions right above this!
+};
+constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
+} // namespace
+
+template<CheckpointDataOperation operation>
+void VRescaleThermostat::doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr)
+{
+    if (MASTER(cr))
+    {
+        checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
+
+        checkpointData->arrayRef("thermostat integral",
+                                 makeCheckpointArrayRef<operation>(thermostatIntegral_));
+    }
+    if (operation == CheckpointDataOperation::Read && DOMAINDECOMP(cr))
+    {
+        dd_bcast(cr->dd, thermostatIntegral_.size() * sizeof(double), thermostatIntegral_.data());
+    }
+}
+
+void VRescaleThermostat::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+}
+
+void VRescaleThermostat::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+{
+    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+}
+
+const std::string& VRescaleThermostat::clientID()
 {
-    localState->therm_integral = thermostatIntegral_;
-    localState->flags |= (1U << estTHERM_INT);
+    return identifier_;
 }
 
 real VRescaleThermostat::conservedEnergyContribution() const
@@ -237,8 +259,7 @@ ISimulatorElement* VRescaleThermostat::getElementPointerImpl(
             legacySimulatorData->inputrec->ld_seed, legacySimulatorData->inputrec->opts.ngtc,
             legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nsttcouple,
             legacySimulatorData->inputrec->opts.ref_t, legacySimulatorData->inputrec->opts.tau_t,
-            legacySimulatorData->inputrec->opts.nrdf, energyData, legacySimulatorData->state_global,
-            legacySimulatorData->cr, legacySimulatorData->inputrec->bContinuation));
+            legacySimulatorData->inputrec->opts.nrdf, energyData));
     auto* thermostat = static_cast<VRescaleThermostat*>(element);
     builderHelper->registerThermostat([thermostat](const PropagatorThermostatConnection& connection) {
         thermostat->connectWithPropagator(connection);
index f019bcd771abc1f788c1a79858bc236bf5feb81d..7730b8e38e42912b0040c0dc34f76e6f9bfcc512 100644 (file)
@@ -93,10 +93,7 @@ public:
                        const real*                       referenceTemperature,
                        const real*                       couplingTime,
                        const real*                       numDegreesOfFreedom,
-                       EnergyData*                       energyData,
-                       const t_state*                    globalState,
-                       t_commrec*                        cr,
-                       bool                              isRestart);
+                       EnergyData*                       energyData);
 
     /*! \brief Register run function for step / time
      *
@@ -117,6 +114,13 @@ public:
     //! Connect this to propagator
     void connectWithPropagator(const PropagatorThermostatConnection& connectionData);
 
+    //! ICheckpointHelperClient write checkpoint implementation
+    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient read checkpoint implementation
+    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient key implementation
+    const std::string& clientID() override;
+
     /*! \brief Factory method implementation
      *
      * \param legacySimulatorData  Pointer allowing access to simulator level data
@@ -181,8 +185,11 @@ private:
     //! Set new lambda value (at T-coupling steps)
     void setLambda(Step step);
 
-    //! ICheckpointHelperClient implementation
-    void writeCheckpoint(t_state* localState, t_state* globalState) override;
+    //! CheckpointHelper identifier
+    const std::string identifier_ = "VRescaleThermostat";
+    //! Helper function to read from / write to CheckpointData
+    template<CheckpointDataOperation operation>
+    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
 };
 
 } // namespace gmx