Implement pull for modular simulator
authorPascal Merz <pascal.merz@me.com>
Wed, 6 Oct 2021 21:37:46 +0000 (21:37 +0000)
committerPascal Merz <pascal.merz@me.com>
Wed, 6 Oct 2021 21:37:46 +0000 (21:37 +0000)
12 files changed:
src/gromacs/mdlib/update_vv.cpp
src/gromacs/mdrun/md.cpp
src/gromacs/modularsimulator/energydata.cpp
src/gromacs/modularsimulator/energydata.h
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/pullelement.cpp [new file with mode: 0644]
src/gromacs/modularsimulator/pullelement.h [new file with mode: 0644]
src/gromacs/modularsimulator/simulatoralgorithm.cpp
src/gromacs/pulling/pull.cpp
src/gromacs/pulling/pull.h
src/gromacs/pulling/pullutil.cpp
src/programs/mdrun/tests/simulator.cpp

index 58bca1a17ee2858d9334ae941a8ae89b6d85efd7..0f04fadfd363c01db5031388dcbb469e0ef0bc90 100644 (file)
@@ -389,7 +389,7 @@ void integrateVVSecondStep(int64_t                   step,
 
     if (ir->bPull && ir->pull->bSetPbcRefToPrevStepCOM)
     {
-        updatePrevStepPullCom(pull_work, state);
+        updatePrevStepPullCom(pull_work, state->pull_com_prev_step);
     }
 
     upd->update_coords(*ir,
index 85c29398e012f82c5ae52164123238d7b92137ea..cd657319c00c8e95426c52e21527c526dbd11646 100644 (file)
@@ -1624,7 +1624,7 @@ void gmx::LegacySimulator::do_md()
 
             if (ir->bPull && ir->pull->bSetPbcRefToPrevStepCOM)
             {
-                updatePrevStepPullCom(pull_work, state);
+                updatePrevStepPullCom(pull_work, state->pull_com_prev_step);
             }
 
             enerd->term[F_DVDL_CONSTR] += dvdl_constr;
index f2a58538cbc993bc795a66814c7c3e5669ef9bc8..fd24ac18ac6662177a3efc57913870e670e1be7e 100644 (file)
@@ -89,7 +89,8 @@ EnergyData::EnergyData(StatePropagatorData*        statePropagatorData,
                        bool                        isMasterRank,
                        ObservablesHistory*         observablesHistory,
                        StartingBehavior            startingBehavior,
-                       bool                        simulationsShareState) :
+                       bool                        simulationsShareState,
+                       pull_t*                     pullWork) :
     element_(std::make_unique<Element>(this, isMasterRank)),
     isMasterRank_(isMasterRank),
     forceVirialStep_(-1),
@@ -112,7 +113,8 @@ EnergyData::EnergyData(StatePropagatorData*        statePropagatorData,
     mdModulesNotifiers_(mdModulesNotifiers),
     groups_(&globalTopology.groups),
     observablesHistory_(observablesHistory),
-    simulationsShareState_(simulationsShareState)
+    simulationsShareState_(simulationsShareState),
+    pullWork_(pullWork)
 {
     clear_mat(forceVirial_);
     clear_mat(shakeVirial_);
@@ -161,11 +163,10 @@ void EnergyData::Element::trajectoryWriterSetup(gmx_mdoutf* outf)
 
 void EnergyData::setup(gmx_mdoutf* outf)
 {
-    pull_t* pull_work = nullptr;
-    energyOutput_     = std::make_unique<EnergyOutput>(mdoutf_get_fp_ene(outf),
+    energyOutput_ = std::make_unique<EnergyOutput>(mdoutf_get_fp_ene(outf),
                                                    top_global_,
                                                    *inputrec_,
-                                                   pull_work,
+                                                   pullWork_,
                                                    mdoutf_get_fp_dhdl(outf),
                                                    false,
                                                    startingBehavior_,
index 9f6411c1f02610276d15f9bbaae43f0320fa879e..e9e7f275f7e2fdaa10ac457175faf80947dcfc95 100644 (file)
@@ -52,6 +52,7 @@ class gmx_ekindata_t;
 struct gmx_enerdata_t;
 struct gmx_mtop_t;
 struct ObservablesHistory;
+struct pull_t;
 struct t_fcdata;
 struct t_inputrec;
 struct SimulationGroups;
@@ -111,7 +112,8 @@ public:
                bool                        isMasterRank,
                ObservablesHistory*         observablesHistory,
                StartingBehavior            startingBehavior,
-               bool                        simulationsShareState);
+               bool                        simulationsShareState,
+               pull_t*                     pullWork);
 
     /*! \brief Final output
      *
@@ -331,6 +333,8 @@ private:
     ObservablesHistory* observablesHistory_;
     //! Whether simulations share the state
     bool simulationsShareState_;
+    //! The pull work object.
+    pull_t* pullWork_;
 };
 
 /*! \internal
index d3fce2c0b5563bf8bce6df5c8821aaf171f55e58..fac5ea0d86bfd2c0db059fe66598559d18d15617 100644 (file)
@@ -84,6 +84,7 @@
 #include "mttk.h"
 #include "nosehooverchains.h"
 #include "parrinellorahmanbarostat.h"
+#include "pullelement.h"
 #include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
 #include "velocityscalingtemperaturecoupling.h"
@@ -133,6 +134,12 @@ void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder*
         {
             builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
+
+        if (legacySimulatorData_->inputrec->bPull)
+        {
+            builder->add<PullElement>();
+        }
+
         builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>();
         if (legacySimulatorData_->inputrec->epc == PressureCoupling::ParrinelloRahman)
         {
@@ -181,6 +188,12 @@ void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder*
         {
             builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
+
+        if (legacySimulatorData_->inputrec->bPull)
+        {
+            builder->add<PullElement>();
+        }
+
         builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>();
         if (legacySimulatorData_->inputrec->epc == PressureCoupling::ParrinelloRahman)
         {
@@ -304,6 +317,12 @@ void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder*
         {
             builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
+
+        if (legacySimulatorData_->inputrec->bPull)
+        {
+            builder->add<PullElement>();
+        }
+
         builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>();
 
         // Propagate box from t to t+dt
@@ -380,9 +399,6 @@ bool ModularSimulator::isInputCompatible(bool                             exitOn
     isInputCompatible =
             isInputCompatible
             && conditionalAssert(!doRerun, "Rerun is not supported by the modular simulator.");
-    isInputCompatible = isInputCompatible
-                        && conditionalAssert(!inputrec->bPull,
-                                             "Pulling is not supported by the modular simulator.");
     isInputCompatible =
             isInputCompatible
             && conditionalAssert(inputrec->cos_accel == 0.0,
diff --git a/src/gromacs/modularsimulator/pullelement.cpp b/src/gromacs/modularsimulator/pullelement.cpp
new file mode 100644 (file)
index 0000000..a5ba005
--- /dev/null
@@ -0,0 +1,186 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2020,2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+/*! \internal \file
+ * \brief Defines the pull element for the modular simulator
+ *
+ * \author Pascal Merz <pascal.merz@me.com>
+ * \ingroup module_modularsimulator
+ */
+
+#include "gmxpre.h"
+
+#include "gromacs/gmxlib/network.h"
+#include "gromacs/mdlib/mdatoms.h"
+#include "gromacs/mdtypes/commrec.h"
+#include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/mdtypes/mdatom.h"
+#include "gromacs/pbcutil/pbc.h"
+#include "gromacs/pulling/output.h"
+#include "gromacs/pulling/pull.h"
+
+#include "pullelement.h"
+#include "simulatoralgorithm.h"
+#include "statepropagatordata.h"
+
+namespace gmx
+{
+
+PullElement::PullElement(bool                 setPbcRefToPrevStepCOM,
+                         PbcType              pbcType,
+                         StatePropagatorData* statePropagatorData,
+                         pull_t*              pullWork,
+                         const t_commrec*     commrec,
+                         const MDAtoms*       mdAtoms) :
+    setPbcRefToPrevStepCOM_(setPbcRefToPrevStepCOM),
+    pbcType_(pbcType),
+    restoredFromCheckpoint_(false),
+    statePropagatorData_(statePropagatorData),
+    pullWork_(pullWork),
+    commrec_(commrec),
+    mdAtoms_(mdAtoms)
+{
+}
+
+void PullElement::elementSetup()
+{
+    if (setPbcRefToPrevStepCOM_ && !restoredFromCheckpoint_)
+    {
+        preparePrevStepPullComNewSimulation(
+                commrec_,
+                pullWork_,
+                arrayRefFromArray(mdAtoms_->mdatoms()->massT, mdAtoms_->mdatoms()->nr),
+                statePropagatorData_->constPositionsView().unpaddedArrayRef(),
+                statePropagatorData_->constBox(),
+                pbcType_,
+                std::nullopt);
+    }
+}
+
+void PullElement::scheduleTask(Step /*unused*/, Time /*unused*/, const RegisterRunFunction& registerRunFunction)
+{
+    if (setPbcRefToPrevStepCOM_)
+    {
+        registerRunFunction([this]() { updatePrevStepPullCom(pullWork_, std::nullopt); });
+    }
+}
+
+void PullElement::schedulePostStep(Step step, Time time, const RegisterRunFunction& registerRunFunction)
+{
+    // Printing output must happen after all external pull potentials
+    // (currently only AWH) were applied, so execute this after step
+    if (MASTER(commrec_))
+    {
+        registerRunFunction([this, step, time]() { pull_print_output(pullWork_, step, time); });
+    }
+}
+
+namespace
+{
+/*!
+ * \brief Enum describing the contents FreeEnergyPerturbationData::Element writes to modular checkpoint
+ *
+ * When changing the checkpoint content, add a new element just above Count, and adjust the
+ * checkpoint functionality.
+ */
+enum class CheckpointVersion
+{
+    Base, //!< First version of modular checkpointing
+    Count //!< Number of entries. Add new versions right above this!
+};
+constexpr auto c_currentVersion = CheckpointVersion(int(CheckpointVersion::Count) - 1);
+} // namespace
+
+template<CheckpointDataOperation operation>
+static void doCheckpointData(CheckpointData<operation>* checkpointData, ArrayRef<double> previousStepCom)
+{
+    checkpointVersion(checkpointData, "PullElement version", c_currentVersion);
+    checkpointData->arrayRef("Previous step COM positions",
+                             makeCheckpointArrayRef<operation>(previousStepCom));
+}
+
+void PullElement::saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr)
+{
+    if (MASTER(cr))
+    {
+        auto previousStepCom = prevStepPullCom(pullWork_);
+        doCheckpointData<CheckpointDataOperation::Write>(&checkpointData.value(), previousStepCom);
+    }
+}
+
+void PullElement::restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData,
+                                         const t_commrec*                  cr)
+{
+    auto previousStepCom = prevStepPullCom(pullWork_);
+    if (MASTER(cr))
+    {
+        doCheckpointData<CheckpointDataOperation::Read>(&checkpointData.value(), previousStepCom);
+    }
+    if (haveDDAtomOrdering(*cr))
+    {
+        gmx_bcast(sizeof(double) * previousStepCom.size(), previousStepCom.data(), cr->mpi_comm_mygroup);
+    }
+    setPrevStepPullCom(pullWork_, previousStepCom);
+    restoredFromCheckpoint_ = true;
+}
+
+const std::string& PullElement::clientID()
+{
+    return identifier_;
+}
+
+ISimulatorElement* PullElement::getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                      ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                      StatePropagatorData* statePropagatorData,
+                                                      EnergyData* /*energyData*/,
+                                                      FreeEnergyPerturbationData* /*freeEnergyPerturbationData*/,
+                                                      GlobalCommunicationHelper* /*globalCommunicationHelper*/,
+                                                      ObservablesReducer* /*observablesReducer*/)
+{
+    auto* pullElement = builderHelper->storeElement(std::make_unique<PullElement>(
+            legacySimulatorData->inputrec->pull->bSetPbcRefToPrevStepCOM,
+            legacySimulatorData->inputrec->pbcType,
+            statePropagatorData,
+            legacySimulatorData->pull_work,
+            legacySimulatorData->cr,
+            legacySimulatorData->mdAtoms));
+    // Printing output is scheduled after the step
+    builderHelper->registerPostStepScheduling(
+            [pullElement](Step step, Time time, const RegisterRunFunction& registerRunFunction) {
+                pullElement->schedulePostStep(step, time, registerRunFunction);
+            });
+    return pullElement;
+}
+
+} // namespace gmx
diff --git a/src/gromacs/modularsimulator/pullelement.h b/src/gromacs/modularsimulator/pullelement.h
new file mode 100644 (file)
index 0000000..fd49e9e
--- /dev/null
@@ -0,0 +1,140 @@
+/*
+ * This file is part of the GROMACS molecular simulation package.
+ *
+ * Copyright (c) 2020,2021, by the GROMACS development team, led by
+ * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
+ * and including many others, as listed in the AUTHORS file in the
+ * top-level source directory and at http://www.gromacs.org.
+ *
+ * GROMACS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1
+ * of the License, or (at your option) any later version.
+ *
+ * GROMACS is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with GROMACS; if not, see
+ * http://www.gnu.org/licenses, or write to the Free Software Foundation,
+ * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
+ *
+ * If you want to redistribute modifications to GROMACS, please
+ * consider that scientific software is very special. Version
+ * control is crucial - bugs must be traceable. We will be happy to
+ * consider code for inclusion in the official distribution, but
+ * derived work must not be called official GROMACS. Details are found
+ * in the README & COPYING files - if they are missing, get the
+ * official version at http://www.gromacs.org.
+ *
+ * To help us fund GROMACS development, we humbly ask that you cite
+ * the research papers on the package. Check out http://www.gromacs.org.
+ */
+/*! \internal \file
+ * \brief Declares the pull element for the modular simulator
+ *
+ * \author Pascal Merz <pascal.merz@me.com>
+ * \ingroup module_modularsimulator
+ *
+ * This header is only used within the modular simulator module
+ */
+
+#ifndef GMX_MODULARSIMULATOR_PULLELEMENT_H
+#define GMX_MODULARSIMULATOR_PULLELEMENT_H
+
+#include "modularsimulatorinterfaces.h"
+
+struct gmx_mtop_t;
+struct pull_t;
+struct t_inputrec;
+
+namespace gmx
+{
+class EnergyData;
+class FreeEnergyPerturbationData;
+class GlobalCommunicationHelper;
+class LegacySimulatorData;
+class MDAtoms;
+class ModularSimulatorAlgorithmBuilderHelper;
+class ObservablesReducer;
+class StatePropagatorData;
+
+/*! \internal
+ * \brief Element calling pull functionality
+ */
+class PullElement : public ISimulatorElement, public ICheckpointHelperClient
+{
+public:
+    //! Constructor
+    PullElement(bool                 setPbcRefToPrevStepCOM,
+                PbcType              pbcType,
+                StatePropagatorData* statePropagatorData,
+                pull_t*              pullWork,
+                const t_commrec*     commrec,
+                const MDAtoms*       mdAtoms);
+    //! Update annealing temperature
+    void scheduleTask(Step step, Time time, const RegisterRunFunction& registerRunFunction) override;
+    //! Set initial annealing temperature
+    void elementSetup() override;
+    //! No teardown needed
+    void elementTeardown() override {}
+
+    //! ICheckpointHelperClient write checkpoint implementation
+    void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient read checkpoint implementation
+    void restoreCheckpointState(std::optional<ReadCheckpointData> checkpointData, const t_commrec* cr) override;
+    //! ICheckpointHelperClient key implementation
+    const std::string& clientID() override;
+
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     * \param observablesReducer          Pointer to the \c ObservablesReducer object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper,
+                                                    ObservablesReducer*        observablesReducer);
+
+private:
+    //! Schedule post step functionality
+    void schedulePostStep(Step step, Time time, const RegisterRunFunction& registerRunFunction);
+
+    //! Whether to use the COM of each group from the previous step as reference
+    const bool setPbcRefToPrevStepCOM_;
+    //! The PBC type
+    const PbcType pbcType_;
+
+    //! CheckpointHelper identifier
+    const std::string identifier_ = "PullElement";
+    //! Whether this object was restored from checkpoint
+    bool restoredFromCheckpoint_;
+
+    // TODO: Clarify relationship to data objects and find a more robust alternative to raw pointers (#3583)
+    //! Pointer to the micro state
+    StatePropagatorData* statePropagatorData_;
+
+    // Access to LegacySimulatorData
+    //! The pull work object.
+    pull_t* pullWork_;
+    //! Handles communication.
+    const t_commrec* commrec_;
+    //! Atom parameters for this domain.
+    const MDAtoms* mdAtoms_;
+};
+} // namespace gmx
+
+
+#endif // GMX_MODULARSIMULATOR_PULLELEMENT_H
index 5d07cc83aad163637f366749704de2e30e9ae487..c6f96efc3a1134c9175a802bcff898628f25bba5 100644 (file)
@@ -455,7 +455,8 @@ ModularSimulatorAlgorithmBuilder::ModularSimulatorAlgorithmBuilder(
                                                MASTER(legacySimulatorData->cr),
                                                legacySimulatorData->observablesHistory,
                                                legacySimulatorData->startingBehavior,
-                                               simulationsShareState);
+                                               simulationsShareState,
+                                               legacySimulatorData->pull_work);
     registerExistingElement(energyData_->element());
 
     // This is the only modular simulator object which changes the inputrec
