Checkpointdata DD followup
authorPascal Merz <pascal.merz@me.com>
Mon, 21 Sep 2020 16:34:37 +0000 (16:34 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Mon, 21 Sep 2020 16:34:37 +0000 (16:34 +0000)
12 files changed:
src/gromacs/modularsimulator/checkpointhelper.cpp
src/gromacs/modularsimulator/energydata.cpp
src/gromacs/modularsimulator/energydata.h
src/gromacs/modularsimulator/freeenergyperturbationdata.cpp
src/gromacs/modularsimulator/freeenergyperturbationdata.h
src/gromacs/modularsimulator/modularsimulatorinterfaces.h
src/gromacs/modularsimulator/parrinellorahmanbarostat.cpp
src/gromacs/modularsimulator/parrinellorahmanbarostat.h
src/gromacs/modularsimulator/statepropagatordata.cpp
src/gromacs/modularsimulator/statepropagatordata.h
src/gromacs/modularsimulator/velocityscalingtemperaturecoupling.cpp
src/gromacs/modularsimulator/velocityscalingtemperaturecoupling.h

index 643e96a02d087a24021a350127d7b1d51b458930..d030904dc4f4cacdcb222228595a99881ff6f527 100644 (file)
@@ -134,7 +134,9 @@ void CheckpointHelper::writeCheckpoint(Step step, Time time)
     WriteCheckpointDataHolder checkpointDataHolder;
     for (const auto& [key, client] : clients_)
     {
-        client->writeCheckpoint(checkpointDataHolder.checkpointData(key), cr_);
+        client->saveCheckpointState(
+                MASTER(cr_) ? std::make_optional(checkpointDataHolder.checkpointData(key)) : std::nullopt,
+                cr_);
     }
 
     mdoutf_write_to_trajectory_files(fplog_, cr_, trajectoryElement_->outf_, MDOF_CPT,
@@ -177,7 +179,7 @@ void CheckpointHelperBuilder::registerClient(ICheckpointHelperClient* client)
     clientsMap_[key] = client;
     if (resetFromCheckpoint_)
     {
-        if (!checkpointDataHolder_->keyExists(key))
+        if (MASTER(cr_) && !checkpointDataHolder_->keyExists(key))
         {
             throw SimulationAlgorithmSetupError(
                     formatString(
@@ -186,7 +188,9 @@ void CheckpointHelperBuilder::registerClient(ICheckpointHelperClient* client)
                             key.c_str(), key.c_str())
                             .c_str());
         }
-        client->readCheckpoint(checkpointDataHolder_->checkpointData(key), cr_);
+        client->restoreCheckpointState(
+                MASTER(cr_) ? std::make_optional(checkpointDataHolder_->checkpointData(key)) : std::nullopt,
+                cr_);
     }
 }
 
index f3df288719164874457bd46647d1108573d12ac0..8bba774e60d1b00caf84bc3603a93509c39e99eb 100644 (file)
@@ -386,20 +386,17 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 } // namespace
 
 template<CheckpointDataOperation operation>
-void EnergyData::Element::doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr)
+void EnergyData::Element::doCheckpointData(CheckpointData<operation>* checkpointData)
 {
-    if (MASTER(cr))
-    {
-        checkpointVersion(checkpointData, "EnergyData version", c_currentVersion);
+    checkpointVersion(checkpointData, "EnergyData version", c_currentVersion);
 
-        energyData_->observablesHistory_->energyHistory->doCheckpoint<operation>(
-                checkpointData->subCheckpointData("energy history"));
-        energyData_->ekinstate_.doCheckpoint<operation>(
-                checkpointData->subCheckpointData("ekinstate"));
-    }
+    energyData_->observablesHistory_->energyHistory->doCheckpoint<operation>(
+            checkpointData->subCheckpointData("energy history"));
+    energyData_->ekinstate_.doCheckpoint<operation>(checkpointData->subCheckpointData("ekinstate"));
 }
 
-void EnergyData::Element::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+void EnergyData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                              const t_commrec*                   cr)
 {
     if (MASTER(cr))
     {
@@ -414,13 +411,17 @@ void EnergyData::Element::writeCheckpoint(WriteCheckpointData checkpointData, co
         }
         energyData_->energyOutput_->fillEnergyHistory(
                 energyData_->observablesHistory_->energyHistory.get());
+        doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
     }
-    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
 }
 
