Remove explicit dependency of EnergyData on thermostat / barostats
authorPascal Merz <pascal.merz@me.com>
Mon, 12 Apr 2021 15:39:51 +0000 (15:39 +0000)
committerJoe Jordan <ejjordan12@gmail.com>
Mon, 12 Apr 2021 15:39:51 +0000 (15:39 +0000)
src/gromacs/modularsimulator/energydata.cpp
src/gromacs/modularsimulator/energydata.h
src/gromacs/modularsimulator/modularsimulator.cpp
src/gromacs/modularsimulator/parrinellorahmanbarostat.cpp
src/gromacs/modularsimulator/parrinellorahmanbarostat.h
src/gromacs/modularsimulator/velocityscalingtemperaturecoupling.cpp
src/gromacs/modularsimulator/velocityscalingtemperaturecoupling.h

index 30d55811d566ae7c3cd1bc510f5f92102406e90f..32e3b1892226ed36fd285be1eb1b74261ad95291 100644 (file)
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/math/vec.h"
 #include "gromacs/mdlib/compute_io.h"
-#include "gromacs/mdlib/coupling.h"
 #include "gromacs/mdlib/enerdata_utils.h"
 #include "gromacs/mdlib/energyoutput.h"
 #include "gromacs/mdlib/mdatoms.h"
 #include "gromacs/mdlib/mdoutf.h"
 #include "gromacs/mdlib/stat.h"
 #include "gromacs/mdlib/update.h"
-#include "gromacs/mdrunutility/handlerestart.h"
-#include "gromacs/mdtypes/checkpointdata.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/enerdata.h"
 #include "gromacs/mdtypes/energyhistory.h"
 
 #include "freeenergyperturbationdata.h"
 #include "modularsimulator.h"
-#include "parrinellorahmanbarostat.h"
 #include "simulatoralgorithm.h"
 #include "statepropagatordata.h"
-#include "velocityscalingtemperaturecoupling.h"
 
 struct pull_t;
 class t_state;