index cbef32a4029f512b7aa07b905fef558820837ee7..077bbb82ba5eac1d645fa233e891812bf971f5f7 100644 (file)
@@ -2449,6 +2449,20 @@ static void destroy_pull(struct pull_t* pull)
     delete pull;
 }
 
+void preparePrevStepPullComNewSimulation(const t_commrec*                       cr,
+                                         pull_t*                                pull_work,
+                                         ArrayRef<const real>                   masses,
+                                         ArrayRef<const RVec>                   x,
+                                         const matrix                           box,
+                                         PbcType                                pbcType,
+                                         std::optional<gmx::ArrayRef<double>>&& comPreviousStep)
+{
+    t_pbc pbc;
+    set_pbc(&pbc, pbcType, box);
+    initPullComFromPrevStep(cr, pull_work, masses, pbc, x);
+    updatePrevStepPullCom(pull_work, comPreviousStep);
+}
+
 void preparePrevStepPullCom(const t_inputrec*    ir,
                             pull_t*              pull_work,
                             ArrayRef<const real> masses,
@@ -2479,11 +2493,13 @@ void preparePrevStepPullCom(const t_inputrec*    ir,
     }
     else
     {
-        t_pbc pbc;
-        set_pbc(&pbc, ir->pbcType, state->box);
-        initPullComFromPrevStep(
-                cr, pull_work, masses, pbc, state->x.arrayRefWithPadding().unpaddedArrayRef());
-        updatePrevStepPullCom(pull_work, state);
+        preparePrevStepPullComNewSimulation(cr,
+                                            pull_work,
+                                            masses,
+                                            state->x.arrayRefWithPadding().unpaddedArrayRef(),
+                                            state->box,
+                                            ir->pbcType,
+                                            state->pull_com_prev_step);
     }
 }
 