-void EnergyData::Element::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+void EnergyData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                                 const t_commrec*                  cr)
 {
-    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+    if (MASTER(cr))
+    {
+        doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
+    }
     energyData_->hasReadEkinFromCheckpoint_ = MASTER(cr) ? energyData_->ekinstate_.bUpToDate : false;
     if (PAR(cr))
     {
index 025702fec53f8caa5739b72ff813887de227d764..774390a39a95704d2c7074eec22dc82a4eb87029 100644 (file)
@@ -364,9 +364,9 @@ public:
     void elementTeardown() override {}
 
     //! ICheckpointHelperClient write checkpoint implementation
-    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient read checkpoint implementation
-    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient key implementation
     const std::string& clientID() override;
 
@@ -415,7 +415,7 @@ private:
     const std::string identifier_ = "EnergyElement";
     //! Helper function to read from / write to CheckpointData
     template<CheckpointDataOperation operation>
-    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+    void doCheckpointData(CheckpointData<operation>* checkpointData);
 
     //! Whether this is the master rank
     const bool isMasterRank_;
index d284eceb1f9276e503e3a8f848c25bd017a59f76..86b098d89f68a41ae60f8fc46f1cdce4fc514d5b 100644 (file)
@@ -124,40 +124,39 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 } // namespace
 
 template<CheckpointDataOperation operation>
-void FreeEnergyPerturbationData::Element::doCheckpointData(CheckpointData<operation>* checkpointData,
-                                                           const t_commrec*           cr)
+void FreeEnergyPerturbationData::Element::doCheckpointData(CheckpointData<operation>* checkpointData)
 {
-    if (MASTER(cr))
-    {
-        checkpointVersion(checkpointData, "FreeEnergyPerturbationData version", c_currentVersion);
+    checkpointVersion(checkpointData, "FreeEnergyPerturbationData version", c_currentVersion);
 
-        checkpointData->scalar("current FEP state", &freeEnergyPerturbationData_->currentFEPState_);
-        checkpointData->arrayRef("lambda vector",
-                                 makeCheckpointArrayRef<operation>(freeEnergyPerturbationData_->lambda_));
-    }
-    if (operation == CheckpointDataOperation::Read)
-    {
-        if (DOMAINDECOMP(cr))
-        {
-            dd_bcast(cr->dd, sizeof(int), &freeEnergyPerturbationData_->currentFEPState_);
-            dd_bcast(cr->dd, freeEnergyPerturbationData_->lambda_.size() * sizeof(real),
-                     freeEnergyPerturbationData_->lambda_.data());
-        }
-        update_mdatoms(freeEnergyPerturbationData_->mdAtoms_->mdatoms(),
-                       freeEnergyPerturbationData_->lambda_[efptMASS]);
-    }
+    checkpointData->scalar("current FEP state", &freeEnergyPerturbationData_->currentFEPState_);
+    checkpointData->arrayRef("lambda vector",
+                             makeCheckpointArrayRef<operation>(freeEnergyPerturbationData_->lambda_));
 }
 
-void FreeEnergyPerturbationData::Element::writeCheckpoint(WriteCheckpointData checkpointData,
-                                                          const t_commrec*    cr)
+void FreeEnergyPerturbationData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                                              const t_commrec*                   cr)
 {
-    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+    if (MASTER(cr))
+    {
+        doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+    }
 }
 
-void FreeEnergyPerturbationData::Element::readCheckpoint(ReadCheckpointData checkpointData,
-                                                         const t_commrec*   cr)
+void FreeEnergyPerturbationData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                                                 const t_commrec* cr)
 {
-    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+    if (MASTER(cr))
+    {
+        doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
+    }
+    if (DOMAINDECOMP(cr))
+    {
+        dd_bcast(cr->dd, sizeof(int), &freeEnergyPerturbationData_->currentFEPState_);
+        dd_bcast(cr->dd, freeEnergyPerturbationData_->lambda_.size() * sizeof(real),
+                 freeEnergyPerturbationData_->lambda_.data());
+    }
+    update_mdatoms(freeEnergyPerturbationData_->mdAtoms_->mdatoms(),
+                   freeEnergyPerturbationData_->lambda_[efptMASS]);
 }
 
 const std::string& FreeEnergyPerturbationData::Element::clientID()
