Introduce flexible ModularSimulatorAlgorithmBuilder: Helper structs to connect propag...
authorPascal Merz <pascal.merz@me.com>
Fri, 14 Aug 2020 12:21:48 +0000 (12:21 +0000)
committerChristian Blau <cblau.mail@gmail.com>
Fri, 14 Aug 2020 12:21:48 +0000 (12:21 +0000)
23 files changed:
src/gromacs/modularsimulator/computeglobalselement.cpp
src/gromacs/modularsimulator/computeglobalselement.h
src/gromacs/modularsimulator/constraintelement.cpp
src/gromacs/modularsimulator/constraintelement.h
src/gromacs/modularsimulator/energydata.cpp
src/gromacs/modularsimulator/energydata.h
src/gromacs/modularsimulator/forceelement.cpp
src/gromacs/modularsimulator/forceelement.h
src/gromacs/modularsimulator/freeenergyperturbationdata.cpp
src/gromacs/modularsimulator/freeenergyperturbationdata.h
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/modularsimulator.h
src/gromacs/modularsimulator/modularsimulatorinterfaces.h
src/gromacs/modularsimulator/parrinellorahmanbarostat.cpp
src/gromacs/modularsimulator/parrinellorahmanbarostat.h
src/gromacs/modularsimulator/propagator.cpp
src/gromacs/modularsimulator/propagator.h
src/gromacs/modularsimulator/simulatoralgorithm.cpp
src/gromacs/modularsimulator/simulatoralgorithm.h
src/gromacs/modularsimulator/statepropagatordata.cpp
src/gromacs/modularsimulator/statepropagatordata.h
src/gromacs/modularsimulator/vrescalethermostat.cpp
src/gromacs/modularsimulator/vrescalethermostat.h

index 5759fe1aef395842017efea5161544776ca10d67..df0853079c6912b3fcaa578edb37a3dced7e32ea 100644 (file)
 #include "computeglobalselement.h"
 
 #include "gromacs/domdec/partition.h"
+#include "gromacs/gmxlib/network.h"
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdlib/md_support.h"
 #include "gromacs/mdlib/mdatoms.h"
 #include "gromacs/mdlib/stat.h"
+#include "gromacs/mdlib/update.h"
+#include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/group.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/md_enums.h"
@@ -56,6 +59,8 @@
 #include "gromacs/topology/topology.h"
 
 #include "freeenergyperturbationdata.h"
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
 
 namespace gmx
 {
@@ -346,4 +351,113 @@ SignallerCallbackPtr ComputeGlobalsElement<algorithm>::registerTrajectorySignall
 template class ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>;
 template class ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>;
 //! @}
+
+template<>
+ISimulatorElement* ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>::getElementPointerImpl(
+        LegacySimulatorData*                    legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+        StatePropagatorData*                    statePropagatorData,
+        EnergyData*                             energyData,
+        FreeEnergyPerturbationData*             freeEnergyPerturbationData,
+        GlobalCommunicationHelper*              globalCommunicationHelper)
+{
+    /* When restarting from a checkpoint, it can be appropriate to
+     * initialize ekind from quantities in the checkpoint. Otherwise,
+     * compute_globals must initialize ekind before the simulation
+     * starts/restarts. However, only the master rank knows what was
+     * found in the checkpoint file, so we have to communicate in
+     * order to coordinate the restart.
+     *
+     * TODO This will become obsolete as soon as checkpoint reading
+     *      happens within the modular simulator framework: The energy
+     *      element will read its data from the checkpoint file pointer,
+     *      and signal to the compute globals element if it needs anything
+     *      reduced.
+     */
+    bool hasReadEkinState = MASTER(legacySimulatorData->cr)
+                                    ? legacySimulatorData->state_global->ekinstate.hasReadEkinState
+                                    : false;
+    if (PAR(legacySimulatorData->cr))
+    {
+        gmx_bcast(sizeof(hasReadEkinState), &hasReadEkinState, legacySimulatorData->cr->mpi_comm_mygroup);
+    }
+    if (hasReadEkinState)
+    {
+        restore_ekinstate_from_state(legacySimulatorData->cr, legacySimulatorData->ekind,
+                                     &legacySimulatorData->state_global->ekinstate);
+    }
+    auto* element = builderHelper->storeElement(
+            std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>(
+                    statePropagatorData, energyData, freeEnergyPerturbationData,
+                    globalCommunicationHelper->simulationSignals(),
+                    globalCommunicationHelper->nstglobalcomm(), legacySimulatorData->fplog,
+                    legacySimulatorData->mdlog, legacySimulatorData->cr,
+                    legacySimulatorData->inputrec, legacySimulatorData->mdAtoms,
+                    legacySimulatorData->nrnb, legacySimulatorData->wcycle, legacySimulatorData->fr,
+                    legacySimulatorData->top_global, legacySimulatorData->constr, hasReadEkinState));
+
+    // TODO: Remove this when DD can reduce bonded interactions independently (#3421)
+    auto* castedElement = static_cast<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>*>(element);
+    globalCommunicationHelper->setCheckBondedInteractionsCallback(
+            castedElement->getCheckNumberOfBondedInteractionsCallback());
+
+    return element;
+}
+
+template<>
+ISimulatorElement* ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>::getElementPointerImpl(
+        LegacySimulatorData*                    simulator,
+        ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+        StatePropagatorData*                    statePropagatorData,
+        EnergyData*                             energyData,
+        FreeEnergyPerturbationData*             freeEnergyPerturbationData,
+        GlobalCommunicationHelper*              globalCommunicationHelper)
+{
+    // We allow this element to be added multiple times to the call list, but we only want one
+    // actual element built
+    static thread_local ISimulatorElement* vvComputeGlobalsElement = nullptr;
+    if (!builderHelper->elementIsStored(vvComputeGlobalsElement))
+    {
+        /* When restarting from a checkpoint, it can be appropriate to
+         * initialize ekind from quantities in the checkpoint. Otherwise,
+         * compute_globals must initialize ekind before the simulation
+         * starts/restarts. However, only the master rank knows what was
+         * found in the checkpoint file, so we have to communicate in
+         * order to coordinate the restart.
+         *
+         * TODO This will become obsolete as soon as checkpoint reading
+         *      happens within the modular simulator framework: The energy
+         *      element will read its data from the checkpoint file pointer,
+         *      and signal to the compute globals element if it needs anything
+         *      reduced.
+         */
+        bool hasReadEkinState =
+                MASTER(simulator->cr) ? simulator->state_global->ekinstate.hasReadEkinState : false;
+        if (PAR(simulator->cr))
+        {
+            gmx_bcast(sizeof(hasReadEkinState), &hasReadEkinState, simulator->cr->mpi_comm_mygroup);
+        }
+        if (hasReadEkinState)
+        {
+            restore_ekinstate_from_state(simulator->cr, simulator->ekind,
+                                         &simulator->state_global->ekinstate);
+        }
+        vvComputeGlobalsElement = builderHelper->storeElement(
+                std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>(
+                        statePropagatorData, energyData, freeEnergyPerturbationData,
+                        globalCommunicationHelper->simulationSignals(),
+                        globalCommunicationHelper->nstglobalcomm(), simulator->fplog,
+                        simulator->mdlog, simulator->cr, simulator->inputrec, simulator->mdAtoms,
+                        simulator->nrnb, simulator->wcycle, simulator->fr, simulator->top_global,
+                        simulator->constr, hasReadEkinState));
+
+        // TODO: Remove this when DD can reduce bonded interactions independently (#3421)
+        auto* castedElement =
+                static_cast<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>*>(
+                        vvComputeGlobalsElement);
+        globalCommunicationHelper->setCheckBondedInteractionsCallback(
+                castedElement->getCheckNumberOfBondedInteractionsCallback());
+    }
+    return vvComputeGlobalsElement;
+}
 } // namespace gmx
index d8800bc09729615754fbad1ce1bc9c69550c79e9..a60fe246496bf1475ac21609b23eee068f51559a 100644 (file)
@@ -59,6 +59,7 @@ struct t_nrnb;
 namespace gmx
 {
 class FreeEnergyPerturbationData;
+class LegacySimulatorData;
 class MDAtoms;
 class MDLogger;
 
@@ -147,6 +148,24 @@ public:
     //! No element teardown needed
     void elementTeardown() override {}
 
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper);
+
 private:
     //! ITopologyClient implementation
     void setTopology(const gmx_localtop_t* top) override;
index 6625bf58e7dc3a84b85be179b820c05157faee42..4ac9b39a1f703fa4bcd03bbf764f279784f1696b 100644 (file)
 #include "constraintelement.h"
 
 #include "gromacs/math/vec.h"
+#include "gromacs/mdlib/mdatoms.h"
+#include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/inputrec.h"
+#include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/state.h"
 #include "gromacs/utility/fatalerror.h"
 
 #include "energydata.h"
 #include "freeenergyperturbationdata.h"
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
 
 namespace gmx
@@ -206,10 +211,23 @@ SignallerCallbackPtr ConstraintsElement<variable>::registerLoggingCallback()
             [this](Step step, Time /*unused*/) { nextLogWritingStep_ = step; });
 }
 
-//! Explicit template initialization
-//! @{
+template<ConstraintVariable variable>
+ISimulatorElement* ConstraintsElement<variable>::getElementPointerImpl(
+        LegacySimulatorData*                    legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+        StatePropagatorData*                    statePropagatorData,
+        EnergyData*                             energyData,
+        FreeEnergyPerturbationData*             freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
+{
+    return builderHelper->storeElement(std::make_unique<ConstraintsElement<variable>>(
+            legacySimulatorData->constr, statePropagatorData, energyData,
+            freeEnergyPerturbationData, MASTER(legacySimulatorData->cr), legacySimulatorData->fplog,
+            legacySimulatorData->inputrec, legacySimulatorData->mdAtoms->mdatoms()));
+}
+
+// Explicit template initializations
 template class ConstraintsElement<ConstraintVariable::Positions>;
 template class ConstraintsElement<ConstraintVariable::Velocities>;
-//! @}
 
 } // namespace gmx
index 24dc0b1a0773e6f3b881a64ae3b8044e345e0c1d..e61333a02a5616c170fbd629c35a271201cec447 100644 (file)
@@ -53,6 +53,9 @@ namespace gmx
 class Constraints;
 class EnergyData;
 class FreeEnergyPerturbationData;