index a27c65547d9e20edffdd2e38ca8bb31b05c62327..90a12967d17819df0fc5680ade9ed63d037018a0 100644 (file)
@@ -52,6 +52,7 @@
 #define GMX_PULLING_PULL_H
 
 #include <cstdio>
+#include <optional>
 
 #include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/pull_params.h"
@@ -69,6 +70,7 @@ struct t_filenm;
 struct t_inputrec;
 struct t_pbc;
 class t_state;
+enum class PbcType;
 
 namespace gmx
 {
@@ -351,12 +353,33 @@ bool pull_have_constraint(const pull_params_t& pullParameters);
  */
 real max_pull_distance2(const pull_coord_work_t& pcrd, const t_pbc& pbc);
 
-/*! \brief Sets the previous step COM in pull to the current COM and updates the pull_com_prev_step in the state
+/*! \brief Sets the previous step COM in pull to the current COM, and optionally
+ *         updates it in the provided ArrayRef
  *
- * \param[in]   pull  The COM pull force calculation data structure
- * \param[in]   state The local (to this rank) state.
+ * \param[in] pull  The COM pull force calculation data structure
+ * \param[in] comPreviousStep  The COM of the previous step of each pull group
  */
-void updatePrevStepPullCom(pull_t* pull, t_state* state);
+void updatePrevStepPullCom(pull_t* pull, std::optional<gmx::ArrayRef<double>> comPreviousStep);
+
+/*! \brief Returns a copy of the previous step pull COM as flat vector
+ *
+ * Used for modular simulator checkpointing. Allows to keep the
+ * implementation details of pull_t hidden from its users.
+ *
+ * \param[in] pull  The COM pull force calculation data structure
+ * \return A copy of the previous step COM
+ */
+std::vector<double> prevStepPullCom(const pull_t* pull);
+
+/*! \brief Set the previous step pull COM from a flat vector
+ *
+ * Used to restore modular simulator checkpoints. Allows to keep the
+ * implementation details of pull_t hidden from its users.
+ *
+ * \param[in] pull  The COM pull force calculation data structure
+ * \param[in] prevStepPullCom  The previous step COM to set
+ */
+void setPrevStepPullCom(pull_t* pull, gmx::ArrayRef<const double> prevStepPullCom);
 
 /*! \brief Allocates, initializes and communicates the previous step pull COM (if that option is set to true).
  *
@@ -392,4 +415,22 @@ void initPullComFromPrevStep(const t_commrec*               cr,
                              const t_pbc&                   pbc,
                              gmx::ArrayRef<const gmx::RVec> x);
 
+/*! \brief Initializes the previous step pull COM for new simulations (no reading from checkpoint).
+ *
+ * \param[in] cr               Struct for communication info.
+ * \param[in] pull_work        The COM pull force calculation data structure.
+ * \param[in] masses           Atoms masses.
+ * \param[in] x                The local positions.
+ * \param[in] box              The current box matrix.
+ * \param[in] pbcType          The type of periodic boundary conditions.
+ * \param[in] comPreviousStep  The COM of the previous step of each pull group.
+ */
+void preparePrevStepPullComNewSimulation(const t_commrec*                       cr,
+                                         pull_t*                                pull_work,
+                                         gmx::ArrayRef<const real>              masses,
+                                         gmx::ArrayRef<const gmx::RVec>         x,
+                                         const matrix                           box,
+                                         PbcType                                pbcType,
+                                         std::optional<gmx::ArrayRef<double>>&& comPreviousStep);
+
 #endif
index cf7a9778ecf57694a5cce484d56f2fe4d3779947..8aecd5fa3b89b58fdefadfee09d31463011fa825 100644 (file)
@@ -999,21 +999,81 @@ void setPrevStepPullComFromState(struct pull_t* pull, const t_state* state)
     }
 }
 