index c2844dcb8cd36f430f21e60e176fc51fee5592a4..0ee6fe78ece8f17d29b0815d11ec24d0a8f3fa2a 100644 (file)
@@ -135,9 +135,9 @@ public:
     void elementTeardown() override{};
 
     //! ICheckpointHelperClient write checkpoint implementation
-    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient read checkpoint implementation
-    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient key implementation
     const std::string& clientID() override;
 
@@ -170,7 +170,7 @@ private:
     const std::string identifier_ = "FreeEnergyPerturbationElement";
     //! Helper function to read from / write to CheckpointData
     template<CheckpointDataOperation operation>
-    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+    void doCheckpointData(CheckpointData<operation>* checkpointData);
 };
 
 } // namespace gmx
index 935f690ae3924baf40dec5c972d4f9d695107129..6e0dbb6a36939349098756df592f184e3e73c27c 100644 (file)
@@ -374,10 +374,12 @@ public:
     //! Standard virtual destructor
     virtual ~ICheckpointHelperClient() = default;
 
-    //! Write checkpoint
-    virtual void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) = 0;
-    //! Read checkpoint
-    virtual void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) = 0;
+    //! Write checkpoint (CheckpointData object only passed on master rank)
+    virtual void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                     const t_commrec*                   cr) = 0;
+    //! Read checkpoint (CheckpointData object only passed on master rank)
+    virtual void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                        const t_commrec*                  cr) = 0;
     //! Get unique client id
     [[nodiscard]] virtual const std::string& clientID() = 0;
 };
index 518356401273b6d2530e045f8d71934cddb14575..7bf5c7fe999c4a2d43ac0652184097cfeaa9ba67 100644 (file)
@@ -239,33 +239,37 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 } // namespace
 
 template<CheckpointDataOperation operation>
-void ParrinelloRahmanBarostat::doCheckpointData(CheckpointData<operation>* checkpointData,
-                                                const t_commrec*           cr)
+void ParrinelloRahmanBarostat::doCheckpointData(CheckpointData<operation>* checkpointData)
+{
+    checkpointVersion(checkpointData, "ParrinelloRahmanBarostat version", c_currentVersion);
+
+    checkpointData->tensor("box velocity", boxVelocity_);
+    checkpointData->tensor("relative box vector", boxRel_);
+}
+
+void ParrinelloRahmanBarostat::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                                   const t_commrec*                   cr)
 {
     if (MASTER(cr))
     {
-        checkpointVersion(checkpointData, "ParrinelloRahmanBarostat version", c_currentVersion);
+        doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+    }
+}
 
-        checkpointData->tensor("box velocity", boxVelocity_);
-        checkpointData->tensor("relative box vector", boxRel_);
+void ParrinelloRahmanBarostat::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                                      const t_commrec*                  cr)
+{
+    if (MASTER(cr))
+    {
+        doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
     }
-    if (operation == CheckpointDataOperation::Read && DOMAINDECOMP(cr))
+    if (DOMAINDECOMP(cr))
     {
         dd_bcast(cr->dd, sizeof(boxVelocity_), boxVelocity_);
         dd_bcast(cr->dd, sizeof(boxRel_), boxRel_);
     }
 }
 