@@ -104,8 +99,6 @@ EnergyData::EnergyData(StatePropagatorData*        statePropagatorData,
     startingBehavior_(startingBehavior),
     statePropagatorData_(statePropagatorData),
     freeEnergyPerturbationData_(freeEnergyPerturbationData),
-    velocityScalingTemperatureCoupling_(nullptr),
-    parrinelloRahmanBarostat_(nullptr),
     inputrec_(inputrec),
     top_global_(globalTopology),
     mdAtoms_(mdAtoms),
@@ -140,8 +133,8 @@ void EnergyData::Element::scheduleTask(Step step, Time time, const RegisterRunFu
     auto isFreeEnergyCalculationStep = freeEnergyCalculationStep_ == step;
     if (isEnergyCalculationStep || writeEnergy)
     {
-        registerRunFunction([this, time, isEnergyCalculationStep, isFreeEnergyCalculationStep]() {
-            energyData_->doStep(time, isEnergyCalculationStep, isFreeEnergyCalculationStep);
+        registerRunFunction([this, step, time, isEnergyCalculationStep, isFreeEnergyCalculationStep]() {
+            energyData_->doStep(step, time, isEnergyCalculationStep, isFreeEnergyCalculationStep);
         });
     }
     else
@@ -239,7 +232,7 @@ std::optional<SignallerCallback> EnergyData::Element::registerEnergyCallback(Ene
     return std::nullopt;
 }
 
-void EnergyData::doStep(Time time, bool isEnergyCalculationStep, bool isFreeEnergyCalculationStep)
+void EnergyData::doStep(Step step, Time time, bool isEnergyCalculationStep, bool isFreeEnergyCalculationStep)
 {
     enerd_->term[F_ETOT] = enerd_->term[F_EPOT] + enerd_->term[F_EKIN];
     if (freeEnergyPerturbationData_)
@@ -249,12 +242,11 @@ void EnergyData::doStep(Time time, bool isEnergyCalculationStep, bool isFreeEner
     }
     if (integratorHasConservedEnergyQuantity(inputrec_))
     {
-        enerd_->term[F_ECONSERVED] =
-                enerd_->term[F_ETOT]
-                + (velocityScalingTemperatureCoupling_
-                           ? velocityScalingTemperatureCoupling_->conservedEnergyContribution()
-                           : 0)
-                + (parrinelloRahmanBarostat_ ? parrinelloRahmanBarostat_->conservedEnergyContribution() : 0);
+        enerd_->term[F_ECONSERVED] = enerd_->term[F_ETOT];
+        for (const auto& energyContibution : conservedEnergyContributions_)
+        {
+            enerd_->term[F_ECONSERVED] += energyContibution(step, time);
+        }
     }
     matrix nullMatrix = {};
     energyOutput_->addDataAtEnergyStep(
@@ -266,7 +258,7 @@ void EnergyData::doStep(Time time, bool isEnergyCalculationStep, bool isFreeEner
             inputrec_->fepvals.get(),
             inputrec_->expandedvals.get(),
             statePropagatorData_->constPreviousBox(),
-            PTCouplingArrays({ parrinelloRahmanBarostat_ ? parrinelloRahmanBarostat_->boxVelocities() : nullMatrix,
+            PTCouplingArrays({ parrinelloRahmanBoxVelocities_ ? parrinelloRahmanBoxVelocities_() : nullMatrix,
                                {},
                                {},
                                {},
@@ -505,14 +497,16 @@ void EnergyData::initializeEnergyHistory(StartingBehavior    startingBehavior,
     energyOutput->fillEnergyHistory(observablesHistory->energyHistory.get());
 }
 
-void EnergyData::setVelocityScalingTemperatureCoupling(const VelocityScalingTemperatureCoupling* velocityScalingTemperatureCoupling)
+void EnergyData::addConservedEnergyContribution(EnergyContribution&& energyContribution)
 {
-    velocityScalingTemperatureCoupling_ = velocityScalingTemperatureCoupling;
+    conservedEnergyContributions_.emplace_back(std::move(energyContribution));
 }
 
-void EnergyData::setParrinelloRahamnBarostat(const gmx::ParrinelloRahmanBarostat* parrinelloRahmanBarostat)
+void EnergyData::setParrinelloRahmanBoxVelocities(std::function<const rvec*()>&& parrinelloRahmanBoxVelocities)
 {
-    parrinelloRahmanBarostat_ = parrinelloRahmanBarostat;
+    GMX_RELEASE_ASSERT(!parrinelloRahmanBoxVelocities_,
+                       "Received a second callback to the Parrinello-Rahman velocities");
+    parrinelloRahmanBoxVelocities_ = parrinelloRahmanBoxVelocities;
 }
 
 EnergyData::Element* EnergyData::element()
index 0020dd04da771c4d84b81830fdc548b20f2c5ed0..1359c68bb62c11cca07a049d7d345d4ba4c65e00 100644 (file)
@@ -44,7 +44,6 @@
 #ifndef GMX_ENERGYELEMENT_MICROSTATE_H
 #define GMX_ENERGYELEMENT_MICROSTATE_H
 
-#include "gromacs/math/vectypes.h"
 #include "gromacs/mdtypes/state.h"
 
 #include "modularsimulatorinterfaces.h"
@@ -72,6 +71,9 @@ class StatePropagatorData;
 class VelocityScalingTemperatureCoupling;
 struct MdModulesNotifier;
 
+//! Function type for elements contributing energy
+using EnergyContribution = std::function<real(Step, Time)>;
+
 /*! \internal
  * \ingroup module_modularsimulator
  * \brief Data class managing energies
@@ -188,21 +190,19 @@ public:
      */
     [[nodiscard]] bool hasReadEkinFromCheckpoint() const;
 
-    /*! \brief Set velocity scaling temperature coupling
+    /*! \brief Add conserved energy contribution
      *
-     * This allows to set a pointer to a velocity scaling temperature coupling
-     * element used to obtain contributions to the conserved energy.
-     * TODO: This should be made obsolete my a more modular energy element
+     * This allows other elements to register callbacks for contributions to
+     * the conserved energy term.
      */
-    void setVelocityScalingTemperatureCoupling(const VelocityScalingTemperatureCoupling* velocityScalingTemperatureCoupling);
+    void addConservedEnergyContribution(EnergyContribution&& energyContribution);
 
     /*! \brief set Parrinello-Rahman barostat
      *
      * This allows to set a pointer to the Parrinello-Rahman barostat used to
      * print the box velocities.
-     * TODO: This should be made obsolete my a more modular energy element
      */
-    void setParrinelloRahamnBarostat(const ParrinelloRahmanBarostat* parrinelloRahmanBarostat);
+    void setParrinelloRahmanBoxVelocities(std::function<const rvec*()>&& parrinelloRahmanBoxVelocities);
 
     /*! \brief Initialize energy history
      *
@@ -233,7 +233,7 @@ private:
      * \param isEnergyCalculationStep  Whether the current step is an energy calculation step
      * \param isFreeEnergyCalculationStep  Whether the current step is a free energy calculation step
      */
-    void doStep(Time time, bool isEnergyCalculationStep, bool isFreeEnergyCalculationStep);
+    void doStep(Step step, Time time, bool isEnergyCalculationStep, bool isFreeEnergyCalculationStep);
 
     /*! \brief Write to energy trajectory
      *
@@ -290,10 +290,12 @@ private:
     StatePropagatorData* statePropagatorData_;
     //! Pointer to the free energy perturbation data
     FreeEnergyPerturbationData* freeEnergyPerturbationData_;
-    //! Pointer to the vrescale thermostat
-    const VelocityScalingTemperatureCoupling* velocityScalingTemperatureCoupling_;
-    //! Pointer to the Parrinello-Rahman barostat
-    const ParrinelloRahmanBarostat* parrinelloRahmanBarostat_;
+
+    //! Callbacks contributing to the conserved energy term
+    std::vector<EnergyContribution> conservedEnergyContributions_;
+    //! Callback to the Parrinello-Rahman box velocities
+    std::function<const rvec*()> parrinelloRahmanBoxVelocities_;
+
     //! Contains user input mdp options.
     const t_inputrec* inputrec_;
     //! Full system topology.
index 0f079dc1ce9b856a669661f7d649bbe5e7cfa5ac..6a25af5285a4c7994ce2ccc5468cca119b1bcf9a 100644 (file)
@@ -126,7 +126,6 @@ void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder*
             builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
         builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::LeapFrog>>();
-        builder->add<EnergyData::Element>();
         if (legacySimulatorData_->inputrec->epc == PressureCoupling::ParrinelloRahman)
         {
             builder->add<ParrinelloRahmanBarostat>(-1, PropagatorTag("LeapFrogPropagator"));
@@ -160,7 +159,6 @@ void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder*
             builder->add<ConstraintsElement<ConstraintVariable::Positions>>();
         }
         builder->add<ComputeGlobalsElement<ComputeGlobalsAlgorithm::VelocityVerlet>>();
-        builder->add<EnergyData::Element>();
         if (legacySimulatorData_->inputrec->epc == PressureCoupling::ParrinelloRahman)
         {
             builder->add<ParrinelloRahmanBarostat>(-1, PropagatorTag("VelocityHalfStep"));
@@ -170,6 +168,7 @@ void ModularSimulator::addIntegrationElements(ModularSimulatorAlgorithmBuilder*
     {
         gmx_fatal(FARGS, "Integrator not implemented for the modular simulator.");
     }
+    builder->add<EnergyData::Element>();
 }
 
 bool ModularSimulator::isInputCompatible(bool                             exitOnFailure,
index d01e6fd4d82e2d5355329fa59c306edc8763de6b..f35d1846ae27e7110e5e3c91b662093424fcfba7 100644 (file)
@@ -81,11 +81,17 @@ ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                  nstpcoup
     boxVelocity_{ { 0 } },
     statePropagatorData_(statePropagatorData),
     energyData_(energyData),
+    nextEnergyCalculationStep_(-1),
     fplog_(fplog),
     inputrec_(inputrec),
     mdAtoms_(mdAtoms)
 {
-    energyData->setParrinelloRahamnBarostat(this);
+    energyData->setParrinelloRahmanBoxVelocities([this]() { return boxVelocity_; });
+    energyData->addConservedEnergyContribution([this](Step gmx_used_in_debug step, Time /*unused*/) {
+        GMX_ASSERT(conservedEnergyContributionStep_ == step,
+                   "Parrinello-Rahman conserved energy step mismatch.");
+        return conservedEnergyContribution_;
+    });
 }
 
 void ParrinelloRahmanBarostat::connectWithMatchingPropagator(const PropagatorBarostatConnection& connectionData,
@@ -104,7 +110,16 @@ void ParrinelloRahmanBarostat::scheduleTask(Step step,
 {
     const bool scaleOnNextStep = do_per_step(step + nstpcouple_ + offset_ + 1, nstpcouple_);
     const bool scaleOnThisStep = do_per_step(step + nstpcouple_ + offset_, nstpcouple_);
+    const bool contributeEnergyThisStep = (step == nextEnergyCalculationStep_);
 
+    if (contributeEnergyThisStep)
+    {
+        // For compatibility with legacy md, we store this before integrating the box velocities
+        registerRunFunction([this, step]() {
+            conservedEnergyContribution_     = conservedEnergyContribution();
+            conservedEnergyContributionStep_ = step;
+        });
+    }
     if (scaleOnThisStep)
     {
         registerRunFunction([this]() { scaleBoxAndPositions(); });
@@ -298,6 +313,15 @@ const std::string& ParrinelloRahmanBarostat::clientID()
     return identifier_;
 }
 
+std::optional<SignallerCallback> ParrinelloRahmanBarostat::registerEnergyCallback(EnergySignallerEvent event)
+{
+    if (event == EnergySignallerEvent::EnergyCalculationStep)
+    {
+        return [this](Step step, Time /*unused*/) { nextEnergyCalculationStep_ = step; };
+    }
+    return std::nullopt;
+}
+
 ISimulatorElement* ParrinelloRahmanBarostat::getElementPointerImpl(
         LegacySimulatorData*                    legacySimulatorData,
         ModularSimulatorAlgorithmBuilderHelper* builderHelper,
index ee11ececf122107112c145dcf7850548af60f408..91bc1a1223f5036e169acd0acc27f339c7ea9deb 100644 (file)
@@ -70,7 +70,7 @@ class StatePropagatorData;
  *     scaling factor, and
  *   * scales the box and the positions of the system.
  */
-class ParrinelloRahmanBarostat final : public ISimulatorElement, public ICheckpointHelperClient
+class ParrinelloRahmanBarostat final : public ISimulatorElement, public ICheckpointHelperClient, public IEnergySignallerClient
 {
 public:
     //! Constructor
@@ -99,8 +99,6 @@ public:
 
     //! Getter for the box velocities
     [[nodiscard]] const rvec* boxVelocities() const;
-    //! Contribution to the conserved energy (called by energy data)
-    [[nodiscard]] real conservedEnergyContribution() const;
 
     //! Connect this to propagator
     void connectWithMatchingPropagator(const PropagatorBarostatConnection& connectionData,
@@ -158,6 +156,11 @@ private:
     //! Box velocity
     tensor boxVelocity_;
 
+    //! Current conserved energy contribution
+    real conservedEnergyContribution_;
+    //! Step of current conserved energy contribution
+    Step conservedEnergyContributionStep_;
+
     // TODO: Clarify relationship to data objects and find a more robust alternative to raw pointers (#3583)
     //! Pointer to the micro state
     StatePropagatorData* statePropagatorData_;
@@ -175,6 +178,14 @@ private:
     template<CheckpointDataOperation operation>
     void doCheckpointData(CheckpointData<operation>* checkpointData);
 
+    //! IEnergySignallerClient implementation
+    std::optional<SignallerCallback> registerEnergyCallback(EnergySignallerEvent event) override;
+    //! The next communicated energy calculation step
+    Step nextEnergyCalculationStep_;
+
+    //! Contribution to the conserved energy
+    [[nodiscard]] real conservedEnergyContribution() const;
+
     // Access to ISimulator data
     //! Handles logging.
     FILE* fplog_;
index 58a0fed73fcbe6c8a4416f07827b33bc51950280..c7cef4917beeb48ac889a5db33dd09c8ffe97734 100644 (file)
@@ -294,13 +294,9 @@ VelocityScalingTemperatureCoupling::VelocityScalingTemperatureCoupling(
     couplingTime_(couplingTime, couplingTime + numTemperatureGroups),
     numDegreesOfFreedom_(numDegreesOfFreedom, numDegreesOfFreedom + numTemperatureGroups),
     temperatureCouplingIntegral_(numTemperatureGroups, 0.0),
-    energyData_(energyData)
+    energyData_(energyData),
+    nextEnergyCalculationStep_(-1)
 {
-    if (reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::Yes)
-    {
-        temperatureCouplingIntegralPreviousStep_ = temperatureCouplingIntegral_;
-    }
-    energyData->setVelocityScalingTemperatureCoupling(this);
     if (couplingType == TemperatureCoupling::VRescale)
     {
         temperatureCouplingImpl_ = std::make_unique<VRescaleTemperatureCoupling>(seed);
@@ -314,6 +310,11 @@ VelocityScalingTemperatureCoupling::VelocityScalingTemperatureCoupling(
         throw NotImplementedError("Temperature coupling " + std::string(enumValueToString(couplingType))
                                   + " is not implemented for modular simulator.");
     }
+    energyData->addConservedEnergyContribution([this](Step gmx_used_in_debug step, Time /*unused*/) {
+        GMX_ASSERT(conservedEnergyContributionStep_ == step,
+                   "VelocityScalingTemperatureCoupling conserved energy step mismatch.");
+        return conservedEnergyContribution_;
+    });
 }
 
 void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorThermostatConnection& connectionData,
@@ -351,6 +352,15 @@ void VelocityScalingTemperatureCoupling::scheduleTask(Step step,
      *       of the kinetic energy is needed.
      *
      */
+    if (step == nextEnergyCalculationStep_
+        && reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::Yes)
+    {
+        // add conserved energy before we do T-coupling
+        registerRunFunction([this, step]() {
+            conservedEnergyContribution_     = conservedEnergyContribution();
+            conservedEnergyContributionStep_ = step;
+        });
+    }
     if (do_per_step(step + nstcouple_ + offset_, nstcouple_))
     {
         // do T-coupling this step
@@ -359,16 +369,19 @@ void VelocityScalingTemperatureCoupling::scheduleTask(Step step,
         // Let propagator know that we want to do T-coupling
         propagatorCallback_(step);
     }
+    if (step == nextEnergyCalculationStep_
+        && reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::No)
+    {
+        // add conserved energy after we did T-coupling
+        registerRunFunction([this, step]() {
+            conservedEnergyContribution_     = conservedEnergyContribution();
+            conservedEnergyContributionStep_ = step;
+        });
+    }
 }
 
 void VelocityScalingTemperatureCoupling::setLambda(Step step)
 {
-    // if we report the previous energy, calculate before the step
-    if (reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::Yes)
-    {
-        temperatureCouplingIntegralPreviousStep_ = temperatureCouplingIntegral_;
-    }
-
     const auto*             ekind          = energyData_->ekindata();
     TemperatureCouplingData thermostatData = {
         couplingTimeStep_, referenceTemperature_, couplingTime_, numDegreesOfFreedom_, temperatureCouplingIntegral_
@@ -454,17 +467,16 @@ const std::string& VelocityScalingTemperatureCoupling::clientID()
 
 real VelocityScalingTemperatureCoupling::conservedEnergyContribution() const
 {
-    if (reportPreviousConservedEnergy_ == ReportPreviousStepConservedEnergy::Yes)
-    {
-        return std::accumulate(temperatureCouplingIntegralPreviousStep_.begin(),
-                               temperatureCouplingIntegralPreviousStep_.end(),
-                               0.0);
-    }
-    else
+    return std::accumulate(temperatureCouplingIntegral_.begin(), temperatureCouplingIntegral_.end(), 0.0);
+}
+
+std::optional<SignallerCallback> VelocityScalingTemperatureCoupling::registerEnergyCallback(EnergySignallerEvent event)
+{
+    if (event == EnergySignallerEvent::EnergyCalculationStep)
     {
-        return std::accumulate(
-                temperatureCouplingIntegral_.begin(), temperatureCouplingIntegral_.end(), 0.0);
+        return [this](Step step, Time /*unused*/) { nextEnergyCalculationStep_ = step; };
     }
+    return std::nullopt;
 }
 
 ISimulatorElement* VelocityScalingTemperatureCoupling::getElementPointerImpl(
index 4b8947c2264507f0368d9a4969f5c0e681f22d81..6b70509a91f15bed981552bbcf79a2b12902c6ed 100644 (file)
@@ -86,7 +86,10 @@ enum class ReportPreviousStepConservedEnergy
  * implementations of the ITemperatureCouplingImpl interface, while the element
  * handles the scheduling and interfacing with other elements.
  */
-class VelocityScalingTemperatureCoupling final : public ISimulatorElement, public ICheckpointHelperClient
+class VelocityScalingTemperatureCoupling final :
+    public ISimulatorElement,
+    public ICheckpointHelperClient,
+    public IEnergySignallerClient
 {
 public:
     //! Constructor
@@ -116,9 +119,6 @@ public:
     //! No element teardown needed
     void elementTeardown() override {}
 
-    //! Contribution to the conserved energy (called by energy data)
-    [[nodiscard]] real conservedEnergyContribution() const;
-
     //! Connect this to propagator
     void connectWithMatchingPropagator(const PropagatorThermostatConnection& connectionData,
                                        const PropagatorTag&                  propagatorTag);
@@ -179,8 +179,11 @@ private:
     const std::vector<real> numDegreesOfFreedom_;
     //! Work exerted by thermostat per group
     std::vector<double> temperatureCouplingIntegral_;
-    //! Work exerted by thermostat per group (backup from previous step)
-    std::vector<double> temperatureCouplingIntegralPreviousStep_;
+
+    //! Current conserved energy contribution
+    real conservedEnergyContribution_;
+    //! Step of current conserved energy contribution
+    Step conservedEnergyContributionStep_;
 
     // TODO: Clarify relationship to data objects and find a more robust alternative to raw pointers (#3583)
     //! Pointer to the energy data (for ekindata)
@@ -191,6 +194,8 @@ private:
 
     //! Set new lambda value (at T-coupling steps)
     void setLambda(Step step);
+    //! Contribution to the conserved energy
+    [[nodiscard]] real conservedEnergyContribution() const;
 
     //! The temperature coupling implementation
     std::unique_ptr<ITemperatureCouplingImpl> temperatureCouplingImpl_;
@@ -200,6 +205,11 @@ private:
     //! Helper function to read from / write to CheckpointData
     template<CheckpointDataOperation operation>
     void doCheckpointData(CheckpointData<operation>* checkpointData);
+
+    //! IEnergySignallerClient implementation
+    std::optional<SignallerCallback> registerEnergyCallback(EnergySignallerEvent event) override;
+    //! The next communicated energy calculation step
+    Step nextEnergyCalculationStep_;
 };
 
 } // namespace gmx