WriteCheckpointDataHolder checkpointDataHolder;
for (const auto& [key, client] : clients_)
{
- client->writeCheckpoint(checkpointDataHolder.checkpointData(key), cr_);
+ client->saveCheckpointState(
+ MASTER(cr_) ? std::make_optional(checkpointDataHolder.checkpointData(key)) : std::nullopt,
+ cr_);
}
mdoutf_write_to_trajectory_files(fplog_, cr_, trajectoryElement_->outf_, MDOF_CPT,
clientsMap_[key] = client;
if (resetFromCheckpoint_)
{
- if (!checkpointDataHolder_->keyExists(key))
+ if (MASTER(cr_) && !checkpointDataHolder_->keyExists(key))
{
throw SimulationAlgorithmSetupError(
formatString(
key.c_str(), key.c_str())
.c_str());
}
- client->readCheckpoint(checkpointDataHolder_->checkpointData(key), cr_);
+ client->restoreCheckpointState(
+ MASTER(cr_) ? std::make_optional(checkpointDataHolder_->checkpointData(key)) : std::nullopt,
+ cr_);
}
}
} // namespace
template<CheckpointDataOperation operation>
-void EnergyData::Element::doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr)
+void EnergyData::Element::doCheckpointData(CheckpointData<operation>* checkpointData)
{
- if (MASTER(cr))
- {
- checkpointVersion(checkpointData, "EnergyData version", c_currentVersion);
+ checkpointVersion(checkpointData, "EnergyData version", c_currentVersion);
- energyData_->observablesHistory_->energyHistory->doCheckpoint<operation>(
- checkpointData->subCheckpointData("energy history"));
- energyData_->ekinstate_.doCheckpoint<operation>(
- checkpointData->subCheckpointData("ekinstate"));
- }
+ energyData_->observablesHistory_->energyHistory->doCheckpoint<operation>(
+ checkpointData->subCheckpointData("energy history"));
+ energyData_->ekinstate_.doCheckpoint<operation>(checkpointData->subCheckpointData("ekinstate"));
}
-void EnergyData::Element::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+void EnergyData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr)
{
if (MASTER(cr))
{
}
energyData_->energyOutput_->fillEnergyHistory(
energyData_->observablesHistory_->energyHistory.get());
+ doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
}
- doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
}
-void EnergyData::Element::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+void EnergyData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+ const t_commrec* cr)
{
- doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+ if (MASTER(cr))
+ {
+ doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
+ }
energyData_->hasReadEkinFromCheckpoint_ = MASTER(cr) ? energyData_->ekinstate_.bUpToDate : false;
if (PAR(cr))
{
void elementTeardown() override {}
//! ICheckpointHelperClient write checkpoint implementation
- void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+ void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient read checkpoint implementation
- void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+ void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient key implementation
const std::string& clientID() override;
const std::string identifier_ = "EnergyElement";
//! Helper function to read from / write to CheckpointData
template<CheckpointDataOperation operation>
- void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+ void doCheckpointData(CheckpointData<operation>* checkpointData);
//! Whether this is the master rank
const bool isMasterRank_;
} // namespace
template<CheckpointDataOperation operation>
-void FreeEnergyPerturbationData::Element::doCheckpointData(CheckpointData<operation>* checkpointData,
- const t_commrec* cr)
+void FreeEnergyPerturbationData::Element::doCheckpointData(CheckpointData<operation>* checkpointData)
{
- if (MASTER(cr))
- {
- checkpointVersion(checkpointData, "FreeEnergyPerturbationData version", c_currentVersion);
+ checkpointVersion(checkpointData, "FreeEnergyPerturbationData version", c_currentVersion);
- checkpointData->scalar("current FEP state", &freeEnergyPerturbationData_->currentFEPState_);
- checkpointData->arrayRef("lambda vector",
- makeCheckpointArrayRef<operation>(freeEnergyPerturbationData_->lambda_));
- }
- if (operation == CheckpointDataOperation::Read)
- {
- if (DOMAINDECOMP(cr))
- {
- dd_bcast(cr->dd, sizeof(int), &freeEnergyPerturbationData_->currentFEPState_);
- dd_bcast(cr->dd, freeEnergyPerturbationData_->lambda_.size() * sizeof(real),
- freeEnergyPerturbationData_->lambda_.data());
- }
- update_mdatoms(freeEnergyPerturbationData_->mdAtoms_->mdatoms(),
- freeEnergyPerturbationData_->lambda_[efptMASS]);
- }
+ checkpointData->scalar("current FEP state", &freeEnergyPerturbationData_->currentFEPState_);
+ checkpointData->arrayRef("lambda vector",
+ makeCheckpointArrayRef<operation>(freeEnergyPerturbationData_->lambda_));
}
-void FreeEnergyPerturbationData::Element::writeCheckpoint(WriteCheckpointData checkpointData,
- const t_commrec* cr)
+void FreeEnergyPerturbationData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr)
{
- doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+ if (MASTER(cr))
+ {
+ doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+ }
}
-void FreeEnergyPerturbationData::Element::readCheckpoint(ReadCheckpointData checkpointData,
- const t_commrec* cr)
+void FreeEnergyPerturbationData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+ const t_commrec* cr)
{
- doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+ if (MASTER(cr))
+ {
+ doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
+ }
+ if (DOMAINDECOMP(cr))
+ {
+ dd_bcast(cr->dd, sizeof(int), &freeEnergyPerturbationData_->currentFEPState_);
+ dd_bcast(cr->dd, freeEnergyPerturbationData_->lambda_.size() * sizeof(real),
+ freeEnergyPerturbationData_->lambda_.data());
+ }
+ update_mdatoms(freeEnergyPerturbationData_->mdAtoms_->mdatoms(),
+ freeEnergyPerturbationData_->lambda_[efptMASS]);
}
const std::string& FreeEnergyPerturbationData::Element::clientID()
void elementTeardown() override{};
//! ICheckpointHelperClient write checkpoint implementation
- void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+ void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient read checkpoint implementation
- void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+ void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient key implementation
const std::string& clientID() override;
const std::string identifier_ = "FreeEnergyPerturbationElement";
//! Helper function to read from / write to CheckpointData
template<CheckpointDataOperation operation>
- void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+ void doCheckpointData(CheckpointData<operation>* checkpointData);
};
} // namespace gmx
//! Standard virtual destructor
virtual ~ICheckpointHelperClient() = default;
- //! Write checkpoint
- virtual void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) = 0;
- //! Read checkpoint
- virtual void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) = 0;
+ //! Write checkpoint (CheckpointData object only passed on master rank)
+ virtual void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr) = 0;
+ //! Read checkpoint (CheckpointData object only passed on master rank)
+ virtual void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+ const t_commrec* cr) = 0;
//! Get unique client id
[[nodiscard]] virtual const std::string& clientID() = 0;
};
} // namespace
template<CheckpointDataOperation operation>
-void ParrinelloRahmanBarostat::doCheckpointData(CheckpointData<operation>* checkpointData,
- const t_commrec* cr)
+void ParrinelloRahmanBarostat::doCheckpointData(CheckpointData<operation>* checkpointData)
+{
+ checkpointVersion(checkpointData, "ParrinelloRahmanBarostat version", c_currentVersion);
+
+ checkpointData->tensor("box velocity", boxVelocity_);
+ checkpointData->tensor("relative box vector", boxRel_);
+}
+
+void ParrinelloRahmanBarostat::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr)
{
if (MASTER(cr))
{
- checkpointVersion(checkpointData, "ParrinelloRahmanBarostat version", c_currentVersion);
+ doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+ }
+}
- checkpointData->tensor("box velocity", boxVelocity_);
- checkpointData->tensor("relative box vector", boxRel_);
+void ParrinelloRahmanBarostat::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+ const t_commrec* cr)
+{
+ if (MASTER(cr))
+ {
+ doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
}
- if (operation == CheckpointDataOperation::Read && DOMAINDECOMP(cr))
+ if (DOMAINDECOMP(cr))
{
dd_bcast(cr->dd, sizeof(boxVelocity_), boxVelocity_);
dd_bcast(cr->dd, sizeof(boxRel_), boxRel_);
}
}
-void ParrinelloRahmanBarostat::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
-{
- doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
-}
-
-void ParrinelloRahmanBarostat::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
-{
- doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
-}
-
const std::string& ParrinelloRahmanBarostat::clientID()
{
return identifier_;
void connectWithPropagator(const PropagatorBarostatConnection& connectionData);
//! ICheckpointHelperClient write checkpoint implementation
- void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+ void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient read checkpoint implementation
- void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+ void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient key implementation
const std::string& clientID() override;
const std::string identifier_ = "ParrinelloRahmanBarostat";
//! Helper function to read from / write to CheckpointData
template<CheckpointDataOperation operation>
- void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+ void doCheckpointData(CheckpointData<operation>* checkpointData);
// Access to ISimulator data
//! Handles logging.
if (DOMAINDECOMP(cr) && MASTER(cr))
{
- xGlobal_.reserveWithPadding(totalNumAtoms_);
- previousXGlobal_.reserveWithPadding(totalNumAtoms_);
- vGlobal_.reserveWithPadding(totalNumAtoms_);
- fGlobal_.reserveWithPadding(totalNumAtoms_);
+ xGlobal_.resizeWithPadding(totalNumAtoms_);
+ previousXGlobal_.resizeWithPadding(totalNumAtoms_);
+ vGlobal_.resizeWithPadding(totalNumAtoms_);
+ fGlobal_.resizeWithPadding(totalNumAtoms_);
}
if (!inputrec->bContinuation)
}
if (MASTER(cr))
{
+ GMX_ASSERT(checkpointData, "Master needs a valid pointer to a CheckpointData object");
checkpointVersion(checkpointData, "StatePropagatorData version", c_currentVersion);
checkpointData->arrayRef("positions", makeCheckpointArrayRef<operation>(xGlobalRef));
}
}
-void StatePropagatorData::Element::writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr)
+void StatePropagatorData::Element::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr)
{
- doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
+ doCheckpointData<CheckpointDataOperation::Write>(
+ checkpointData ? &checkpointData.value() : nullptr, cr);
}
-void StatePropagatorData::Element::readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr)
+/*!
+ * \brief Update the legacy global state
+ *
+ * When restoring from checkpoint, data will be distributed during domain decomposition at setup stage.
+ * Domain decomposition still uses the legacy global t_state object so make sure it's up-to-date.
+ */
+static void updateGlobalState(t_state* globalState,
+ const PaddedHostVector<RVec>& x,
+ const PaddedHostVector<RVec>& v,
+ const tensor box,
+ int ddpCount,
+ int ddpCountCgGl,
+ const std::vector<int>& cgGl)
+{
+ globalState->x = x;
+ globalState->v = v;
+ copy_mat(box, globalState->box);
+ globalState->ddp_count = ddpCount;
+ globalState->ddp_count_cg_gl = ddpCountCgGl;
+ globalState->cg_gl = cgGl;
+}
+
+void StatePropagatorData::Element::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+ const t_commrec* cr)
{
- doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
+ doCheckpointData<CheckpointDataOperation::Read>(checkpointData ? &checkpointData.value() : nullptr, cr);
+
+ // Copy data to global state to be distributed by DD at setup stage
+ if (DOMAINDECOMP(cr) && MASTER(cr))
+ {
+ updateGlobalState(statePropagatorData_->globalState_, statePropagatorData_->xGlobal_,
+ statePropagatorData_->vGlobal_, statePropagatorData_->box_,
+ statePropagatorData_->ddpCount_, statePropagatorData_->ddpCountCgGl_,
+ statePropagatorData_->cgGl_);
+ }
}
const std::string& StatePropagatorData::Element::clientID()
void setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData);
//! ICheckpointHelperClient write checkpoint implementation
- void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+ void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient read checkpoint implementation
- void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+ void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient key implementation
const std::string& clientID() override;
const TemperatureCouplingData& temperatureCouplingData) = 0;
//! Write private data to checkpoint
- virtual void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) = 0;
+ virtual void writeCheckpoint(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr) = 0;
//! Read private data from checkpoint
- virtual void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) = 0;
+ virtual void readCheckpoint(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) = 0;
//! Standard virtual destructor
virtual ~ITemperatureCouplingImpl() = default;
}
//! No data to write to checkpoint
- void writeCheckpoint(WriteCheckpointData gmx_unused checkpointData, const t_commrec gmx_unused* cr) override
+ void writeCheckpoint(std::optional<WriteCheckpointData> gmx_unused checkpointData,
+ const t_commrec gmx_unused* cr) override
{
}
//! No data to read from checkpoints
- void readCheckpoint(ReadCheckpointData gmx_unused checkpointData, const t_commrec gmx_unused* cr) override
+ void readCheckpoint(std::optional<ReadCheckpointData> gmx_unused checkpointData,
+ const t_commrec gmx_unused* cr) override
{
}
} // namespace
template<CheckpointDataOperation operation>
-void VelocityScalingTemperatureCoupling::doCheckpointData(CheckpointData<operation>* checkpointData,
- const t_commrec* cr)
+void VelocityScalingTemperatureCoupling::doCheckpointData(CheckpointData<operation>* checkpointData)
+{
+ checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
+
+ checkpointData->arrayRef("thermostat integral",
+ makeCheckpointArrayRef<operation>(temperatureCouplingIntegral_));
+}
+
+void VelocityScalingTemperatureCoupling::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData,
+ const t_commrec* cr)
{
if (MASTER(cr))
{
- checkpointVersion(checkpointData, "VRescaleThermostat version", c_currentVersion);
+ doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value());
+ }
+ temperatureCouplingImpl_->writeCheckpoint(
+ checkpointData
+ ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
+ : std::nullopt,
+ cr);
+}
- checkpointData->arrayRef("thermostat integral",
- makeCheckpointArrayRef<operation>(temperatureCouplingIntegral_));
+void VelocityScalingTemperatureCoupling::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+ const t_commrec* cr)
+{
+ if (MASTER(cr))
+ {
+ doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value());
}
- if (operation == CheckpointDataOperation::Read && DOMAINDECOMP(cr))
+ if (DOMAINDECOMP(cr))
{
dd_bcast(cr->dd, temperatureCouplingIntegral_.size() * sizeof(double),
temperatureCouplingIntegral_.data());
}
-}
-
-void VelocityScalingTemperatureCoupling::writeCheckpoint(WriteCheckpointData checkpointData,
- const t_commrec* cr)
-{
- doCheckpointData<CheckpointDataOperation::Write>(&checkpointData, cr);
- temperatureCouplingImpl_->writeCheckpoint(checkpointData.subCheckpointData("thermostat impl"), cr);
-}
-
-void VelocityScalingTemperatureCoupling::readCheckpoint(ReadCheckpointData checkpointData,
- const t_commrec* cr)
-{
- doCheckpointData<CheckpointDataOperation::Read>(&checkpointData, cr);
- temperatureCouplingImpl_->readCheckpoint(checkpointData.subCheckpointData("thermostat impl"), cr);
+ temperatureCouplingImpl_->readCheckpoint(
+ checkpointData
+ ? std::make_optional(checkpointData->subCheckpointData("thermostat impl"))
+ : std::nullopt,
+ cr);
}
const std::string& VelocityScalingTemperatureCoupling::clientID()
void connectWithPropagator(const PropagatorThermostatConnection& connectionData);
//! ICheckpointHelperClient write checkpoint implementation
- void writeCheckpoint(WriteCheckpointData checkpointData, const t_commrec* cr) override;
+ void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient read checkpoint implementation
- void readCheckpoint(ReadCheckpointData checkpointData, const t_commrec* cr) override;
+ void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
//! ICheckpointHelperClient key implementation
const std::string& clientID() override;
const std::string identifier_ = "VelocityScalingTemperatureCoupling";
//! Helper function to read from / write to CheckpointData
template<CheckpointDataOperation operation>
- void doCheckpointData(CheckpointData<operation>* checkpointData, const t_commrec* cr);
+ void doCheckpointData(CheckpointData<operation>* checkpointData);
};
} // namespace gmx