-void ParrinelloRahmanBarostat::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
-{
-    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
-}
-
-void ParrinelloRahmanBarostat::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
-{
-    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
-}
-
 const std::string& ParrinelloRahmanBarostat::clientID()
 {
     return identifier_;
index 980675572559bfaf544f72382d332a94a7ac747c..77ede5316557e06407ec222c53b6426d77e99aef 100644 (file)
@@ -106,9 +106,9 @@ public:
     void connectWithPropagator(const PropagatorBarostatConnection& connectionData);
 
     //! ICheckpointHelperClient write checkpoint implementation
-    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient read checkpoint implementation
-    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient key implementation
     const std::string& clientID() override;
 
@@ -169,7 +169,7 @@ private:
     const std::string identifier_ = "ParrinelloRahmanBarostat";
     //! Helper function to read from / write to CheckpointData
     template<CheckpointDataOperation operation>
-    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+    void doCheckpointData(CheckpointData<operation>* checkpointData);
 
     // Access to ISimulator data
     //! Handles logging.
index e66483718d926cbe0863aa112cb01d1031fb5900..bf5bf8e10f9a7e7937b1340c968847582778f00a 100644 (file)
@@ -133,10 +133,10 @@ StatePropagatorData::StatePropagatorData(int                numAtoms,
 
     if (DOMAINDECOMP(cr) && MASTER(cr))
     {
-        xGlobal_.reserveWithPadding(totalNumAtoms_);
-        previousXGlobal_.reserveWithPadding(totalNumAtoms_);
-        vGlobal_.reserveWithPadding(totalNumAtoms_);
-        fGlobal_.reserveWithPadding(totalNumAtoms_);
+        xGlobal_.resizeWithPadding(totalNumAtoms_);
+        previousXGlobal_.resizeWithPadding(totalNumAtoms_);
+        vGlobal_.resizeWithPadding(totalNumAtoms_);
+        fGlobal_.resizeWithPadding(totalNumAtoms_);
     }
 
     if (!inputrec->bContinuation)
@@ -508,6 +508,7 @@ void StatePropagatorData::Element::doCheckpointData(CheckpointData<operation>* c
     }
     if (MASTER(cr))
     {
+        GMX_ASSERT(checkpointData, "Master needs a valid pointer to a CheckpointData object");
         checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
 
         checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobalRef));
@@ -519,14 +520,48 @@ void StatePropagatorData::Element::doCheckpointData(CheckpointData<operation>* c
     }
 }
 
-void StatePropagatorData::Element::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                                       const t_commrec*                   cr)
 {
-    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+    doCheckpointData<CheckpointDataOperation::Write>(
+            checkpointData ? &checkpointData.value() : nullptr, cr);
 }
 
-void StatePropagatorData::Element::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+/*!
+ * \brief Update the legacy global state
+ *
+ * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
+ * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
+ */
+static void updateGlobalState(t_state*                      globalState,
+                              const PaddedHostVector<RVec>& x,
+                              const PaddedHostVector<RVec>& v,
+                              const tensor                  box,
+                              int                           ddpCount,
+                              int                           ddpCountCgGl,
+                              const std::vector<int>&       cgGl)
+{
+    globalState->x = x;
+    globalState->v = v;
+    copy_mat(box, globalState->box);
+    globalState->ddp_count       = ddpCount;
+    globalState->ddp_count_cg_gl = ddpCountCgGl;
+    globalState->cg_gl           = cgGl;
+}
+
+void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                                          const t_commrec*                  cr)
 {
-    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+    doCheckpointData<CheckpointDataOperation::Read>(checkpointData ? &checkpointData.value() : nullptr, cr);
+
+    // Copy data to global state to be distributed by DD at setup stage
+    if (DOMAINDECOMP(cr) && MASTER(cr))
+    {
+        updateGlobalState(statePropagatorData_->globalState_, statePropagatorData_->xGlobal_,
+                          statePropagatorData_->vGlobal_, statePropagatorData_->box_,
+                          statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
+                          statePropagatorData_->cgGl_);
+    }
 }
 
 const std::string& StatePropagatorData::Element::clientID()
index 2a8ba2314af0427563b25faaab1c7598da5baa17..b3c1b050cac28d488b5d59dd2af5778968f20af7 100644 (file)
@@ -297,9 +297,9 @@ public:
     void setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData);
 
     //! ICheckpointHelperClient write checkpoint implementation
-    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient read checkpoint implementation
-    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient key implementation
     const std::string& clientID() override;
 
index 8a9bf53ca990baba0bfc81c8bc07b9f39396c160..8ceaf49255bddcaf3a4e44623df76648cee6a3f6 100644 (file)
@@ -107,9 +107,10 @@ public:
                                      const TemperatureCouplingData& temperatureCouplingData) = 0;
 
     //! Write private data to checkpoint
-    virtual void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) = 0;
+    virtual void writeCheckpoint(std::optional<WriteCheckpointData> checkpointData,
+                                 const t_commrec*                   cr) = 0;
     //! Read private data from checkpoint