+class GlobalCommunicationHelper;
+class LegacySimulatorData;
+class ModularSimulatorAlgorithmBuilderHelper;
 class StatePropagatorData;
 
 /*! \internal
@@ -105,6 +108,24 @@ public:
     //! No element teardown needed
     void elementTeardown() override {}
 
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper);
+
 private:
     //! The actual constraining computation
     void apply(Step step, bool calculateVirial, bool writeLog, bool writeEnergy);
index 61a971762ed8422ee6998aa819446427df2c594c..2a819f2dbe9739219392f879cdebd17244805391 100644 (file)
@@ -53,6 +53,7 @@
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdlib/update.h"
 #include "gromacs/mdrunutility/handlerestart.h"
+#include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/energyhistory.h"
 #include "gromacs/mdtypes/inputrec.h"
@@ -63,7 +64,9 @@
 #include "gromacs/topology/topology.h"
 
 #include "freeenergyperturbationdata.h"
+#include "modularsimulator.h"
 #include "parrinellorahmanbarostat.h"
+#include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
 #include "vrescalethermostat.h"
 
@@ -461,4 +464,15 @@ EnergyData::Element::Element(EnergyData* energyData, bool isMasterRank) :
 {
 }
 
+ISimulatorElement* EnergyData::Element::getElementPointerImpl(
+        LegacySimulatorData gmx_unused*        legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
+        StatePropagatorData gmx_unused* statePropagatorData,
+        EnergyData*                     energyData,
+        FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
+{
+    return energyData->element();
+}
+
 } // namespace gmx
index 7efd1cec0ff5c7639bf33addd869e5beab404349..ecc8af04f3737ce59b84ae1833f76adaaea8efe5 100644 (file)
@@ -63,7 +63,10 @@ enum class StartingBehavior;
 class Constraints;
 class EnergyOutput;
 class FreeEnergyPerturbationData;
+class GlobalCommunicationHelper;
+class LegacySimulatorData;
 class MDAtoms;
+class ModularSimulatorAlgorithmBuilderHelper;
 class ParrinelloRahmanBarostat;
 class StatePropagatorData;
 class VRescaleThermostat;
@@ -352,6 +355,24 @@ public:
     //! No element teardown needed
     void elementTeardown() override {}
 
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper);
+
 private:
     EnergyData* energyData_;
 
index 856d27ea03f9c0fc5026cd971e500efd84a89770..df6322efc723724525399400bdff5076229aeed3 100644 (file)
 #include "gromacs/mdrun/shellfc.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/mdatom.h"
+#include "gromacs/mdtypes/mdrunoptions.h"
 #include "gromacs/pbcutil/pbc.h"
 
 #include "energydata.h"
 #include "freeenergyperturbationdata.h"
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
 
 struct gmx_edsam;
@@ -248,4 +251,23 @@ SignallerCallbackPtr ForceElement::registerEnergyCallback(EnergySignallerEvent e
     }
     return nullptr;
 }
+
+ISimulatorElement*
+ForceElement::getElementPointerImpl(LegacySimulatorData*                    legacySimulatorData,
+                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                    StatePropagatorData*                    statePropagatorData,
+                                    EnergyData*                             energyData,
+                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                    GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
+{
+    const bool isVerbose    = legacySimulatorData->mdrunOptions.verbose;
+    const bool isDynamicBox = inputrecDynamicBox(legacySimulatorData->inputrec);
+    return builderHelper->storeElement(std::make_unique<ForceElement>(
+            statePropagatorData, energyData, freeEnergyPerturbationData, isVerbose, isDynamicBox,
+            legacySimulatorData->fplog, legacySimulatorData->cr, legacySimulatorData->inputrec,
+            legacySimulatorData->mdAtoms, legacySimulatorData->nrnb, legacySimulatorData->fr,
+            legacySimulatorData->wcycle, legacySimulatorData->runScheduleWork, legacySimulatorData->vsite,
+            legacySimulatorData->imdSession, legacySimulatorData->pull_work, legacySimulatorData->constr,
+            legacySimulatorData->top_global, legacySimulatorData->enforcedRotation));
+}
 } // namespace gmx
index d7a8689494809273e5af1b01a0ce56de6320e9dd..72959d99b817dfd946b14115fb8300e49049709f 100644 (file)
@@ -67,9 +67,12 @@ namespace gmx
 class Awh;
 class EnergyData;
 class FreeEnergyPerturbationData;
+class GlobalCommunicationHelper;
 class ImdSession;
+class LegacySimulatorData;
 class MDAtoms;
 class MdrunScheduleWorkload;
+class ModularSimulatorAlgorithmBuilderHelper;
 class StatePropagatorData;
 class VirtualSitesHandler;
 
@@ -99,15 +102,14 @@ public:
                  const MDAtoms*              mdAtoms,
                  t_nrnb*                     nrnb,
                  t_forcerec*                 fr,
-
-                 gmx_wallcycle*         wcycle,
-                 MdrunScheduleWorkload* runScheduleWork,
-                 VirtualSitesHandler*   vsite,
-                 ImdSession*            imdSession,
-                 pull_t*                pull_work,
-                 Constraints*           constr,
-                 const gmx_mtop_t*      globalTopology,
-                 gmx_enfrot*            enforcedRotation);
+                 gmx_wallcycle*              wcycle,
+                 MdrunScheduleWorkload*      runScheduleWork,
+                 VirtualSitesHandler*        vsite,
+                 ImdSession*                 imdSession,
+                 pull_t*                     pull_work,
+                 Constraints*                constr,
+                 const gmx_mtop_t*           globalTopology,
+                 gmx_enfrot*                 enforcedRotation);
 
     /*! \brief Register force calculation for step / time
      *
@@ -122,6 +124,24 @@ public:
     //! Print some final output
     void elementTeardown() override;
 
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper);
+
 private:
     //! ITopologyHolderClient implementation
     void setTopology(const gmx_localtop_t* top) override;
index 8b7687746b0b0e6a886c197d0d2b2b1e8d1555c6..d85343b8049a439966a87f872fc4897de0852c52 100644 (file)
@@ -49,6 +49,9 @@
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/mdtypes/state.h"
 
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
+
 namespace gmx
 {
 
@@ -119,4 +122,15 @@ FreeEnergyPerturbationData::Element* FreeEnergyPerturbationData::element()
     return element_.get();
 }
 
+ISimulatorElement* FreeEnergyPerturbationData::Element::getElementPointerImpl(
+        LegacySimulatorData gmx_unused*        legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
+        StatePropagatorData gmx_unused* statePropagatorData,
+        EnergyData gmx_unused*      energyData,
+        FreeEnergyPerturbationData* freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
+{
+    return freeEnergyPerturbationData->element();
+}
+
 } // namespace gmx
index ab89d087cfdfe59cf4f23686c837d8d35c1c4b25..23e39530c9db2f168913a1706d92315084550874 100644 (file)
@@ -54,7 +54,12 @@ struct t_inputrec;
 
 namespace gmx
 {
+class EnergyData;
+class GlobalCommunicationHelper;
+class LegacySimulatorData;
 class MDAtoms;
+class ModularSimulatorAlgorithmBuilderHelper;
+class StatePropagatorData;
 
 /*! \internal
  * \ingroup module_modularsimulator
@@ -130,6 +135,24 @@ public:
     //! No teardown needed
     void elementTeardown() override{};
 
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper);
+
 private:
     //! The free energy data
     FreeEnergyPerturbationData* freeEnergyPerturbationData_;
index 5552e94ff23993260eb6d4001531df7f7210aa97..dbb04d3fb5746b839d715b0e5a786d155a1d84b6 100644 (file)
@@ -48,7 +48,6 @@
 #include "gromacs/ewald/pme.h"
 #include "gromacs/ewald/pme_load_balancing.h"
 #include "gromacs/ewald/pme_pp.h"
-#include "gromacs/gmxlib/network.h"
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/listed_forces/listed_forces.h"
 #include "gromacs/mdlib/checkpointhandler.h"
 #include "gromacs/topology/topology.h"
 #include "gromacs/utility/fatalerror.h"
 
-#include "compositesimulatorelement.h"
 #include "computeglobalselement.h"
 #include "constraintelement.h"
-#include "energydata.h"
 #include "forceelement.h"
-#include "freeenergyperturbationdata.h"
 #include "parrinellorahmanbarostat.h"
-#include "propagator.h"
-#include "signallers.h"
 #include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
-#include "trajectoryelement.h"
 #include "vrescalethermostat.h"
 
 namespace gmx
@@ -96,8 +89,9 @@ void ModularSimulator::run()
             .asParagraph()
             .appendText("Using the modular simulator.");
 
-    ModularSimulatorAlgorithmBuilder algorithmBuilder(compat::make_not_null(legacySimulatorData_.get()));
-    auto                             algorithm = algorithmBuilder.build();
+    ModularSimulatorAlgorithmBuilder algorithmBuilder(compat::make_not_null(legacySimulatorData_));
+    addIntegrationElements(&algorithmBuilder);
+    auto algorithm = algorithmBuilder.build();
 
     while (const auto* task = algorithm.getNextTask())
     {
@@ -106,284 +100,66 @@ void ModularSimulator::run()
     }
 }
 
-std::unique_ptr<ISimulatorElement> ModularSimulatorAlgorithmBuilder::buildForces(
-        SignallerBuilder<NeighborSearchSignaller>* neighborSearchSignallerBuilder,
-        SignallerBuilder<EnergySignaller>*         energySignallerBuilder,
-        StatePropagatorData*                       statePropagatorDataPtr,
-        EnergyData*                                energyDataPtr,
-        FreeEnergyPerturbationData*                freeEnergyPerturbationDataPtr,
-        TopologyHolder::Builder*                   topologyHolderBuilder)
+void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder* builder)
 {
-    const bool isVerbose    = legacySimulatorData_->mdrunOptions.verbose;
-    const bool isDynamicBox = inputrecDynamicBox(legacySimulatorData_->inputrec);
-
-    auto forceElement = std::make_unique<ForceElement>(
-            statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr, isVerbose, isDynamicBox,
-            legacySimulatorData_->fplog, legacySimulatorData_->cr, legacySimulatorData_->inputrec,
-            legacySimulatorData_->mdAtoms, legacySimulatorData_->nrnb, legacySimulatorData_->fr,
-            legacySimulatorData_->wcycle, legacySimulatorData_->runScheduleWork,
-            legacySimulatorData_->vsite, legacySimulatorData_->imdSession,
-            legacySimulatorData_->pull_work, legacySimulatorData_->constr,
-            legacySimulatorData_->top_global, legacySimulatorData_->enforcedRotation);
-    topologyHolderBuilder->registerClient(forceElement.get());
-    neighborSearchSignallerBuilder->registerSignallerClient(compat::make_not_null(forceElement.get()));
-    energySignallerBuilder->registerSignallerClient(compat::make_not_null(forceElement.get()));
-
-    // std::move *should* not be needed with c++-14, but clang-3.6 still requires it
-    return std::move(forceElement);
-}
-
-std::unique_ptr<ISimulatorElement> ModularSimulatorAlgorithmBuilder::buildIntegrator(
-        SignallerBuilder<NeighborSearchSignaller>* neighborSearchSignallerBuilder,
-        SignallerBuilder<LastStepSignaller>*       lastStepSignallerBuilder,
-        SignallerBuilder<EnergySignaller>*         energySignallerBuilder,
-        SignallerBuilder<LoggingSignaller>*        loggingSignallerBuilder,
-        SignallerBuilder<TrajectorySignaller>*     trajectorySignallerBuilder,
-        TrajectoryElementBuilder*                  trajectoryElementBuilder,
-        std::vector<ICheckpointHelperClient*>*     checkpointClients,
-        compat::not_null<StatePropagatorData*>     statePropagatorDataPtr,
-        compat::not_null<EnergyData*>              energyDataPtr,
-        FreeEnergyPerturbationData*                freeEnergyPerturbationDataPtr,
-        bool                                       hasReadEkinState,
-        TopologyHolder::Builder*                   topologyHolderBuilder,
-        GlobalCommunicationHelper*                 globalCommunicationHelper)
-{
-    auto forceElement = buildForces(neighborSearchSignallerBuilder, energySignallerBuilder,
-                                    statePropagatorDataPtr, energyDataPtr,
-                                    freeEnergyPerturbationDataPtr, topologyHolderBuilder);
-
-    // list of elements owned by the simulator composite object
-    std::vector<std::unique_ptr<ISimulatorElement>> elementsOwnershipList;
-    // call list of the simulator composite object
-    std::vector<compat::not_null<ISimulatorElement*>> elementCallList;
-
-    std::function<void()> needToCheckNumberOfBondedInteractions;
     if (legacySimulatorData_->inputrec->eI == eiMD)
     {
-        auto computeGlobalsElement =
-                std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>(
-                        statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr,
-                        globalCommunicationHelper->simulationSignals(),
-                        globalCommunicationHelper->nstglobalcomm(), legacySimulatorData_->fplog,
-                        legacySimulatorData_->mdlog, legacySimulatorData_->cr,
-                        legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms,
-                        legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
-                        legacySimulatorData_->fr, legacySimulatorData_->top_global,
-                        legacySimulatorData_->constr, hasReadEkinState);
-        topologyHolderBuilder->registerClient(computeGlobalsElement.get());
-        energySignallerBuilder->registerSignallerClient(compat::make_not_null(computeGlobalsElement.get()));
-        trajectorySignallerBuilder->registerSignallerClient(
-                compat::make_not_null(computeGlobalsElement.get()));
-
-        globalCommunicationHelper->setCheckBondedInteractionsCallback(
-                computeGlobalsElement->getCheckNumberOfBondedInteractionsCallback());
-
-        auto propagator = std::make_unique<Propagator<IntegrationStep::LeapFrog>>(
-                legacySimulatorData_->inputrec->delta_t, statePropagatorDataPtr,
-                legacySimulatorData_->mdAtoms, legacySimulatorData_->wcycle);
-
-        addToCallListAndMove(std::move(forceElement), elementCallList, elementsOwnershipList);
-        auto stateElement = compat::make_not_null(statePropagatorDataPtr->element());
-        trajectoryElementBuilder->registerWriterClient(stateElement);
-        trajectorySignallerBuilder->registerSignallerClient(stateElement);
-        lastStepSignallerBuilder->registerSignallerClient(stateElement);
-        checkpointClients->emplace_back(stateElement);
-        // we have a full microstate at time t here!
-        addToCallList(stateElement, elementCallList);
+        // The leap frog integration algorithm
+        builder->add<ForceElement>();
+        builder->add<StatePropagatorData::Element>();
         if (legacySimulatorData_->inputrec->etc == etcVRESCALE)
         {
-            // TODO: With increased complexity of the propagator, this will need further development,
-            //       e.g. using propagators templated for velocity propagation policies and a builder
-            propagator->setNumVelocityScalingVariables(legacySimulatorData_->inputrec->opts.ngtc);
-            auto thermostat = std::make_unique<VRescaleThermostat>(
-                    legacySimulatorData_->inputrec->nsttcouple, -1, false,
-                    legacySimulatorData_->inputrec->ld_seed, legacySimulatorData_->inputrec->opts.ngtc,
-                    legacySimulatorData_->inputrec->delta_t * legacySimulatorData_->inputrec->nsttcouple,
-                    legacySimulatorData_->inputrec->opts.ref_t,
-                    legacySimulatorData_->inputrec->opts.tau_t, legacySimulatorData_->inputrec->opts.nrdf,
-                    energyDataPtr, propagator->viewOnVelocityScaling(),
-                    propagator->velocityScalingCallback(), legacySimulatorData_->state_global,
-                    legacySimulatorData_->cr, legacySimulatorData_->inputrec->bContinuation);
-            checkpointClients->emplace_back(thermostat.get());
-            energyDataPtr->setVRescaleThermostat(thermostat.get());
-            addToCallListAndMove(std::move(thermostat), elementCallList, elementsOwnershipList);
-        }
-
-        std::unique_ptr<ParrinelloRahmanBarostat> prBarostat = nullptr;
-        if (legacySimulatorData_->inputrec->epc == epcPARRINELLORAHMAN)
-        {
-            // Building the PR barostat here since it needs access to the propagator
-            // and we want to be able to move the propagator object
-            prBarostat = std::make_unique<ParrinelloRahmanBarostat>(
-                    legacySimulatorData_->inputrec->nstpcouple, -1,
-                    legacySimulatorData_->inputrec->delta_t * legacySimulatorData_->inputrec->nstpcouple,
-                    legacySimulatorData_->inputrec->init_step, propagator->viewOnPRScalingMatrix(),
-                    propagator->prScalingCallback(), statePropagatorDataPtr, energyDataPtr,
-                    legacySimulatorData_->fplog, legacySimulatorData_->inputrec,
-                    legacySimulatorData_->mdAtoms, legacySimulatorData_->state_global,
-                    legacySimulatorData_->cr, legacySimulatorData_->inputrec->bContinuation);
-            energyDataPtr->setParrinelloRahamnBarostat(prBarostat.get());
-            checkpointClients->emplace_back(prBarostat.get());
+            builder->add<VRescaleThermostat>(-1, VRescaleThermostatUseFullStepKE::No);
         }
-        addToCallListAndMove(std::move(propagator), elementCallList, elementsOwnershipList);
+        builder->add<Propagator<IntegrationStep::LeapFrog>>(legacySimulatorData_->inputrec->delta_t,
+                                                            RegisterWithThermostat::True,
+                                                            RegisterWithBarostat::True);
         if (legacySimulatorData_->constr)
         {
-            auto constraintElement = std::make_unique<ConstraintsElement<ConstraintVariable::Positions>>(
-                    legacySimulatorData_->constr, statePropagatorDataPtr, energyDataPtr,
-                    freeEnergyPerturbationDataPtr, MASTER(legacySimulatorData_->cr),
-                    legacySimulatorData_->fplog, legacySimulatorData_->inputrec,
-                    legacySimulatorData_->mdAtoms->mdatoms());
-            auto constraintElementPtr = compat::make_not_null(constraintElement.get());
-            energySignallerBuilder->registerSignallerClient(constraintElementPtr);
-            trajectorySignallerBuilder->registerSignallerClient(constraintElementPtr);
-            loggingSignallerBuilder->registerSignallerClient(constraintElementPtr);
-
-            addToCallListAndMove(std::move(constraintElement), elementCallList, elementsOwnershipList);
+            builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
-
-        addToCallListAndMove(std::move(computeGlobalsElement), elementCallList, elementsOwnershipList);
-        auto energyElement = compat::make_not_null(energyDataPtr->element());
-        trajectoryElementBuilder->registerWriterClient(energyElement);
-        trajectorySignallerBuilder->registerSignallerClient(energyElement);
-        energySignallerBuilder->registerSignallerClient(energyElement);
-        checkpointClients->emplace_back(energyElement);
-        // we have the energies at time t here!
-        addToCallList(energyElement, elementCallList);
-        if (prBarostat)
+        builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>();
+        builder->add<EnergyData::Element>();
+        if (legacySimulatorData_->inputrec->epc == epcPARRINELLORAHMAN)
         {
-            addToCallListAndMove(std::move(prBarostat), elementCallList, elementsOwnershipList);
+            builder->add<ParrinelloRahmanBarostat>(-1);
         }
     }
     else if (legacySimulatorData_->inputrec->eI == eiVV)
     {
-        auto computeGlobalsElement =
-                std::make_unique<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>(
-                        statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr,
-                        globalCommunicationHelper->simulationSignals(),
-                        globalCommunicationHelper->nstglobalcomm(), legacySimulatorData_->fplog,
-                        legacySimulatorData_->mdlog, legacySimulatorData_->cr,
-                        legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms,
-                        legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
-                        legacySimulatorData_->fr, legacySimulatorData_->top_global,
-                        legacySimulatorData_->constr, hasReadEkinState);
-        topologyHolderBuilder->registerClient(computeGlobalsElement.get());
-        energySignallerBuilder->registerSignallerClient(compat::make_not_null(computeGlobalsElement.get()));
-        trajectorySignallerBuilder->registerSignallerClient(
-                compat::make_not_null(computeGlobalsElement.get()));
-
-        globalCommunicationHelper->setCheckBondedInteractionsCallback(
-                computeGlobalsElement->getCheckNumberOfBondedInteractionsCallback());
-
-        auto propagatorVelocities = std::make_unique<Propagator<IntegrationStep::VelocitiesOnly>>(
-                legacySimulatorData_->inputrec->delta_t * 0.5, statePropagatorDataPtr,
-                legacySimulatorData_->mdAtoms, legacySimulatorData_->wcycle);
-        auto propagatorVelocitiesAndPositions =
-                std::make_unique<Propagator<IntegrationStep::VelocityVerletPositionsAndVelocities>>(
-                        legacySimulatorData_->inputrec->delta_t, statePropagatorDataPtr,
-                        legacySimulatorData_->mdAtoms, legacySimulatorData_->wcycle);
-
-        addToCallListAndMove(std::move(forceElement), elementCallList, elementsOwnershipList);
-
-        std::unique_ptr<ParrinelloRahmanBarostat> prBarostat = nullptr;
-        if (legacySimulatorData_->inputrec->epc == epcPARRINELLORAHMAN)
-        {
-            // Building the PR barostat here since it needs access to the propagator
-            // and we want to be able to move the propagator object
-            prBarostat = std::make_unique<ParrinelloRahmanBarostat>(
-                    legacySimulatorData_->inputrec->nstpcouple, -1,
-                    legacySimulatorData_->inputrec->delta_t * legacySimulatorData_->inputrec->nstpcouple,
-                    legacySimulatorData_->inputrec->init_step,
-                    propagatorVelocities->viewOnPRScalingMatrix(),
-                    propagatorVelocities->prScalingCallback(), statePropagatorDataPtr,
-                    energyDataPtr, legacySimulatorData_->fplog, legacySimulatorData_->inputrec,
-                    legacySimulatorData_->mdAtoms, legacySimulatorData_->state_global,
-                    legacySimulatorData_->cr, legacySimulatorData_->inputrec->bContinuation);
-            energyDataPtr->setParrinelloRahamnBarostat(prBarostat.get());
-            checkpointClients->emplace_back(prBarostat.get());
-        }
-        addToCallListAndMove(std::move(propagatorVelocities), elementCallList, elementsOwnershipList);
+        // The velocity verlet integration algorithm
+        builder->add<ForceElement>();
+        builder->add<Propagator<IntegrationStep::VelocitiesOnly>>(
+                0.5 * legacySimulatorData_->inputrec->delta_t, RegisterWithThermostat::False,
+                RegisterWithBarostat::True);
         if (legacySimulatorData_->constr)
         {
-            auto constraintElement = std::make_unique<ConstraintsElement<ConstraintVariable::Velocities>>(
-                    legacySimulatorData_->constr, statePropagatorDataPtr, energyDataPtr,
-                    freeEnergyPerturbationDataPtr, MASTER(legacySimulatorData_->cr),
-                    legacySimulatorData_->fplog, legacySimulatorData_->inputrec,
-                    legacySimulatorData_->mdAtoms->mdatoms());
-            energySignallerBuilder->registerSignallerClient(compat::make_not_null(constraintElement.get()));
-            trajectorySignallerBuilder->registerSignallerClient(
-                    compat::make_not_null(constraintElement.get()));
-            loggingSignallerBuilder->registerSignallerClient(
-                    compat::make_not_null(constraintElement.get()));
-
-            addToCallListAndMove(std::move(constraintElement), elementCallList, elementsOwnershipList);
+            builder->add<ConstraintsElement<ConstraintVariable::Velocities>>();
         }
-        addToCallList(compat::make_not_null(computeGlobalsElement.get()), elementCallList);
-        auto stateElement = compat::make_not_null(statePropagatorDataPtr->element());
-        trajectoryElementBuilder->registerWriterClient(stateElement);
-        trajectorySignallerBuilder->registerSignallerClient(stateElement);
-        lastStepSignallerBuilder->registerSignallerClient(stateElement);
-        checkpointClients->emplace_back(stateElement);
-        // we have a full microstate at time t here!
-        addToCallList(stateElement, elementCallList);
+        builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>();
+        builder->add<StatePropagatorData::Element>();
         if (legacySimulatorData_->inputrec->etc == etcVRESCALE)
         {
-            // TODO: With increased complexity of the propagator, this will need further development,
-            //       e.g. using propagators templated for velocity propagation policies and a builder
-            propagatorVelocitiesAndPositions->setNumVelocityScalingVariables(
-                    legacySimulatorData_->inputrec->opts.ngtc);
-            auto thermostat = std::make_unique<VRescaleThermostat>(
-                    legacySimulatorData_->inputrec->nsttcouple, 0, true,
-                    legacySimulatorData_->inputrec->ld_seed, legacySimulatorData_->inputrec->opts.ngtc,
-                    legacySimulatorData_->inputrec->delta_t * legacySimulatorData_->inputrec->nsttcouple,
-                    legacySimulatorData_->inputrec->opts.ref_t,
-                    legacySimulatorData_->inputrec->opts.tau_t, legacySimulatorData_->inputrec->opts.nrdf,
-                    energyDataPtr, propagatorVelocitiesAndPositions->viewOnVelocityScaling(),
-                    propagatorVelocitiesAndPositions->velocityScalingCallback(),
-                    legacySimulatorData_->state_global, legacySimulatorData_->cr,
-                    legacySimulatorData_->inputrec->bContinuation);
-            checkpointClients->emplace_back(thermostat.get());
-            energyDataPtr->setVRescaleThermostat(thermostat.get());
-            addToCallListAndMove(std::move(thermostat), elementCallList, elementsOwnershipList);
+            builder->add<VRescaleThermostat>(0, VRescaleThermostatUseFullStepKE::Yes);
         }
-        addToCallListAndMove(std::move(propagatorVelocitiesAndPositions), elementCallList,
-                             elementsOwnershipList);
+        builder->add<Propagator<IntegrationStep::VelocityVerletPositionsAndVelocities>>(
+                legacySimulatorData_->inputrec->delta_t, RegisterWithThermostat::True,
+                RegisterWithBarostat::False);
         if (legacySimulatorData_->constr)
         {
-            auto constraintElement = std::make_unique<ConstraintsElement<ConstraintVariable::Positions>>(
-                    legacySimulatorData_->constr, statePropagatorDataPtr, energyDataPtr,
-                    freeEnergyPerturbationDataPtr, MASTER(legacySimulatorData_->cr),
-                    legacySimulatorData_->fplog, legacySimulatorData_->inputrec,
-                    legacySimulatorData_->mdAtoms->mdatoms());
-            energySignallerBuilder->registerSignallerClient(compat::make_not_null(constraintElement.get()));
-            trajectorySignallerBuilder->registerSignallerClient(
-                    compat::make_not_null(constraintElement.get()));
-            loggingSignallerBuilder->registerSignallerClient(
-                    compat::make_not_null(constraintElement.get()));
-
-            addToCallListAndMove(std::move(constraintElement), elementCallList, elementsOwnershipList);
+            builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
-        addToCallListAndMove(std::move(computeGlobalsElement), elementCallList, elementsOwnershipList);
-        auto energyElement = compat::make_not_null(energyDataPtr->element());
-        trajectoryElementBuilder->registerWriterClient(energyElement);
-        trajectorySignallerBuilder->registerSignallerClient(energyElement);
-        energySignallerBuilder->registerSignallerClient(energyElement);
-        checkpointClients->emplace_back(energyElement);
-        // we have the energies at time t here!
-        addToCallList(energyElement, elementCallList);
-        if (prBarostat)
+        builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>();
+        builder->add<EnergyData::Element>();
+        if (legacySimulatorData_->inputrec->epc == epcPARRINELLORAHMAN)
         {
-            addToCallListAndMove(std::move(prBarostat), elementCallList, elementsOwnershipList);
+            builder->add<ParrinelloRahmanBarostat>(-1);
         }
     }
     else
     {
         gmx_fatal(FARGS, "Integrator not implemented for the modular simulator.");
     }
-
-    auto integrator = std::make_unique<CompositeSimulatorElement>(std::move(elementCallList),
-                                                                  std::move(elementsOwnershipList));
-    // std::move *should* not be needed with c++-14, but clang-3.6 still requires it
-    return std::move(integrator);
 }
 
 bool ModularSimulator::isInputCompatible(bool                             exitOnFailure,
index fbf37e06f58407f614797f0d324bb4244192384e..c0bc6d2b7d68ddb54f486ef82d6cdff31617bf96 100644 (file)
@@ -56,6 +56,7 @@ struct t_fcdata;
 
 namespace gmx
 {
+class ModularSimulatorAlgorithmBuilder;
 
 /*! \libinternal
  * \ingroup module_modularsimulator
@@ -94,6 +95,9 @@ private:
     //! Constructor
     explicit ModularSimulator(std::unique_ptr<LegacySimulatorData> legacySimulatorData);
 
+    //! Populate algorithm builder with elements
+    void addIntegrationElements(ModularSimulatorAlgorithmBuilder* builder);
+
     //! Check for disabled functionality (during construction time)
     void checkInputForDisabledFunctionality();
 
index 6af7639ad928456d2a0b932f063426cc193bf4be..db5fe9c63dee4841188c093f2b67b506ba3be541 100644 (file)
@@ -59,6 +59,7 @@
 #include <functional>
 #include <memory>
 
+#include "gromacs/math/vectypes.h"
 #include "gromacs/utility/basedefinitions.h"
 #include "gromacs/utility/exceptions.h"
 
@@ -68,6 +69,8 @@ class t_state;
 
 namespace gmx
 {
+template<typename T>
+class ArrayRef;
 template<class Signaller>
 class SignallerBuilder;
 class NeighborSearchSignaller;
@@ -422,6 +425,30 @@ enum class ModularSimulatorBuilderState
     NotAcceptingClientRegistrations
 };
 
+//! Generic callback to the propagator
+typedef std::function<void(Step)> PropagatorCallback;
+//! Pointer to generic callback to the propagator
+typedef std::unique_ptr<PropagatorCallback> PropagatorCallbackPtr;
+
+/*! \internal
+ * \brief Information needed to connect a propagator to a thermostat
+ */
+struct PropagatorThermostatConnection
+{
+    std::function<void(int)>               setNumVelocityScalingVariables;
+    std::function<ArrayRef<real>()>        getViewOnVelocityScaling;
+    std::function<PropagatorCallbackPtr()> getVelocityScalingCallback;
+};
+
+/*! \internal
+ * \brief Information needed to connect a propagator to a barostat
+ */
+struct PropagatorBarostatConnection
+{
+    std::function<ArrayRef<rvec>()>        getViewOnPRScalingMatrix;
+    std::function<PropagatorCallbackPtr()> getPRScalingCallback;
+};
+
 //! /}
 } // namespace gmx
 