-void updatePrevStepPullCom(struct pull_t* pull, t_state* state)
+/*! \brief Whether pull functions save a backup to the t_state object
+ *
+ * Saving to the state object is only used for checkpointing in the legacy simulator.
+ * Modular simulator doesn't use the t_state object for checkpointing.
+ */
+enum class PullBackupCOM
 {
-    for (size_t g = 0; g < pull->group.size(); g++)
+    Yes, //<! Save a copy of the previous step COM to state
+    No,  //<! Don't save a copy of the previous step COM to state
+};
+
+/*! \brief Sets the previous step COM in pull to the current COM and optionally
+ *         stores it in the provided ArrayRef
+ *
+ * \tparam     pullBackupToState  Whether we're storing the previous COM to state
+ * \param[in]  pull  The COM pull force calculation data structure
+ * \param[in]   comPreviousStep  The COM of the previous step of each pull group
+ */
+template<PullBackupCOM pullBackupToState>
+static void updatePrevStepPullComImpl(pull_t* pull, gmx::ArrayRef<double> comPreviousStep)
+{
+    for (gmx::index g = 0; g < gmx::ssize(pull->group); g++)
     {
         if (pull->group[g].needToCalcCom)
         {
             for (int j = 0; j < DIM; j++)
             {
-                pull->group[g].x_prev_step[j]          = pull->group[g].x[j];
-                state->pull_com_prev_step[g * DIM + j] = pull->group[g].x[j];
+                pull->group[g].x_prev_step[j] = pull->group[g].x[j];
+                if (pullBackupToState == PullBackupCOM::Yes)
+                {
+                    comPreviousStep[g * DIM + j] = pull->group[g].x[j];
+                }
             }
         }
     }
 }
 