-    virtual void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) = 0;
+    virtual void readCheckpoint(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) = 0;
 
     //! Standard virtual destructor
     virtual ~ITemperatureCouplingImpl() = default;
@@ -177,11 +178,13 @@ public:
     }
 
     //! No data to write to checkpoint
-    void writeCheckpoint(WriteCheckpointData gmx_unused checkpointData, const t_commrec gmx_unused* cr) override
+    void writeCheckpoint(std::optional<WriteCheckpointData> gmx_unused checkpointData,
+                         const t_commrec gmx_unused* cr) override
     {
     }
     //! No data to read from checkpoints
-    void readCheckpoint(ReadCheckpointData gmx_unused checkpointData, const t_commrec gmx_unused* cr) override
+    void readCheckpoint(std::optional<ReadCheckpointData> gmx_unused checkpointData,
+                        const t_commrec gmx_unused* cr) override
     {
     }
 
@@ -320,35 +323,45 @@ constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count
 } // namespace
 
 template<CheckpointDataOperation operation>
-void VelocityScalingTemperatureCoupling::doCheckpointData(CheckpointData<operation>* checkpointData,
-                                                          const t_commrec*           cr)
+void VelocityScalingTemperatureCoupling::doCheckpointData(CheckpointData<operation>* checkpointData)
+{
+    checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
+
+    checkpointData->arrayRef("thermostat integral",
+                             makeCheckpointArrayRef<operation>(temperatureCouplingIntegral_));
+}
+
+void VelocityScalingTemperatureCoupling::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+                                                             const t_commrec*                   cr)
 {
     if (MASTER(cr))
     {
-        checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
+        doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+    }
+    temperatureCouplingImpl_->writeCheckpoint(
+            checkpointData
+                    ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
+                    : std::nullopt,
+            cr);
+}
 
-        checkpointData->arrayRef("thermostat integral",
-                                 makeCheckpointArrayRef<operation>(temperatureCouplingIntegral_));
+void VelocityScalingTemperatureCoupling::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                                                const t_commrec* cr)
+{
+    if (MASTER(cr))
+    {
+        doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
     }
-    if (operation == CheckpointDataOperation::Read && DOMAINDECOMP(cr))
+    if (DOMAINDECOMP(cr))
     {
         dd_bcast(cr->dd, temperatureCouplingIntegral_.size() * sizeof(double),
                  temperatureCouplingIntegral_.data());
     }
-}
-
-void VelocityScalingTemperatureCoupling::writeCheckpoint(WriteCheckpointData checkpointData,
-                                                         const t_commrec*    cr)
-{
-    doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
-    temperatureCouplingImpl_->writeCheckpoint(checkpointData.subCheckpointData("thermostat impl"), cr);
-}
-
-void VelocityScalingTemperatureCoupling::readCheckpoint(ReadCheckpointData checkpointData,
-                                                        const t_commrec*   cr)
-{
-    doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
-    temperatureCouplingImpl_->readCheckpoint(checkpointData.subCheckpointData("thermostat impl"), cr);
+    temperatureCouplingImpl_->readCheckpoint(
+            checkpointData
+                    ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
+                    : std::nullopt,
+            cr);
 }
 
 const std::string& VelocityScalingTemperatureCoupling::clientID()
index e68ab6411a14c683a1f4d5a123318cb2d2059340..c984fadd0dfcde63b1b80c3cc1f66ef370779bdd 100644 (file)
@@ -125,9 +125,9 @@ public:
     void connectWithPropagator(const PropagatorThermostatConnection& connectionData);
 
     //! ICheckpointHelperClient write checkpoint implementation
-    void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+    void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient read checkpoint implementation
-    void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+    void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
     //! ICheckpointHelperClient key implementation
     const std::string& clientID() override;
 
@@ -198,7 +198,7 @@ private:
     const std::string identifier_ = "VelocityScalingTemperatureCoupling";
     //! Helper function to read from / write to CheckpointData
     template<CheckpointDataOperation operation>
-    void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+    void doCheckpointData(CheckpointData<operation>* checkpointData);
 };
 
 } // namespace gmx