index 902e87ea5ac45e7c81a34f9751b4ff3415ab26d7..06986a8d48e673796894f1ec94a3ab32bb5e5d55 100644 (file)
 #include "gromacs/pbcutil/boxutilities.h"
 
 #include "energydata.h"
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
 
 namespace gmx
 {
 
-ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                   nstpcouple,
-                                                   int                   offset,
-                                                   real                  couplingTimeStep,
-                                                   Step                  initStep,
-                                                   ArrayRef<rvec>        scalingTensor,
-                                                   PropagatorCallbackPtr propagatorCallback,
-                                                   StatePropagatorData*  statePropagatorData,
-                                                   EnergyData*           energyData,
-                                                   FILE*                 fplog,
-                                                   const t_inputrec*     inputrec,
-                                                   const MDAtoms*        mdAtoms,
-                                                   const t_state*        globalState,
-                                                   t_commrec*            cr,
-                                                   bool                  isRestart) :
+ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                  nstpcouple,
+                                                   int                  offset,
+                                                   real                 couplingTimeStep,
+                                                   Step                 initStep,
+                                                   StatePropagatorData* statePropagatorData,
+                                                   EnergyData*          energyData,
+                                                   FILE*                fplog,
+                                                   const t_inputrec*    inputrec,
+                                                   const MDAtoms*       mdAtoms,
+                                                   const t_state*       globalState,
+                                                   t_commrec*           cr,
+                                                   bool                 isRestart) :
     nstpcouple_(nstpcouple),
     offset_(offset),
     couplingTimeStep_(couplingTimeStep),
     initStep_(initStep),
-    scalingTensor_(scalingTensor),
-    propagatorCallback_(std::move(propagatorCallback)),
+    propagatorCallback_(nullptr),
+    mu_{ { 0 } },
+    boxRel_{ { 0 } },
+    boxVelocity_{ { 0 } },
     statePropagatorData_(statePropagatorData),
     energyData_(energyData),
     fplog_(fplog),
     inputrec_(inputrec),
     mdAtoms_(mdAtoms)
 {
-    clear_mat(mu_);
-    clear_mat(boxRel_);
-    clear_mat(boxVelocity_);
-
+    energyData->setParrinelloRahamnBarostat(this);
     // TODO: This is only needed to restore the thermostatIntegral_ from cpt. Remove this when
     //       switching to purely client-based checkpointing.
     if (isRestart)
@@ -107,6 +106,12 @@ ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                   nstpcou
     }
 }
 
+void ParrinelloRahmanBarostat::connectWithPropagator(const PropagatorBarostatConnection& connectionData)
+{
+    scalingTensor_      = connectionData.getViewOnPRScalingMatrix();
+    propagatorCallback_ = connectionData.getPRScalingCallback();
+}
+
 void ParrinelloRahmanBarostat::scheduleTask(gmx::Step step,
                                             gmx::Time gmx_unused               time,
                                             const gmx::RegisterRunFunctionPtr& registerRunFunction)