+void updatePrevStepPullCom(pull_t* pull, std::optional<gmx::ArrayRef<double>> comPreviousStep)
+{
+    if (comPreviousStep.has_value())
+    {
+        updatePrevStepPullComImpl<PullBackupCOM::Yes>(pull, comPreviousStep.value());
+    }
+    else
+    {
+        updatePrevStepPullComImpl<PullBackupCOM::No>(pull, gmx::ArrayRef<double>());
+    }
+}
+
+std::vector<double> prevStepPullCom(const pull_t* pull)
+{
+    std::vector<double> pullCom(pull->group.size() * DIM, 0.0);
+    for (gmx::index g = 0; g < gmx::ssize(pull->group); g++)
+    {
+        for (int j = 0; j < DIM; j++)
+        {
+            pullCom[g * DIM + j] = pull->group[g].x_prev_step[j];
+        }
+    }
+    return pullCom;
+}
+
+void setPrevStepPullCom(pull_t* pull, gmx::ArrayRef<const double> prevStepPullCom)
+{
+    GMX_RELEASE_ASSERT(prevStepPullCom.size() >= pull->group.size() * DIM,
+                       "Pull COM vector size mismatch.");
+    for (gmx::index g = 0; g < gmx::ssize(pull->group); g++)
+    {
+        for (int j = 0; j < DIM; j++)
+        {
+            pull->group[g].x_prev_step[j] = prevStepPullCom[g * DIM + j];
+        }
+    }
+}
+
 void allocStatePrevStepPullCom(t_state* state, const pull_t* pull)
 {
     if (!pull)
index 0e5a7047821ed353fe30c372ff2d02402e3dcabd..b33bf5019809704a8ba1284a4959815ce889af8b 100644 (file)
@@ -314,6 +314,22 @@ INSTANTIATE_TEST_SUITE_P(
                         ::testing::Values("no", "Parrinello-Rahman", "berendsen", "c-rescale"),
                         ::testing::Values(MdpParameterDatabase::Default)),
                 ::testing::Values("GMX_USE_MODULAR_SIMULATOR")));
+INSTANTIATE_TEST_SUITE_P(SimulatorsAreEquivalentDefaultModularPull,
+                         SimulatorComparisonTest,
+                         ::testing::Combine(::testing::Combine(::testing::Values("spc2"),
+                                                               ::testing::Values("md-vv"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values(MdpParameterDatabase::Pull)),
+                                            ::testing::Values("GMX_DISABLE_MODULAR_SIMULATOR")));
+INSTANTIATE_TEST_SUITE_P(SimulatorsAreEquivalentDefaultLegacyPull,
+                         SimulatorComparisonTest,
+                         ::testing::Combine(::testing::Combine(::testing::Values("spc2"),
+                                                               ::testing::Values("md"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values(MdpParameterDatabase::Pull)),
+                                            ::testing::Values("GMX_USE_MODULAR_SIMULATOR")));
 #else
 INSTANTIATE_TEST_SUITE_P(
         DISABLED_SimulatorsAreEquivalentDefaultModular,
@@ -341,6 +357,22 @@ INSTANTIATE_TEST_SUITE_P(
                         ::testing::Values("no", "Parrinello-Rahman", "berendsen", "c-rescale"),
                         ::testing::Values(MdpParameterDatabase::Default)),
                 ::testing::Values("GMX_USE_MODULAR_SIMULATOR")));
+INSTANTIATE_TEST_SUITE_P(DISABLED_SimulatorsAreEquivalentDefaultModularPull,
+                         SimulatorComparisonTest,
+                         ::testing::Combine(::testing::Combine(::testing::Values("spc2"),
+                                                               ::testing::Values("md-vv"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values(MdpParameterDatabase::Pull)),
+                                            ::testing::Values("GMX_DISABLE_MODULAR_SIMULATOR")));
+INSTANTIATE_TEST_SUITE_P(DISABLED_SimulatorsAreEquivalentDefaultLegacyPull,
+                         SimulatorComparisonTest,
+                         ::testing::Combine(::testing::Combine(::testing::Values("spc2"),
+                                                               ::testing::Values("md"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values("no"),
+                                                               ::testing::Values(MdpParameterDatabase::Pull)),
+                                            ::testing::Values("GMX_USE_MODULAR_SIMULATOR")));
 #endif
 
 } // namespace