@@ -162,6 +167,15 @@ void ParrinelloRahmanBarostat::scaleBoxAndPositions()
 
 void ParrinelloRahmanBarostat::elementSetup()
 {
+    if (propagatorCallback_ == nullptr || scalingTensor_.empty())
+    {
+        throw MissingElementConnectionError(
+                "Parrinello-Rahman barostat was not connected to a propagator.\n"
+                "Connection to a propagator element is needed to scale the velocities.\n"
+                "Use connectWithPropagator(...) before building the ModularSimulatorAlgorithm "
+                "object.");
+    }
+
     if (inputrecPreserveShape(inputrec_))
     {
         auto      box  = statePropagatorData_->box();
@@ -202,5 +216,27 @@ void ParrinelloRahmanBarostat::writeCheckpoint(t_state* localState, t_state gmx_
     localState->flags |= (1U << estBOXV) | (1U << estBOX_REL);
 }
 
+ISimulatorElement* ParrinelloRahmanBarostat::getElementPointerImpl(
+        LegacySimulatorData*                    legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+        StatePropagatorData*                    statePropagatorData,
+        EnergyData*                             energyData,
+        FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
+        int                                   offset)
+{
+    auto* element  = builderHelper->storeElement(std::make_unique<ParrinelloRahmanBarostat>(
+            legacySimulatorData->inputrec->nstpcouple, offset,
+            legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nstpcouple,
+            legacySimulatorData->inputrec->init_step, statePropagatorData, energyData,
+            legacySimulatorData->fplog, legacySimulatorData->inputrec, legacySimulatorData->mdAtoms,
+            legacySimulatorData->state_global, legacySimulatorData->cr,
+            legacySimulatorData->inputrec->bContinuation));
+    auto* barostat = static_cast<ParrinelloRahmanBarostat*>(element);
+    builderHelper->registerBarostat([barostat](const PropagatorBarostatConnection& connection) {
+        barostat->connectWithPropagator(connection);
+    });
+    return element;
+}
 
 } // namespace gmx
index c4a189a5a023bac5c362dff071eaa0d323f31740..3d59bf7dea56a7d313ac4e50be1f926fe1cf4f54 100644 (file)
@@ -55,6 +55,7 @@ struct t_commrec;
 namespace gmx
 {
 class EnergyData;
+class LegacySimulatorData;
 class MDAtoms;
 class StatePropagatorData;
 
@@ -72,20 +73,18 @@ class ParrinelloRahmanBarostat final : public ISimulatorElement, public ICheckpo
 {
 public:
     //! Constructor
-    ParrinelloRahmanBarostat(int                   nstpcouple,
-                             int                   offset,
-                             real                  couplingTimeStep,
-                             Step                  initStep,
-                             ArrayRef<rvec>        scalingTensor,
-                             PropagatorCallbackPtr propagatorCallback,
-                             StatePropagatorData*  statePropagatorData,
-                             EnergyData*           energyData,
-                             FILE*                 fplog,
-                             const t_inputrec*     inputrec,
-                             const MDAtoms*        mdAtoms,
-                             const t_state*        globalState,
-                             t_commrec*            cr,
-                             bool                  isRestart);
+    ParrinelloRahmanBarostat(int                  nstpcouple,
+                             int                  offset,
+                             real                 couplingTimeStep,
+                             Step                 initStep,
+                             StatePropagatorData* statePropagatorData,
+                             EnergyData*          energyData,
+                             FILE*                fplog,
+                             const t_inputrec*    inputrec,
+                             const MDAtoms*       mdAtoms,
+                             const t_state*       globalState,
+                             t_commrec*           cr,
+                             bool                 isRestart);
 
     /*! \brief Register run function for step / time
      *
@@ -103,6 +102,29 @@ public:
     //! Getter for the box velocities
     const rvec* boxVelocities() const;
 
+    //! Connect this to propagator
+    void connectWithPropagator(const PropagatorBarostatConnection& connectionData);
+
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     * \param offset  The step offset at which the barostat is applied
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper,
+                                                    int                        offset);
+
 private:
     //! The frequency at which the barostat is applied
     const int nstpcouple_;
index cfd9331f3cce9f8d79b75af6b348032c87ec835f..3a14a55b684b753ca3479056ca7e045c207eda59 100644 (file)
 #include "gromacs/mdlib/gmx_omp_nthreads.h"
 #include "gromacs/mdlib/mdatoms.h"
 #include "gromacs/mdlib/update.h"
+#include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/mdatom.h"
 #include "gromacs/timing/wallcycle.h"
 #include "gromacs/utility/fatalerror.h"
 
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
 
 namespace gmx
@@ -572,12 +575,42 @@ PropagatorCallbackPtr Propagator<algorithm>::prScalingCallback()
     return std::make_unique<PropagatorCallback>([this](Step step) { scalingStepPR_ = step; });
 }
 
-//! Explicit template initialization
-//! @{
+template<IntegrationStep algorithm>
+ISimulatorElement* Propagator<algorithm>::getElementPointerImpl(
+        LegacySimulatorData*                    legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+        StatePropagatorData*                    statePropagatorData,
+        EnergyData gmx_unused*     energyData,
+        FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
+        double                                timestep,
+        RegisterWithThermostat                registerWithThermostat,
+        RegisterWithBarostat                  registerWithBarostat)
+{
+    auto* element = builderHelper->storeElement(std::make_unique<Propagator<algorithm>>(
+            timestep, statePropagatorData, legacySimulatorData->mdAtoms, legacySimulatorData->wcycle));
+    if (registerWithThermostat == RegisterWithThermostat::True)
+    {
+        auto* propagator = static_cast<Propagator<algorithm>*>(element);
+        builderHelper->registerWithThermostat(
+                { [propagator](int num) { propagator->setNumVelocityScalingVariables(num); },
+                  [propagator]() { return propagator->viewOnVelocityScaling(); },
+                  [propagator]() { return propagator->velocityScalingCallback(); } });
+    }
+    if (registerWithBarostat == RegisterWithBarostat::True)
+    {
+        auto* propagator = static_cast<Propagator<algorithm>*>(element);
+        builderHelper->registerWithBarostat(
+                { [propagator]() { return propagator->viewOnPRScalingMatrix(); },
+                  [propagator]() { return propagator->prScalingCallback(); } });
+    }
+    return element;
+}
+
+// Explicit template initializations
 template class Propagator<IntegrationStep::PositionsOnly>;
 template class Propagator<IntegrationStep::VelocitiesOnly>;
 template class Propagator<IntegrationStep::LeapFrog>;
 template class Propagator<IntegrationStep::VelocityVerletPositionsAndVelocities>;
-//! @}
 
 } // namespace gmx
index 1175c611b318593230762e36797525dc429fa681..adced5a0230d0ab3764222059641a4c23838f1da 100644 (file)
@@ -56,12 +56,30 @@ struct gmx_wallcycle;
 
 namespace gmx
 {
+class EnergyData;
+class FreeEnergyPerturbationData;
+class GlobalCommunicationHelper;
+class LegacySimulatorData;
 class MDAtoms;
+class ModularSimulatorAlgorithmBuilderHelper;
 class StatePropagatorData;
 
 //! \addtogroup module_modularsimulator
 //! \{
 
+//! Whether built propagator should be registered with thermostat
+enum class RegisterWithThermostat
+{
+    True,
+    False
+};
+//! Whether built propagator should be registered with barostat
+enum class RegisterWithBarostat
+{
+    True,
+    False
+};
+
 /*! \brief The different integration types we know about
  *
  * PositionsOnly:
@@ -102,11 +120,6 @@ enum class ParrinelloRahmanVelocityScaling
     Count
 };
 
-//! Generic callback to the propagator
-typedef std::function<void(Step)> PropagatorCallback;
-//! Pointer to generic callback to the propagator
-typedef std::unique_ptr<PropagatorCallback> PropagatorCallbackPtr;
-
 /*! \internal
  * \brief Propagator element
  *
@@ -153,6 +166,30 @@ public:
     //! Get PR scaling callback
     PropagatorCallbackPtr prScalingCallback();
 
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     * \param timestep  The time step the propagator uses
+     * \param registerWithThermostat  Whether this propagator should be registered with the thermostat
+     * \param registerWithBarostat  Whether this propagator should be registered with the barostat
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper,
+                                                    double                     timestep,
+                                                    RegisterWithThermostat registerWithThermostat,
+                                                    RegisterWithBarostat   registerWithBarostat);
+
 private:
     //! The actual propagation
     template<NumVelocityScalingValues numVelocityScalingValues, ParrinelloRahmanVelocityScaling parrinelloRahmanVelocityScaling>
index 0488095fa5b449b79aaf1689a25498174526948a..39c355031710cc22cbce11848899893e15926618 100644 (file)
@@ -48,7 +48,6 @@
 #include "gromacs/ewald/pme.h"
 #include "gromacs/ewald/pme_load_balancing.h"
 #include "gromacs/ewald/pme_pp.h"
-#include "gromacs/gmxlib/network.h"
 #include "gromacs/gmxlib/nrnb.h"
 #include "gromacs/listed_forces/listed_forces.h"
 #include "gromacs/mdlib/checkpointhandler.h"
@@ -58,7 +57,6 @@
 #include "gromacs/mdlib/mdatoms.h"
 #include "gromacs/mdlib/resethandler.h"
 #include "gromacs/mdlib/stat.h"
-#include "gromacs/mdlib/update.h"
 #include "gromacs/mdrun/replicaexchange.h"
 #include "gromacs/mdrun/shellfc.h"
 #include "gromacs/mdrunutility/handlerestart.h"
 #include "gromacs/utility/cstringutil.h"
 #include "gromacs/utility/fatalerror.h"
 
+#include "checkpointhelper.h"
 #include "domdechelper.h"
+#include "energydata.h"
 #include "freeenergyperturbationdata.h"
 #include "modularsimulator.h"
 #include "parrinellorahmanbarostat.h"
-#include "signallers.h"
-#include "trajectoryelement.h"
+#include "pmeloadbalancehelper.h"
+#include "propagator.h"
+#include "statepropagatordata.h"
 #include "vrescalethermostat.h"
 
 namespace gmx
@@ -381,270 +382,258 @@ void ModularSimulatorAlgorithm::populateTaskQueue()
     }
 }
 
-ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::constructElementsAndSignallers()
+ModularSimulatorAlgorithmBuilder::ModularSimulatorAlgorithmBuilder(
+        compat::not_null<LegacySimulatorData*> legacySimulatorData) :
+    legacySimulatorData_(legacySimulatorData),
+    signals_(std::make_unique<SimulationSignals>()),
+    elementAdditionHelper_(this),
+    globalCommunicationHelper_(computeGlobalCommunicationPeriod(legacySimulatorData->mdlog,
+                                                                legacySimulatorData->inputrec,
+                                                                legacySimulatorData->cr),
+                               signals_.get())
 {
-    ModularSimulatorAlgorithm algorithm(
-            *(legacySimulatorData_->top_global->name), legacySimulatorData_->fplog,
-            legacySimulatorData_->cr, legacySimulatorData_->mdlog, legacySimulatorData_->mdrunOptions,
-            legacySimulatorData_->inputrec, legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
-            legacySimulatorData_->fr, legacySimulatorData_->walltime_accounting);
-    GlobalCommunicationHelper globalCommunicationHelper(nstglobalcomm_, &algorithm.signals_);
-    /* When restarting from a checkpoint, it can be appropriate to
-     * initialize ekind from quantities in the checkpoint. Otherwise,
-     * compute_globals must initialize ekind before the simulation
-     * starts/restarts. However, only the master rank knows what was
-     * found in the checkpoint file, so we have to communicate in
-     * order to coordinate the restart.
-     *
-     * TODO (modular) This should become obsolete when checkpoint reading
-     *      happens within the modular simulator framework: The energy
-     *      element should read its data from the checkpoint file pointer,
-     *      and signal to the compute globals element if it needs anything
-     *      reduced.
-     *
-     * TODO (legacy) Consider removing this communication if/when checkpoint
-     *      reading directly follows .tpr reading, because all ranks can
-     *      agree on hasReadEkinState at that time.
-     */
-    bool hasReadEkinState = MASTER(legacySimulatorData_->cr)
-                                    ? legacySimulatorData_->state_global->ekinstate.hasReadEkinState
-                                    : false;
-    if (PAR(legacySimulatorData_->cr))
+    if (legacySimulatorData->inputrec->efep != efepNO)
     {
-        gmx_bcast(sizeof(hasReadEkinState), &hasReadEkinState, legacySimulatorData_->cr->mpi_comm_mygroup);
+        freeEnergyPerturbationData_ = std::make_unique<FreeEnergyPerturbationData>(
+                legacySimulatorData->fplog, legacySimulatorData->inputrec, legacySimulatorData->mdAtoms);
     }
-    if (hasReadEkinState)
+
+    statePropagatorData_ = std::make_unique<StatePropagatorData>(
+            legacySimulatorData->top_global->natoms, legacySimulatorData->fplog, legacySimulatorData->cr,
+            legacySimulatorData->state_global, legacySimulatorData->fr->nbv->useGpu(),
+            legacySimulatorData->fr->bMolPBC, legacySimulatorData->mdrunOptions.writeConfout,
+            opt2fn("-c", legacySimulatorData->nfile, legacySimulatorData->fnm), legacySimulatorData->inputrec,
+            legacySimulatorData->mdAtoms->mdatoms(), legacySimulatorData->top_global);
+
+    energyData_ = std::make_unique<EnergyData>(
+            statePropagatorData_.get(), freeEnergyPerturbationData_.get(),
+            legacySimulatorData->top_global, legacySimulatorData->inputrec, legacySimulatorData->mdAtoms,
+            legacySimulatorData->enerd, legacySimulatorData->ekind, legacySimulatorData->constr,
+            legacySimulatorData->fplog, &legacySimulatorData->fr->listedForces->fcdata(),
+            legacySimulatorData->mdModulesNotifier, MASTER(legacySimulatorData->cr),
+            legacySimulatorData->observablesHistory, legacySimulatorData->startingBehavior);
+}
+
+ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::build()
+{
+    if (algorithmHasBeenBuilt_)
     {
-        restore_ekinstate_from_state(legacySimulatorData_->cr, legacySimulatorData_->ekind,
-                                     &legacySimulatorData_->state_global->ekinstate);
+        throw SimulationAlgorithmSetupError(
+                "Tried to build ModularSimulationAlgorithm more than once.");
     }
+    algorithmHasBeenBuilt_ = true;
 
-    /*
-     * Build data structures
-     */
+    // Connect propagators with thermostat / barostat
+    for (const auto& thermostatRegistration : thermostatRegistrationFunctions_)
+    {
+        for (const auto& connection : propagatorThermostatConnections_)
+        {
+            thermostatRegistration(connection);
+        }
+    }
+    for (const auto& barostatRegistration : barostatRegistrationFunctions_)
+    {
+        for (const auto& connection : propagatorBarostatConnections_)
+        {
+            barostatRegistration(connection);
+        }
+    }
 
-    if (legacySimulatorData_->inputrec->efep != efepNO)
-    {
-        algorithm.freeEnergyPerturbationData_ = std::make_unique<FreeEnergyPerturbationData>(
-                legacySimulatorData_->fplog, legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms);
-    }
-    FreeEnergyPerturbationData* freeEnergyPerturbationDataPtr =
-            algorithm.freeEnergyPerturbationData_.get();
-
-    algorithm.statePropagatorData_ = std::make_unique<StatePropagatorData>(
-            legacySimulatorData_->top_global->natoms, legacySimulatorData_->fplog,
-            legacySimulatorData_->cr, legacySimulatorData_->state_global,
-            legacySimulatorData_->fr->nbv->useGpu(), freeEnergyPerturbationDataPtr,
-            legacySimulatorData_->fr->bMolPBC, legacySimulatorData_->mdrunOptions.writeConfout,
-            opt2fn("-c", legacySimulatorData_->nfile, legacySimulatorData_->fnm),
-            legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms->mdatoms(),
-            legacySimulatorData_->top_global);
-    auto statePropagatorDataPtr = compat::make_not_null(algorithm.statePropagatorData_.get());
-
-    algorithm.energyData_ = std::make_unique<EnergyData>(
-            statePropagatorDataPtr, freeEnergyPerturbationDataPtr, legacySimulatorData_->top_global,
-            legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms,
-            legacySimulatorData_->enerd, legacySimulatorData_->ekind, legacySimulatorData_->constr,
-            legacySimulatorData_->fplog, &legacySimulatorData_->fr->listedForces->fcdata(),
-            legacySimulatorData_->mdModulesNotifier, MASTER(legacySimulatorData_->cr),
-            legacySimulatorData_->observablesHistory, legacySimulatorData_->startingBehavior);
-    auto energyDataPtr = compat::make_not_null(algorithm.energyData_.get());
+    ModularSimulatorAlgorithm algorithm(
+            *(legacySimulatorData_->top_global->name), legacySimulatorData_->fplog,
+            legacySimulatorData_->cr, legacySimulatorData_->mdlog, legacySimulatorData_->mdrunOptions,
+            legacySimulatorData_->inputrec, legacySimulatorData_->nrnb, legacySimulatorData_->wcycle,
+            legacySimulatorData_->fr, legacySimulatorData_->walltime_accounting);
+    registerWithInfrastructureAndSignallers(algorithm.signalHelper_.get());
+    algorithm.statePropagatorData_        = std::move(statePropagatorData_);
+    algorithm.energyData_                 = std::move(energyData_);
+    algorithm.freeEnergyPerturbationData_ = std::move(freeEnergyPerturbationData_);
+    algorithm.signals_                    = std::move(signals_);
 
-    /*
-     * Build stop handler
-     */
+    // Multi sim is turned off
     const bool simulationsShareState = false;
-    algorithm.stopHandler_           = legacySimulatorData_->stopHandlerBuilder->getStopHandlerMD(
-            compat::not_null<SimulationSignal*>(&(*globalCommunicationHelper.simulationSignals())[eglsSTOPCOND]),
+
+    // Build stop handler
+    algorithm.stopHandler_ = legacySimulatorData_->stopHandlerBuilder->getStopHandlerMD(
+            compat::not_null<SimulationSignal*>(
+                    &(*globalCommunicationHelper_.simulationSignals())[eglsSTOPCOND]),
             simulationsShareState, MASTER(legacySimulatorData_->cr),
             legacySimulatorData_->inputrec->nstlist, legacySimulatorData_->mdrunOptions.reproducible,
-            globalCommunicationHelper.nstglobalcomm(), legacySimulatorData_->mdrunOptions.maximumHoursToRun,
+            globalCommunicationHelper_.nstglobalcomm(), legacySimulatorData_->mdrunOptions.maximumHoursToRun,
             legacySimulatorData_->inputrec->nstlist == 0, legacySimulatorData_->fplog,
             algorithm.stophandlerCurrentStep_, algorithm.stophandlerIsNSStep_,
             legacySimulatorData_->walltime_accounting);
 
-    /*
-     * Create simulator builders
-     */
-    SignallerBuilder<NeighborSearchSignaller> neighborSearchSignallerBuilder;
-    SignallerBuilder<LastStepSignaller>       lastStepSignallerBuilder;
-    SignallerBuilder<LoggingSignaller>        loggingSignallerBuilder;
-    SignallerBuilder<EnergySignaller>         energySignallerBuilder;
-    SignallerBuilder<TrajectorySignaller>     trajectorySignallerBuilder;
-    TrajectoryElementBuilder                  trajectoryElementBuilder;
-    TopologyHolder::Builder                   topologyHolderBuilder;
-
-    // Register the simulator itself to the neighbor search / last step signaller
-    neighborSearchSignallerBuilder.registerSignallerClient(
-            compat::make_not_null(algorithm.signalHelper_.get()));
-    lastStepSignallerBuilder.registerSignallerClient(compat::make_not_null(algorithm.signalHelper_.get()));
-
-    /*
-     * Build integrator - this takes care of force calculation, propagation,
-     * constraining, and of the place the statePropagatorData and the energy element
-     * have a full timestep state.
-     */
-    // TODO: Make a CheckpointHelperBuilder
-    std::vector<ICheckpointHelperClient*> checkpointClients;
-    auto                                  integrator = buildIntegrator(
-            &neighborSearchSignallerBuilder, &lastStepSignallerBuilder, &energySignallerBuilder,
-            &loggingSignallerBuilder, &trajectorySignallerBuilder, &trajectoryElementBuilder,
-            &checkpointClients, statePropagatorDataPtr, energyDataPtr, freeEnergyPerturbationDataPtr,
-            hasReadEkinState, &topologyHolderBuilder, &globalCommunicationHelper);
-
-    FreeEnergyPerturbationData::Element* freeEnergyPerturbationElement = nullptr;
-    if (algorithm.freeEnergyPerturbationData_)
-    {
-        freeEnergyPerturbationElement = algorithm.freeEnergyPerturbationData_->element();
-        checkpointClients.emplace_back(freeEnergyPerturbationElement);
-    }
+    // Build reset handler
+    const bool simulationsShareResetCounters = false;
+    algorithm.resetHandler_                  = std::make_unique<ResetHandler>(
+            compat::make_not_null<SimulationSignal*>(
+                    &(*globalCommunicationHelper_.simulationSignals())[eglsRESETCOUNTERS]),
+            simulationsShareResetCounters, legacySimulatorData_->inputrec->nsteps,
+            MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.timingOptions.resetHalfway,
+            legacySimulatorData_->mdrunOptions.maximumHoursToRun, legacySimulatorData_->mdlog,
+            legacySimulatorData_->wcycle, legacySimulatorData_->walltime_accounting);
 
-    /*
-     * Build infrastructure elements
-     */
     // Build topology holder
-    algorithm.topologyHolder_ = topologyHolderBuilder.build(
+    algorithm.topologyHolder_ = topologyHolderBuilder_.build(
             *legacySimulatorData_->top_global, legacySimulatorData_->cr,
             legacySimulatorData_->inputrec, legacySimulatorData_->fr, legacySimulatorData_->mdAtoms,
             legacySimulatorData_->constr, legacySimulatorData_->vsite);
 
+    // Build PME load balance helper
     if (PmeLoadBalanceHelper::doPmeLoadBalancing(legacySimulatorData_->mdrunOptions,
                                                  legacySimulatorData_->inputrec,
                                                  legacySimulatorData_->fr))
     {
         algorithm.pmeLoadBalanceHelper_ = std::make_unique<PmeLoadBalanceHelper>(
-                legacySimulatorData_->mdrunOptions.verbose, statePropagatorDataPtr,
+                legacySimulatorData_->mdrunOptions.verbose, algorithm.statePropagatorData_.get(),
                 legacySimulatorData_->fplog, legacySimulatorData_->cr, legacySimulatorData_->mdlog,
                 legacySimulatorData_->inputrec, legacySimulatorData_->wcycle, legacySimulatorData_->fr);
-        neighborSearchSignallerBuilder.registerSignallerClient(
-                compat::make_not_null(algorithm.pmeLoadBalanceHelper_.get()));
+        registerWithInfrastructureAndSignallers(algorithm.pmeLoadBalanceHelper_.get());
     }
-
+    // Build domdec helper
     if (DOMAINDECOMP(legacySimulatorData_->cr))
     {
         algorithm.domDecHelper_ = std::make_unique<DomDecHelper>(
                 legacySimulatorData_->mdrunOptions.verbose,
-                legacySimulatorData_->mdrunOptions.verboseStepPrintInterval, statePropagatorDataPtr,
-                algorithm.topologyHolder_.get(),
-                globalCommunicationHelper.moveCheckBondedInteractionsCallback(),
-                globalCommunicationHelper.nstglobalcomm(), legacySimulatorData_->fplog,
+                legacySimulatorData_->mdrunOptions.verboseStepPrintInterval,
+                algorithm.statePropagatorData_.get(), algorithm.topologyHolder_.get(),
+                globalCommunicationHelper_.moveCheckBondedInteractionsCallback(),
+                globalCommunicationHelper_.nstglobalcomm(), legacySimulatorData_->fplog,
                 legacySimulatorData_->cr, legacySimulatorData_->mdlog, legacySimulatorData_->constr,
                 legacySimulatorData_->inputrec, legacySimulatorData_->mdAtoms, legacySimulatorData_->nrnb,
                 legacySimulatorData_->wcycle, legacySimulatorData_->fr, legacySimulatorData_->vsite,
                 legacySimulatorData_->imdSession, legacySimulatorData_->pull_work);
-        neighborSearchSignallerBuilder.registerSignallerClient(
-                compat::make_not_null(algorithm.domDecHelper_.get()));
+        registerWithInfrastructureAndSignallers(algorithm.domDecHelper_.get());
     }
 
-    const bool simulationsShareResetCounters = false;
-    algorithm.resetHandler_                  = std::make_unique<ResetHandler>(
-            compat::make_not_null<SimulationSignal*>(
-                    &(*globalCommunicationHelper.simulationSignals())[eglsRESETCOUNTERS]),
-            simulationsShareResetCounters, legacySimulatorData_->inputrec->nsteps,
-            MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.timingOptions.resetHalfway,
-            legacySimulatorData_->mdrunOptions.maximumHoursToRun, legacySimulatorData_->mdlog,
-            legacySimulatorData_->wcycle, legacySimulatorData_->walltime_accounting);
-
-    /*
-     * Build signaller list
-     *
-     * Note that as signallers depend on each others, the order of calling the signallers
-     * matters. It is the responsibility of this builder to ensure that the order is
-     * maintained.
-     */
-    auto energySignaller = energySignallerBuilder.build(legacySimulatorData_->inputrec->nstcalcenergy,
-                                                        legacySimulatorData_->inputrec->fepvals->nstdhdl,
-                                                        legacySimulatorData_->inputrec->nstpcouple);
-    trajectorySignallerBuilder.registerSignallerClient(compat::make_not_null(energySignaller.get()));
-    loggingSignallerBuilder.registerSignallerClient(compat::make_not_null(energySignaller.get()));
-    auto trajectoryElement = trajectoryElementBuilder.build(
+    // Build trajectory element
+    auto trajectoryElement = trajectoryElementBuilder_.build(
             legacySimulatorData_->fplog, legacySimulatorData_->nfile, legacySimulatorData_->fnm,
             legacySimulatorData_->mdrunOptions, legacySimulatorData_->cr,
             legacySimulatorData_->outputProvider, legacySimulatorData_->mdModulesNotifier,
             legacySimulatorData_->inputrec, legacySimulatorData_->top_global,
             legacySimulatorData_->oenv, legacySimulatorData_->wcycle,
             legacySimulatorData_->startingBehavior, simulationsShareState);
-    loggingSignallerBuilder.registerSignallerClient(compat::make_not_null(trajectoryElement.get()));
-    trajectorySignallerBuilder.registerSignallerClient(compat::make_not_null(trajectoryElement.get()));
-    auto trajectorySignaller = trajectorySignallerBuilder.build(
-            legacySimulatorData_->inputrec->nstxout, legacySimulatorData_->inputrec->nstvout,
-            legacySimulatorData_->inputrec->nstfout,
-            legacySimulatorData_->inputrec->nstxout_compressed, trajectoryElement->tngBoxOut(),
-            trajectoryElement->tngLambdaOut(), trajectoryElement->tngBoxOutCompressed(),
-            trajectoryElement->tngLambdaOutCompressed(), legacySimulatorData_->inputrec->nstenergy);
-
-    // Add checkpoint helper here since we need a pointer to the trajectory element and
-    // need to register it with the lastStepSignallerBuilder
-    auto checkpointHandler = std::make_unique<CheckpointHandler>(
-            compat::make_not_null<SimulationSignal*>(
-                    &(*globalCommunicationHelper.simulationSignals())[eglsCHKPT]),
-            simulationsShareState, legacySimulatorData_->inputrec->nstlist == 0,
-            MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.writeConfout,
-            legacySimulatorData_->mdrunOptions.checkpointOptions.period);
-    algorithm.checkpointHelper_ = std::make_unique<CheckpointHelper>(
-            std::move(checkpointClients), std::move(checkpointHandler),
-            legacySimulatorData_->inputrec->init_step, trajectoryElement.get(),
-            legacySimulatorData_->top_global->natoms, legacySimulatorData_->fplog,
-            legacySimulatorData_->cr, legacySimulatorData_->observablesHistory,
-            legacySimulatorData_->walltime_accounting, legacySimulatorData_->state_global,
-            legacySimulatorData_->mdrunOptions.writeConfout);
-    lastStepSignallerBuilder.registerSignallerClient(
-            compat::make_not_null(algorithm.checkpointHelper_.get()));
-
-    lastStepSignallerBuilder.registerSignallerClient(compat::make_not_null(trajectorySignaller.get()));
-    auto loggingSignaller = loggingSignallerBuilder.build(legacySimulatorData_->inputrec->nstlog,
-                                                          legacySimulatorData_->inputrec->init_step,
-                                                          legacySimulatorData_->inputrec->init_t);
-    lastStepSignallerBuilder.registerSignallerClient(compat::make_not_null(loggingSignaller.get()));
-    auto lastStepSignaller = lastStepSignallerBuilder.build(legacySimulatorData_->inputrec->nsteps,
-                                                            legacySimulatorData_->inputrec->init_step,
-                                                            algorithm.stopHandler_.get());
-    neighborSearchSignallerBuilder.registerSignallerClient(compat::make_not_null(lastStepSignaller.get()));
-    auto neighborSearchSignaller = neighborSearchSignallerBuilder.build(
-            legacySimulatorData_->inputrec->nstlist, legacySimulatorData_->inputrec->init_step,
-            legacySimulatorData_->inputrec->init_t);
-
-    algorithm.signallerList_.emplace_back(std::move(neighborSearchSignaller));
-    algorithm.signallerList_.emplace_back(std::move(lastStepSignaller));
-    algorithm.signallerList_.emplace_back(std::move(loggingSignaller));
-    algorithm.signallerList_.emplace_back(std::move(trajectorySignaller));
-    algorithm.signallerList_.emplace_back(std::move(energySignaller));
+    registerWithInfrastructureAndSignallers(trajectoryElement.get());
 
-    /*
-     * Build the element list
-     *
-     * This is the actual sequence of (non-infrastructure) elements to be run.
-     * For NVE, the trajectory element is used outside of the integrator
-     * (composite) element, as well as the checkpoint helper. The checkpoint
-     * helper should be on top of the loop, and is only part of the simulator
-     * call list to be able to react to the last step being signalled.
-     */
-    addToCallList(algorithm.checkpointHelper_, algorithm.elementCallList_);
+    // Build free energy element
+    std::unique_ptr<FreeEnergyPerturbationData::Element> freeEnergyPerturbationElement = nullptr;
+    if (algorithm.freeEnergyPerturbationData_)
+    {
+        freeEnergyPerturbationElement = std::make_unique<FreeEnergyPerturbationData::Element>(
+                algorithm.freeEnergyPerturbationData_.get(),
+                legacySimulatorData_->inputrec->fepvals->delta_lambda);
+        registerWithInfrastructureAndSignallers(freeEnergyPerturbationElement.get());
+    }
+
+    // Build checkpoint helper (do this last so everyone else can be a checkpoint client!)
+    {
+        auto checkpointHandler = std::make_unique<CheckpointHandler>(
+                compat::make_not_null<SimulationSignal*>(
+                        &(*globalCommunicationHelper_.simulationSignals())[eglsCHKPT]),
+                simulationsShareState, legacySimulatorData_->inputrec->nstlist == 0,
+                MASTER(legacySimulatorData_->cr), legacySimulatorData_->mdrunOptions.writeConfout,
+                legacySimulatorData_->mdrunOptions.checkpointOptions.period);
+        algorithm.checkpointHelper_ = std::make_unique<CheckpointHelper>(
+                std::move(checkpointClients_), std::move(checkpointHandler),
+                legacySimulatorData_->inputrec->init_step, trajectoryElement.get(),
+                legacySimulatorData_->top_global->natoms, legacySimulatorData_->fplog,
+                legacySimulatorData_->cr, legacySimulatorData_->observablesHistory,
+                legacySimulatorData_->walltime_accounting, legacySimulatorData_->state_global,
+                legacySimulatorData_->mdrunOptions.writeConfout);
+        registerWithInfrastructureAndSignallers(algorithm.checkpointHelper_.get());
+    }
+
+    // Build signallers
+    {
+        /* Signallers need to be called in an exact order. Some signallers are clients
+         * of other signallers, which requires the clients signallers to be called
+         * _after_ any signaller they are registered to - otherwise, they couldn't
+         * adapt their behavior to the information they got signalled.
+         *
+         * Signallers being clients of other signallers require registration.
+         * That registration happens during construction, which in turn means that
+         * we want to construct the signallers in the reverse order of their later
+         * call order.
+         *
+         * For the above reasons, the `addSignaller` lambda defined below emplaces
+         * added signallers at the beginning of the signaller list, which will yield
+         * a signaller list which is inverse to the build order (and hence equal to
+         * the intended call order).
+         */
+        auto addSignaller = [this, &algorithm](auto signaller) {
+            registerWithInfrastructureAndSignallers(signaller.get());
+            algorithm.signallerList_.emplace(algorithm.signallerList_.begin(), std::move(signaller));
+        };
+        const auto* inputrec = legacySimulatorData_->inputrec;
+        addSignaller(energySignallerBuilder_.build(
+                inputrec->nstcalcenergy, inputrec->fepvals->nstdhdl, inputrec->nstpcouple));
+        addSignaller(trajectorySignallerBuilder_.build(
+                inputrec->nstxout, inputrec->nstvout, inputrec->nstfout,
+                inputrec->nstxout_compressed, trajectoryElement->tngBoxOut(),
+                trajectoryElement->tngLambdaOut(), trajectoryElement->tngBoxOutCompressed(),
+                trajectoryElement->tngLambdaOutCompressed(), inputrec->nstenergy));
+        addSignaller(loggingSignallerBuilder_.build(inputrec->nstlog, inputrec->init_step, inputrec->init_t));
+        addSignaller(lastStepSignallerBuilder_.build(inputrec->nsteps, inputrec->init_step,
+                                                     algorithm.stopHandler_.get()));
+        addSignaller(neighborSearchSignallerBuilder_.build(inputrec->nstlist, inputrec->init_step,
+                                                           inputrec->init_t));
+    }
+
+    // Create element list
+    // Checkpoint helper needs to be in the call list (as first element!) to react to last step
+    algorithm.elementCallList_.emplace_back(algorithm.checkpointHelper_.get());
+    // Next, update the free energy lambda vector if needed
     if (freeEnergyPerturbationElement)
     {
-        addToCallList(freeEnergyPerturbationElement, algorithm.elementCallList_);
+        algorithm.elementsOwnershipList_.emplace_back(std::move(freeEnergyPerturbationElement));
+        algorithm.elementCallList_.emplace_back(algorithm.elementsOwnershipList_.back().get());
     }
-    addToCallListAndMove(std::move(integrator), algorithm.elementCallList_, algorithm.elementsOwnershipList_);
-    addToCallListAndMove(std::move(trajectoryElement), algorithm.elementCallList_,
-                         algorithm.elementsOwnershipList_);
+    // Then, move the built algorithm
+    algorithm.elementsOwnershipList_.insert(algorithm.elementsOwnershipList_.end(),
+                                            std::make_move_iterator(elements_.begin()),
+                                            std::make_move_iterator(elements_.end()));
+    algorithm.elementCallList_.insert(algorithm.elementCallList_.end(),
+                                      std::make_move_iterator(callList_.begin()),
+                                      std::make_move_iterator(callList_.end()));
+    // Finally, all trajectory writing is happening after the step
+    // (relevant data was stored by elements through energy signaller)
+    algorithm.elementsOwnershipList_.emplace_back(std::move(trajectoryElement));
+    algorithm.elementCallList_.emplace_back(algorithm.elementsOwnershipList_.back().get());
 
+    algorithm.setup();
     return algorithm;
 }
 
-ModularSimulatorAlgorithmBuilder::ModularSimulatorAlgorithmBuilder(
-        compat::not_null<LegacySimulatorData*> legacySimulatorData) :
-    legacySimulatorData_(legacySimulatorData),
-    nstglobalcomm_(computeGlobalCommunicationPeriod(legacySimulatorData->mdlog,
-                                                    legacySimulatorData->inputrec,
-                                                    legacySimulatorData->cr))
+ISimulatorElement* ModularSimulatorAlgorithmBuilder::addElementToSimulatorAlgorithm(
+        std::unique_ptr<ISimulatorElement> element)
 {
+    elements_.emplace_back(std::move(element));
+    return elements_.back().get();
 }
 
-ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::build()
+bool ModularSimulatorAlgorithmBuilder::elementExists(const ISimulatorElement* element) const
 {
-    auto algorithm = constructElementsAndSignallers();
-    algorithm.setup();
-    return algorithm;
+    // Check whether element exists in element list
+    if (std::any_of(elements_.begin(), elements_.end(),
+                    [element](auto& existingElement) { return element == existingElement.get(); }))
+    {
+        return true;
+    }
+    // Check whether element exists in other places controlled by *this
+    return (statePropagatorData_->element() == element || energyData_->element() == element
+            || (freeEnergyPerturbationData_ && freeEnergyPerturbationData_->element() == element));
+}
+
+void ModularSimulatorAlgorithmBuilder::addElementToSetupTeardownList(ISimulatorElement* element)
+{
+    // Add element if it's not already in the list
+    if (std::find(setupAndTeardownList_.begin(), setupAndTeardownList_.end(), element)
+        == setupAndTeardownList_.end())
+    {
+        setupAndTeardownList_.emplace_back(element);
+    }
 }
 
 SignallerCallbackPtr ModularSimulatorAlgorithm::SignalHelper::registerLastStepCallback()
@@ -685,4 +674,43 @@ CheckBondedInteractionsCallbackPtr GlobalCommunicationHelper::moveCheckBondedInt
     return std::move(checkBondedInteractionsCallbackPtr_);
 }
 
+ModularSimulatorAlgorithmBuilderHelper::ModularSimulatorAlgorithmBuilderHelper(
+        ModularSimulatorAlgorithmBuilder* builder) :
+    builder_(builder)
+{
+}
+
+ISimulatorElement* ModularSimulatorAlgorithmBuilderHelper::storeElement(std::unique_ptr<ISimulatorElement> element)
+{
+    return builder_->addElementToSimulatorAlgorithm(std::move(element));
+}
+
+bool ModularSimulatorAlgorithmBuilderHelper::elementIsStored(const ISimulatorElement* element) const
+{
+    return builder_->elementExists(element);
+}
+
+void ModularSimulatorAlgorithmBuilderHelper::registerThermostat(
+        std::function<void(const PropagatorThermostatConnection&)> registrationFunction)
+{
+    builder_->thermostatRegistrationFunctions_.emplace_back(std::move(registrationFunction));
+}
+
+void ModularSimulatorAlgorithmBuilderHelper::registerBarostat(
+        std::function<void(const PropagatorBarostatConnection&)> registrationFunction)
+{
+    builder_->barostatRegistrationFunctions_.emplace_back(std::move(registrationFunction));
+}
+
+void ModularSimulatorAlgorithmBuilderHelper::registerWithThermostat(PropagatorThermostatConnection connectionData)
+{
+    builder_->propagatorThermostatConnections_.emplace_back(std::move(connectionData));
+}
+
+void ModularSimulatorAlgorithmBuilderHelper::registerWithBarostat(PropagatorBarostatConnection connectionData)
+{
+    builder_->propagatorBarostatConnections_.emplace_back(std::move(connectionData));
+}
+
+
 } // namespace gmx
index 29087ccd9fe6908c7d49c7e26f82ed53639321d9..9c6b952a0e164a1510a748771b3f8ad81121df28 100644 (file)
 #include "gromacs/mdrun/isimulator.h"
 
 #include "checkpointhelper.h"
-#include "computeglobalselement.h"
 #include "domdechelper.h"
 #include "freeenergyperturbationdata.h"
 #include "modularsimulatorinterfaces.h"
 #include "pmeloadbalancehelper.h"
+#include "signallers.h"
+#include "topologyholder.h"
+#include "trajectoryelement.h"
 
 namespace gmx
 {
+enum class IntegrationStep;
 class EnergyData;
-class EnergySignaller;
-class LoggingSignaller;
 class ModularSimulator;
-class NeighborSearchSignaller;
 class ResetHandler;
+template<IntegrationStep algorithm>
+class Propagator;
 class TopologyHolder;
-class TrajectoryElementBuilder;
 
 /*! \internal
  * \ingroup module_modularsimulator
@@ -178,7 +179,7 @@ private:
     //! List of schedulerElements (ownership)
     std::vector<std::unique_ptr<ISimulatorElement>> elementsOwnershipList_;
     //! List of schedulerElements (calling sequence)
-    std::vector<compat::not_null<ISimulatorElement*>> elementCallList_;
+    std::vector<ISimulatorElement*> elementCallList_;
 
     // Infrastructure elements
     //! The domain decomposition element
@@ -192,7 +193,7 @@ private:
     //! The reset handler
     std::unique_ptr<ResetHandler> resetHandler_;
     //! Signal vector (used by stop / reset / checkpointing signaller)
-    SimulationSignals signals_;
+    std::unique_ptr<SimulationSignals> signals_;
     //! The topology
     std::unique_ptr<TopologyHolder> topologyHolder_;
 
@@ -294,22 +295,58 @@ private:
     CheckBondedInteractionsCallbackPtr checkBondedInteractionsCallbackPtr_;
 };
 
+class ModularSimulatorAlgorithmBuilder;
+
+/*! \internal
+ * \brief Helper for element addition
+ *
+ * Such an object will be given to each invocation of getElementPointer
+ *
+ * Note: It would be nicer to define this as a member type of
+ * ModularSimulatorAlgorithmBuilder, but this would break forward declaration.
+ * This object is therefore defined as friend class.
+ */
+class ModularSimulatorAlgorithmBuilderHelper
+{
+public:
+    //! Constructor
+    ModularSimulatorAlgorithmBuilderHelper(ModularSimulatorAlgorithmBuilder* builder);
+    //! Store an element to the ModularSimulatorAlgorithmBuilder
+    ISimulatorElement* storeElement(std::unique_ptr<ISimulatorElement> element);
+    //! Check if an element is stored in the ModularSimulatorAlgorithmBuilder
+    bool elementIsStored(const ISimulatorElement* element) const;
+    //! Register a thermostat that accepts propagator registrations
+    void registerThermostat(std::function<void(const PropagatorThermostatConnection&)> registrationFunction);
+    //! Register a barostat that accepts propagator registrations
+    void registerBarostat(std::function<void(const PropagatorBarostatConnection&)> registrationFunction);
+    //! Register a propagator to the thermostat used
+    void registerWithThermostat(PropagatorThermostatConnection connectionData);
+    //! Register a propagator to the barostat used
+    void registerWithBarostat(PropagatorBarostatConnection connectionData);
+
+private:
+    //! Pointer to the associated ModularSimulatorAlgorithmBuilder
+    ModularSimulatorAlgorithmBuilder* builder_;
+};
+
 /*!\internal
  * \brief Builder for ModularSimulatorAlgorithm objects
  *
- * TODO: The current builder automatically builds a simulator algorithm based on the
- *       input. This is only an intemediate step towards a builder that will create
- *       algorithms designed by the user of ModularSimulatorAlgorithm (for now, the
- *       only user is the ModularSimulator).
- * TODO: This mirrors all protected members of ISimulator. This hack allows to keep
- *       the number of line changes minimal, and will be removed as soon as the builder
- *       allows the user to compose the integrator algorithm.
- *       For the same reason, the constructElementsAndSignallers(), buildIntegrator(...),
- *       and buildForces(...) implementations were left in modularsimulator.cpp. See other
- *       to do - as the ModularSimulator will eventually design the algorithm, moving it
- *       would only cause unnecessary noise.
+ * This builds a ModularSimulatorAlgorithm.
+ *
+ * Users can add elements and define their call order by calling the templated
+ * add<Element> function. Note that only elements that have a static
+ * getElementPointerImpl factory method can be built in that way.
+ *
+ * Note that each ModularSimulatorAlgorithmBuilder can only be used to build
+ * one ModularSimulatorAlgorithm object, i.e. build() can only be called once.
+ * During the call to build, all elements and other infrastructure objects will
+ * be moved to the built ModularSimulatorAlgorithm object, such that further use
+ * of the builder would not make sense.
+ * Any access to the build or add<> methods after the first call to
+ * build() will result in an exception being thrown.
  */
-class ModularSimulatorAlgorithmBuilder
+class ModularSimulatorAlgorithmBuilder final
 {
 public:
     //! Constructor
@@ -317,110 +354,250 @@ public:
     //! Build algorithm
     ModularSimulatorAlgorithm build();
 
+    /*! \brief  Add element to the modular simulator algorithm builder
+     *
+     * This function has a general implementation, which will call the getElementPointer(...)
+     * factory function.
+     *
+     * \tparam Element  The element type
+     * \tparam Args     A variable number of argument types
+     * \param args      A variable number of arguments
+     */
+    template<typename Element, typename... Args>
+    void add(Args&&... args);
+
+    //! Allow access from helper
+    friend class ModularSimulatorAlgorithmBuilderHelper;
+
 private:
-    /*! \brief The initialisation
+    //! The state of the builder
+    bool algorithmHasBeenBuilt_ = false;
+
+    // Data structures
+    //! The state propagator data
+    std::unique_ptr<StatePropagatorData> statePropagatorData_;
+    //! The energy data
+    std::unique_ptr<EnergyData> energyData_;
+    //! The free energy data
+    std::unique_ptr<FreeEnergyPerturbationData> freeEnergyPerturbationData_;
+
+    //! Pointer to the LegacySimulatorData object
+    compat::not_null<LegacySimulatorData*> legacySimulatorData_;
+
+    // Helper objects
+    //! Signal vector (used by stop / reset / checkpointing signaller)
+    std::unique_ptr<SimulationSignals> signals_;
+    //! Helper object passed to element factory functions
+    ModularSimulatorAlgorithmBuilderHelper elementAdditionHelper_;
+    //! Container for global computation data
+    GlobalCommunicationHelper globalCommunicationHelper_;
+
+    /*! \brief  Register an element to all applicable signallers and infrastructure elements
      *
-     * This builds all signallers and elements, and is responsible to put
-     * them in the correct order.
+     * \tparam Element  Type of the Element
+     * \param element   Pointer to the element
      */
-    ModularSimulatorAlgorithm constructElementsAndSignallers();
+    template<typename Element>
+    void registerWithInfrastructureAndSignallers(Element* element);
 
-    /*! \brief Build the integrator part of the simulator
+    /*! \brief Take ownership of element
      *
-     * This includes the force calculation, state propagation, constraints,
-     * global computation, and the points during the process at which valid
-     * micro state / energy states are found. Currently, buildIntegrator
-     * knows about NVE md and md-vv algorithms.
+     * This function returns a non-owning pointer to the new location of that
+     * element, allowing further usage (e.g. adding the element to the call list).
+     * Note that simply addin an element using this function will not call it
+     * during the simulation - it needs to be added to the call list separately.
+     * Note that generally, users will want to add elements to the call list, but
+     * it might not be practical to do this in the same order.
+     *
+     * \param element  A unique pointer to the element
+     * \return  A non-owning (raw) pointer to the element for further usage
      */
-    std::unique_ptr<ISimulatorElement>
-    buildIntegrator(SignallerBuilder<NeighborSearchSignaller>* neighborSearchSignallerBuilder,
-                    SignallerBuilder<LastStepSignaller>*       lastStepSignallerBuilder,
-                    SignallerBuilder<EnergySignaller>*         energySignallerBuilder,
-                    SignallerBuilder<LoggingSignaller>*        loggingSignallerBuilder,
-                    SignallerBuilder<TrajectorySignaller>*     trajectorySignallerBuilder,
-                    TrajectoryElementBuilder*                  trajectoryElementBuilder,
-                    std::vector<ICheckpointHelperClient*>*     checkpointClients,
-                    compat::not_null<StatePropagatorData*>     statePropagatorDataPtr,
-                    compat::not_null<EnergyData*>              energyDataPtr,
-                    FreeEnergyPerturbationData*                freeEnergyPerturbationDataPtr,
-                    bool                                       hasReadEkinState,
-                    TopologyHolder::Builder*                   topologyHolderBuilder,
-                    GlobalCommunicationHelper*                 globalCommunicationHelper);
-
-    //! Build the force element - can be normal forces or shell / flex constraints
-    std::unique_ptr<ISimulatorElement>
-    buildForces(SignallerBuilder<NeighborSearchSignaller>* neighborSearchSignallerBuilder,
-                SignallerBuilder<EnergySignaller>*         energySignallerBuilder,
-                StatePropagatorData*                       statePropagatorDataPtr,
-                EnergyData*                                energyDataPtr,
-                FreeEnergyPerturbationData*                freeEnergyPerturbationDataPtr,
-                TopologyHolder::Builder*                   topologyHolderBuilder);
+    ISimulatorElement* addElementToSimulatorAlgorithm(std::unique_ptr<ISimulatorElement> element);
 
-    //! Pointer to the LegacySimulatorData object
-    compat::not_null<LegacySimulatorData*> legacySimulatorData_;
+    /*! \brief Check if element is owned by *this
+     *
+     * \param element  Pointer to the element
+     * \return  Bool indicating whether element is owned by *this
+     */
+    [[nodiscard]] bool elementExists(const ISimulatorElement* element) const;
 
-    //! \cond
-    //! Helper function to add elements or signallers to the call list via raw pointer
-    template<typename T, typename U>
-    static void addToCallList(U* element, std::vector<compat::not_null<T*>>& callList);
-    //! Helper function to add elements or signallers to the call list via non-null raw pointer
-    template<typename T, typename U>
-    static void addToCallList(compat::not_null<U*> element, std::vector<compat::not_null<T*>>& callList);
-    //! Helper function to add elements or signallers to the call list via smart pointer
-    template<typename T, typename U>
-    static void addToCallList(std::unique_ptr<U>& element, std::vector<compat::not_null<T*>>& callList);
-    /*! \brief Helper function to add elements or signallers to the call list
-     *         and move the ownership to the ownership list
+    /*! \brief Add element to setupAndTeardownList_ if it's not already there
+     *
+     * \param element  Element pointer to be added
      */
-    template<typename T, typename U>
-    static void addToCallListAndMove(std::unique_ptr<U>                 element,
-                                     std::vector<compat::not_null<T*>>& callList,
-                                     std::vector<std::unique_ptr<T>>&   elementList);
-    //! \endcond
+    void addElementToSetupTeardownList(ISimulatorElement* element);
 
-    //! Compute globals communication period
-    const int nstglobalcomm_;
+    //! Vector to store elements, allowing the SimulatorAlgorithm to control their lifetime
+    std::vector<std::unique_ptr<ISimulatorElement>> elements_;
+    /*! \brief List defining in which order elements are called every step
+     *
+     * Elements may be referenced more than once if they should be called repeatedly
+     */
+    std::vector<ISimulatorElement*> callList_;
+    /*! \brief  List defining in which order elements are set up and torn down
+     *
+     * Elements should only appear once in this list
+     */
+    std::vector<ISimulatorElement*> setupAndTeardownList_;
+
+    //! Builder for the NeighborSearchSignaller
+    SignallerBuilder<NeighborSearchSignaller> neighborSearchSignallerBuilder_;
+    //! Builder for the LastStepSignaller
+    SignallerBuilder<LastStepSignaller> lastStepSignallerBuilder_;
+    //! Builder for the LoggingSignaller
+    SignallerBuilder<LoggingSignaller> loggingSignallerBuilder_;
+    //! Builder for the EnergySignaller
+    SignallerBuilder<EnergySignaller> energySignallerBuilder_;
+    //! Builder for the TrajectorySignaller
+    SignallerBuilder<TrajectorySignaller> trajectorySignallerBuilder_;
+    //! Builder for the TrajectoryElementBuilder
+    TrajectoryElementBuilder trajectoryElementBuilder_;
+    //! Builder for the TopologyHolder
+    TopologyHolder::Builder topologyHolderBuilder_;
+
+    /*! \brief List of clients for the CheckpointHelper
+     *
+     * \todo Replace this by proper builder (#3422)
+     */
+    std::vector<ICheckpointHelperClient*> checkpointClients_;
+
+    //! List of thermostat registration functions
+    std::vector<std::function<void(const PropagatorThermostatConnection&)>> thermostatRegistrationFunctions_;
+    //! List of barostat registration functions
+    std::vector<std::function<void(const PropagatorBarostatConnection&)>> barostatRegistrationFunctions_;
+    //! List of data to connect propagators to thermostats
+    std::vector<PropagatorThermostatConnection> propagatorThermostatConnections_;
+    //! List of data to connect propagators to barostats
+    std::vector<PropagatorBarostatConnection> propagatorBarostatConnections_;
 };
 
-//! \cond
-template<typename T, typename U>
-void ModularSimulatorAlgorithmBuilder::addToCallList(U* element, std::vector<compat::not_null<T*>>& callList)
+/*! \internal
+ * \brief Factory function for elements that can be added via ModularSimulatorAlgorithmBuilder:
+ *        Get a pointer to an object of type \c Element to add to the call list
+ *
+ * This allows elements to be built via the templated ModularSimulatorAlgorithmBuilder::add<Element>
+ * method. Elements buildable throught this factor function are required to implement a static
+ * function with minimal signature
+ *
+ *     static ISimulatorElement* getElementPointerImpl(
+ *             LegacySimulatorData*                    legacySimulatorData,
+ *             ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+ *             StatePropagatorData*                    statePropagatorData,
+ *             EnergyData*                             energyData,
+ *             FreeEnergyPerturbationData*             freeEnergyPerturbationData,
+ *             GlobalCommunicationHelper*              globalCommunicationHelper)
+ *
+ * This function may also accept additional parameters which are passed using the variadic
+ * template parameter pack forwarded in getElementPointer.
+ *
+ * This function returns a pointer to an object of the Element type. Note that the caller will
+ * check whether the returned object has previously been stored using the `storeElement`
+ * function, and throw an exception if the element is not found.
+ * The function can check whether a previously stored pointer is valid using
+ * the `checkElementExistence` function. Most implementing functions will simply want
+ * to create an object, store it using `storeElement`, and then use the return value of
+ * `storeElement` as a return value to the caller. However, this setup allows the function
+ * to store a created element (using a static pointer inside the function) and return it
+ * in case that the factory function is called repeatedly. This allows to create an element
+ * once, but have it called multiple times during the simulation run.
+ *
+ * \see ModularSimulatorAlgorithmBuilder::add
+ *      Function using this functionality
+ * \see ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>::getElementPointerImpl
+ *      Implementation using the single object / multiple call sites functionality
+ *
+ * \tparam Element The type of the element
+ * \tparam Args  Variable number of argument types allowing specific implementations to have
+ *               additional arguments
+ *
+ * \param legacySimulatorData  Pointer allowing access to simulator level data
+ * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+ * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+ * \param energyData  Pointer to the \c EnergyData object
+ * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+ * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+ * \param args  Variable number of additional parameters to be forwarded
+ *
+ * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+ */
+template<typename Element, typename... Args>
+ISimulatorElement* getElementPointer(LegacySimulatorData*                    legacySimulatorData,
+                                     ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                     StatePropagatorData*                    statePropagatorData,
+                                     EnergyData*                             energyData,
+                                     FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                     GlobalCommunicationHelper*  globalCommunicationHelper,
+                                     Args&&... args)
+{
+    return Element::getElementPointerImpl(legacySimulatorData, builderHelper, statePropagatorData,
+                                          energyData, freeEnergyPerturbationData,
+                                          globalCommunicationHelper, std::forward<Args>(args)...);
+}
+
+template<typename Element, typename... Args>
+void ModularSimulatorAlgorithmBuilder::add(Args&&... args)
 {
-    if (element)
+    if (algorithmHasBeenBuilt_)
+    {
+        throw SimulationAlgorithmSetupError(
+                "Tried to add an element after ModularSimulationAlgorithm was built.");
+    }
+
+    // Get element from factory method
+    auto* element = static_cast<Element*>(getElementPointer<Element>(
+            legacySimulatorData_, &elementAdditionHelper_, statePropagatorData_.get(),
+            energyData_.get(), freeEnergyPerturbationData_.get(), &globalCommunicationHelper_,
+            std::forward<Args>(args)...));
+
+    // Make sure returned element pointer is owned by *this
+    // Ensuring this makes sure we can control the life time
+    if (!elementExists(element))
     {
-        callList.emplace_back(element);
+        throw ElementNotFoundError("Tried to append non-existing element to call list.");
     }
+    // Add to call list
+    callList_.emplace_back(element);
+    // Add to setup / teardown list if element hasn't been added yet
+    addElementToSetupTeardownList(element);
+    // Register element to all applicable signallers
+    registerWithInfrastructureAndSignallers(element);
 }
 
-template<typename T, typename U>
-void ModularSimulatorAlgorithmBuilder::addToCallList(gmx::compat::not_null<U*>          element,
-                                                     std::vector<compat::not_null<T*>>& callList)
+//! Returns a pointer casted to type Base if the Element is derived from Base
+template<typename Base, typename Element>
+static std::enable_if_t<std::is_base_of<Base, Element>::value, Base*> castOrNull(Element* element)
 {
-    callList.emplace_back(element);
+    return static_cast<Base*>(element);
 }
 
-template<typename T, typename U>
-void ModularSimulatorAlgorithmBuilder::addToCallList(std::unique_ptr<U>&                element,
-                                                     std::vector<compat::not_null<T*>>& callList)
+//! Returns a nullptr of type Base if Element is not derived from Base
+template<typename Base, typename Element>
+static std::enable_if_t<!std::is_base_of<Base, Element>::value, Base*> castOrNull(Element gmx_unused* element)
 {
-    if (element)
-    {
-        callList.emplace_back(compat::make_not_null(element.get()));
-    }
+    return nullptr;
 }
 
-template<typename T, typename U>
-void ModularSimulatorAlgorithmBuilder::addToCallListAndMove(std::unique_ptr<U> element,
-                                                            std::vector<compat::not_null<T*>>& callList,
-                                                            std::vector<std::unique_ptr<T>>& elementList)
+template<typename Element>
+void ModularSimulatorAlgorithmBuilder::registerWithInfrastructureAndSignallers(Element* element)
 {
-    if (element)
+    // Register element to all applicable signallers
+    neighborSearchSignallerBuilder_.registerSignallerClient(
+            castOrNull<INeighborSearchSignallerClient, Element>(element));
+    lastStepSignallerBuilder_.registerSignallerClient(castOrNull<ILastStepSignallerClient, Element>(element));
+    loggingSignallerBuilder_.registerSignallerClient(castOrNull<ILoggingSignallerClient, Element>(element));
+    energySignallerBuilder_.registerSignallerClient(castOrNull<IEnergySignallerClient, Element>(element));
+    trajectorySignallerBuilder_.registerSignallerClient(
+            castOrNull<ITrajectorySignallerClient, Element>(element));
+    // Register element to trajectory element (if applicable)
+    trajectoryElementBuilder_.registerWriterClient(castOrNull<ITrajectoryWriterClient, Element>(element));
+    // Register element to topology holder (if applicable)
+    topologyHolderBuilder_.registerClient(castOrNull<ITopologyHolderClient, Element>(element));
+    // Register element to checkpoint client (if applicable)
+    if (auto castedElement = castOrNull<ICheckpointHelperClient, Element>(element))
     {
-        callList.emplace_back(compat::make_not_null(element.get()));
-        elementList.emplace_back(std::move(element));
+        checkpointClients_.emplace_back(castedElement);
     }
 }
-//! \endcond
 
 } // namespace gmx
 
index 2e43a870b5b4291ff0d3f08df30f1661c0656709..090742f96b4ac1f4b8b133290fc63fe34eb88163 100644 (file)
 
 #include "statepropagatordata.h"
 
+#include "gromacs/commandline/filenm.h"
 #include "gromacs/domdec/collect.h"
 #include "gromacs/domdec/domdec.h"
 #include "gromacs/fileio/confio.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdlib/gmx_omp_nthreads.h"
+#include "gromacs/mdlib/mdatoms.h"
 #include "gromacs/mdlib/mdoutf.h"
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdlib/update.h"
 #include "gromacs/mdtypes/commrec.h"
+#include "gromacs/mdtypes/forcerec.h"
 #include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/mdatom.h"
+#include "gromacs/mdtypes/mdrunoptions.h"
 #include "gromacs/mdtypes/state.h"
+#include "gromacs/nbnxm/nbnxm.h"
 #include "gromacs/pbcutil/pbc.h"
 #include "gromacs/topology/atoms.h"
 #include "gromacs/topology/topology.h"
 
 #include "freeenergyperturbationdata.h"
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
 
 namespace gmx
 {
-StatePropagatorData::StatePropagatorData(int                         numAtoms,
-                                         FILE*                       fplog,
-                                         const t_commrec*            cr,
-                                         t_state*                    globalState,
-                                         bool                        useGPU,
-                                         FreeEnergyPerturbationData* freeEnergyPerturbationData,
+StatePropagatorData::StatePropagatorData(int                numAtoms,
+                                         FILE*              fplog,
+                                         const t_commrec*   cr,
+                                         t_state*           globalState,
+                                         bool               useGPU,
                                          bool               canMoleculesBeDistributedOverPBC,
                                          bool               writeFinalConfiguration,
                                          const std::string& finalConfigurationFilename,
@@ -87,7 +93,6 @@ StatePropagatorData::StatePropagatorData(int                         numAtoms,
                                        inputrec->nstvout,
                                        inputrec->nstfout,
                                        inputrec->nstxout_compressed,
-                                       freeEnergyPerturbationData,
                                        canMoleculesBeDistributedOverPBC,
                                        writeFinalConfiguration,
                                        finalConfigurationFilename,
@@ -512,26 +517,25 @@ SignallerCallbackPtr StatePropagatorData::Element::registerLastStepCallback()
     });
 }
 
-StatePropagatorData::Element::Element(StatePropagatorData*        statePropagatorData,
-                                      FILE*                       fplog,
-                                      const t_commrec*            cr,
-                                      int                         nstxout,
-                                      int                         nstvout,
-                                      int                         nstfout,
-                                      int                         nstxout_compressed,
-                                      FreeEnergyPerturbationData* freeEnergyPerturbationData,
-                                      bool                        canMoleculesBeDistributedOverPBC,
-                                      bool                        writeFinalConfiguration,
-                                      std::string                 finalConfigurationFilename,
-                                      const t_inputrec*           inputrec,
-                                      const gmx_mtop_t*           globalTop) :
+StatePropagatorData::Element::Element(StatePropagatorData* statePropagatorData,
+                                      FILE*                fplog,
+                                      const t_commrec*     cr,
+                                      int                  nstxout,
+                                      int                  nstvout,
+                                      int                  nstfout,
+                                      int                  nstxout_compressed,
+                                      bool                 canMoleculesBeDistributedOverPBC,
+                                      bool                 writeFinalConfiguration,
+                                      std::string          finalConfigurationFilename,
+                                      const t_inputrec*    inputrec,
+                                      const gmx_mtop_t*    globalTop) :
     statePropagatorData_(statePropagatorData),
     nstxout_(nstxout),
     nstvout_(nstvout),
     nstfout_(nstfout),
     nstxout_compressed_(nstxout_compressed),
     writeOutStep_(-1),
-    freeEnergyPerturbationData_(freeEnergyPerturbationData),
+    freeEnergyPerturbationData_(nullptr),
     isRegularSimulationEnd_(false),
     lastStep_(-1),
     canMoleculesBeDistributedOverPBC_(canMoleculesBeDistributedOverPBC),
@@ -545,5 +549,21 @@ StatePropagatorData::Element::Element(StatePropagatorData*        statePropagato
     top_global_(globalTop)
 {
 }
+void StatePropagatorData::Element::setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData)
+{
+    freeEnergyPerturbationData_ = freeEnergyPerturbationData;
+}
+
+ISimulatorElement* StatePropagatorData::Element::getElementPointerImpl(
+        LegacySimulatorData gmx_unused*        legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper gmx_unused* builderHelper,
+        StatePropagatorData*                               statePropagatorData,
+        EnergyData gmx_unused*      energyData,
+        FreeEnergyPerturbationData* freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper)
+{
+    statePropagatorData->element()->setFreeEnergyPerturbationData(freeEnergyPerturbationData);
+    return statePropagatorData->element();
+}
 
 } // namespace gmx
index 0487be3a5071c0ceeb9051386fb02a35e634d1b9..08a3033775b3c2ad9ec983bd487a1b2e59d07679 100644 (file)
@@ -61,7 +61,11 @@ struct t_mdatoms;
 namespace gmx
 {
 enum class ConstraintVariable;
+class EnergyData;
 class FreeEnergyPerturbationData;
+class GlobalCommunicationHelper;
+class LegacySimulatorData;
+class ModularSimulatorAlgorithmBuilderHelper;
 
 /*! \internal
  * \ingroup module_modularsimulator
@@ -91,18 +95,17 @@ class StatePropagatorData final
 {
 public:
     //! Constructor
-    StatePropagatorData(int                         numAtoms,
-                        FILE*                       fplog,
-                        const t_commrec*            cr,
-                        t_state*                    globalState,
-                        bool                        useGPU,
-                        FreeEnergyPerturbationData* freeEnergyPerturbationData,
-                        bool                        canMoleculesBeDistributedOverPBC,
-                        bool                        writeFinalConfiguration,
-                        const std::string&          finalConfigurationFilename,
-                        const t_inputrec*           inputrec,
-                        const t_mdatoms*            mdatoms,
-                        const gmx_mtop_t*           globalTop);
+    StatePropagatorData(int                numAtoms,
+                        FILE*              fplog,
+                        const t_commrec*   cr,
+                        t_state*           globalState,
+                        bool               useGPU,
+                        bool               canMoleculesBeDistributedOverPBC,
+                        bool               writeFinalConfiguration,
+                        const std::string& finalConfigurationFilename,
+                        const t_inputrec*  inputrec,
+                        const t_mdatoms*   mdatoms,
+                        const gmx_mtop_t*  globalTop);
 
     // Allow access to state
     //! Get write access to position vector
@@ -231,19 +234,18 @@ class StatePropagatorData::Element final :
 {
 public:
     //! Constructor
-    Element(StatePropagatorData*        statePropagatorData,
-            FILE*                       fplog,
-            const t_commrec*            cr,
-            int                         nstxout,
-            int                         nstvout,
-            int                         nstfout,
-            int                         nstxout_compressed,
-            FreeEnergyPerturbationData* freeEnergyPerturbationData,
-            bool                        canMoleculesBeDistributedOverPBC,
-            bool                        writeFinalConfiguration,
-            std::string                 finalConfigurationFilename,
-            const t_inputrec*           inputrec,
-            const gmx_mtop_t*           globalTop);
+    Element(StatePropagatorData* statePropagatorData,
+            FILE*                fplog,
+            const t_commrec*     cr,
+            int                  nstxout,
+            int                  nstvout,
+            int                  nstfout,
+            int                  nstxout_compressed,
+            bool                 canMoleculesBeDistributedOverPBC,
+            bool                 writeFinalConfiguration,
+            std::string          finalConfigurationFilename,
+            const t_inputrec*    inputrec,
+            const gmx_mtop_t*    globalTop);
 
     /*! \brief Register run function for step / time
      *
@@ -274,6 +276,27 @@ public:
     //! No element teardown needed
     void elementTeardown() override {}
 
+    //! Set free energy data
+    void setFreeEnergyPerturbationData(FreeEnergyPerturbationData* freeEnergyPerturbationData);
+
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper);
+
 private:
     //! Pointer to the associated StatePropagatorData
     StatePropagatorData* statePropagatorData_;
index d437db21769b05490809e39b9a969c95aec613a9..ceb85a07c6d1f5b2d0bb2a77aaa2343eee0b5670 100644 (file)
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/group.h"
+#include "gromacs/mdtypes/inputrec.h"
 #include "gromacs/mdtypes/state.h"
 #include "gromacs/utility/fatalerror.h"
 
+#include "modularsimulator.h"
+#include "simulatoralgorithm.h"
+
 namespace gmx
 {
 
-VRescaleThermostat::VRescaleThermostat(int                   nstcouple,
-                                       int                   offset,
-                                       bool                  useFullStepKE,
-                                       int64_t               seed,
-                                       int                   numTemperatureGroups,
-                                       double                couplingTimeStep,
-                                       const real*           referenceTemperature,
-                                       const real*           couplingTime,
-                                       const real*           numDegreesOfFreedom,
-                                       EnergyData*           energyData,
-                                       ArrayRef<real>        lambdaView,
-                                       PropagatorCallbackPtr propagatorCallback,
-                                       const t_state*        globalState,
-                                       t_commrec*            cr,
-                                       bool                  isRestart) :
+VRescaleThermostat::VRescaleThermostat(int            nstcouple,
+                                       int            offset,
+                                       bool           useFullStepKE,
+                                       int64_t        seed,
+                                       int            numTemperatureGroups,
+                                       double         couplingTimeStep,
+                                       const real*    referenceTemperature,
+                                       const real*    couplingTime,
+                                       const real*    numDegreesOfFreedom,
+                                       EnergyData*    energyData,
+                                       const t_state* globalState,
+                                       t_commrec*     cr,
+                                       bool           isRestart) :
     nstcouple_(nstcouple),
     offset_(offset),
     useFullStepKE_(useFullStepKE),
@@ -82,9 +84,9 @@ VRescaleThermostat::VRescaleThermostat(int                   nstcouple,
     numDegreesOfFreedom_(numDegreesOfFreedom, numDegreesOfFreedom + numTemperatureGroups),
     thermostatIntegral_(numTemperatureGroups, 0.0),
     energyData_(energyData),
-    lambda_(lambdaView),
-    propagatorCallback_(std::move(propagatorCallback))
+    propagatorCallback_(nullptr)
 {
+    energyData->setVRescaleThermostat(this);
     // TODO: This is only needed to restore the thermostatIntegral_ from cpt. Remove this when
     //       switching to purely client-based checkpointing.
     if (isRestart)
@@ -103,6 +105,25 @@ VRescaleThermostat::VRescaleThermostat(int                   nstcouple,
     }
 }
 
+void VRescaleThermostat::connectWithPropagator(const PropagatorThermostatConnection& connectionData)
+{
+    connectionData.setNumVelocityScalingVariables(numTemperatureGroups_);
+    lambda_             = connectionData.getViewOnVelocityScaling();
+    propagatorCallback_ = connectionData.getVelocityScalingCallback();
+}
+
+void VRescaleThermostat::elementSetup()
+{
+    if (propagatorCallback_ == nullptr || lambda_.empty())
+    {
+        throw MissingElementConnectionError(
+                "V-rescale thermostat was not connected to a propagator.\n"
+                "Connection to a propagator element is needed to scale the velocities.\n"
+                "Use connectWithPropagator(...) before building the ModularSimulatorAlgorithm "
+                "object.");
+    }
+}
+
 void VRescaleThermostat::scheduleTask(Step step, Time gmx_unused time, const RegisterRunFunctionPtr& registerRunFunction)
 {
     /* The thermostat will need a valid kinetic energy when it is running.
@@ -185,4 +206,29 @@ const std::vector<double>& VRescaleThermostat::thermostatIntegral() const
     return thermostatIntegral_;
 }
 
+ISimulatorElement* VRescaleThermostat::getElementPointerImpl(
+        LegacySimulatorData*                    legacySimulatorData,
+        ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+        StatePropagatorData gmx_unused* statePropagatorData,
+        EnergyData gmx_unused*     energyData,
+        FreeEnergyPerturbationData gmx_unused* freeEnergyPerturbationData,
+        GlobalCommunicationHelper gmx_unused* globalCommunicationHelper,
+        int                                   offset,
+        VRescaleThermostatUseFullStepKE       useFullStepKE)
+{
+    auto* element    = builderHelper->storeElement(std::make_unique<VRescaleThermostat>(
+            legacySimulatorData->inputrec->nsttcouple, offset,
+            useFullStepKE == VRescaleThermostatUseFullStepKE::Yes,
+            legacySimulatorData->inputrec->ld_seed, legacySimulatorData->inputrec->opts.ngtc,
+            legacySimulatorData->inputrec->delta_t * legacySimulatorData->inputrec->nsttcouple,
+            legacySimulatorData->inputrec->opts.ref_t, legacySimulatorData->inputrec->opts.tau_t,
+            legacySimulatorData->inputrec->opts.nrdf, energyData, legacySimulatorData->state_global,
+            legacySimulatorData->cr, legacySimulatorData->inputrec->bContinuation));
+    auto* thermostat = static_cast<VRescaleThermostat*>(element);
+    builderHelper->registerThermostat([thermostat](const PropagatorThermostatConnection& connection) {
+        thermostat->connectWithPropagator(connection);
+    });
+    return element;
+}
+
 } // namespace gmx
index 9c5906eb76de6d3bf7bc6e1e68e565feb4746a4d..d7c3a2bdbfd5fe804180182d4dbe406fc8262135 100644 (file)
@@ -54,6 +54,14 @@ struct t_commrec;
 
 namespace gmx
 {
+class LegacySimulatorData;
+
+//! Enum describing whether the thermostat is using full or half step kinetic energy
+enum class VRescaleThermostatUseFullStepKE
+{
+    Yes,
+    No
+};
 
 /*! \internal
  * \ingroup module_modularsimulator
@@ -66,21 +74,19 @@ class VRescaleThermostat final : public ISimulatorElement, public ICheckpointHel
 {
 public:
     //! Constructor
-    VRescaleThermostat(int                   nstcouple,
-                       int                   offset,
-                       bool                  useFullStepKE,
-                       int64_t               seed,
-                       int                   numTemperatureGroups,
-                       double                couplingTimeStep,
-                       const real*           referenceTemperature,
-                       const real*           couplingTime,
-                       const real*           numDegreesOfFreedom,
-                       EnergyData*           energyData,
-                       ArrayRef<real>        lambdaView,
-                       PropagatorCallbackPtr propagatorCallback,
-                       const t_state*        globalState,
-                       t_commrec*            cr,
-                       bool                  isRestart);
+    VRescaleThermostat(int            nstcouple,
+                       int            offset,
+                       bool           useFullStepKE,
+                       int64_t        seed,
+                       int            numTemperatureGroups,
+                       double         couplingTimeStep,
+                       const real*    referenceTemperature,
+                       const real*    couplingTime,
+                       const real*    numDegreesOfFreedom,
+                       EnergyData*    energyData,
+                       const t_state* globalState,
+                       t_commrec*     cr,
+                       bool           isRestart);
 
     /*! \brief Register run function for step / time
      *
@@ -90,14 +96,39 @@ public:
      */
     void scheduleTask(Step step, Time time, const RegisterRunFunctionPtr& registerRunFunction) override;
 
-    //! No element setup needed
-    void elementSetup() override {}
+    //! Sanity check at setup time
+    void elementSetup() override;
     //! No element teardown needed
     void elementTeardown() override {}
 
     //! Getter for the thermostatIntegral
     const std::vector<double>& thermostatIntegral() const;
 
+    //! Connect this to propagator
+    void connectWithPropagator(const PropagatorThermostatConnection& connectionData);
+
+    /*! \brief Factory method implementation
+     *
+     * \param legacySimulatorData  Pointer allowing access to simulator level data
+     * \param builderHelper  ModularSimulatorAlgorithmBuilder helper object
+     * \param statePropagatorData  Pointer to the \c StatePropagatorData object
+     * \param energyData  Pointer to the \c EnergyData object
+     * \param freeEnergyPerturbationData  Pointer to the \c FreeEnergyPerturbationData object
+     * \param globalCommunicationHelper  Pointer to the \c GlobalCommunicationHelper object
+     * \param offset  The step offset at which the thermostat is applied
+     * \param useFullStepKE  Whether full step or half step KE is used
+     *
+     * \return  Pointer to the element to be added. Element needs to have been stored using \c storeElement
+     */
+    static ISimulatorElement* getElementPointerImpl(LegacySimulatorData* legacySimulatorData,
+                                                    ModularSimulatorAlgorithmBuilderHelper* builderHelper,
+                                                    StatePropagatorData*        statePropagatorData,
+                                                    EnergyData*                 energyData,
+                                                    FreeEnergyPerturbationData* freeEnergyPerturbationData,
+                                                    GlobalCommunicationHelper* globalCommunicationHelper,
+                                                    int                        offset,
+                                                    VRescaleThermostatUseFullStepKE useFullStepKE);
+
 private:
     //! The frequency at which the thermostat is applied
     const int nstcouple_;