Unify temperature and pressure coupling connection
authorPascal Merz <pascal.merz@me.com>
Thu, 6 May 2021 04:43:26 +0000 (04:43 +0000)
committerMark Abraham <mark.j.abraham@gmail.com>
Thu, 6 May 2021 04:43:26 +0000 (04:43 +0000)
src/gromacs/modularsimulator/modularsimulatorinterfaces.h
src/gromacs/modularsimulator/parrinellorahmanbarostat.cpp
src/gromacs/modularsimulator/parrinellorahmanbarostat.h
src/gromacs/modularsimulator/propagator.cpp
src/gromacs/modularsimulator/simulatoralgorithm.cpp
src/gromacs/modularsimulator/simulatoralgorithm.h
src/gromacs/modularsimulator/velocityscalingtemperaturecoupling.cpp
src/gromacs/modularsimulator/velocityscalingtemperaturecoupling.h

index 44fac47aaa4e257ed02d3f331e32ac9dd47f81e1..3dcbb6b03a2ec27c2962087699de1d62431a6301 100644 (file)
@@ -505,33 +505,52 @@ private:
 };
 
 /*! \internal
- * \brief Information needed to connect a propagator to a thermostat
+ * \brief Information needed to connect a propagator to a temperature and / or pressure coupling element
  */
-struct PropagatorThermostatConnection
+struct PropagatorConnection
 {
-    //! Function variable for setting velocity scaling variables.
+    //! The tag of the creating propagator
+    PropagatorTag tag;
+
+    //! Whether the propagator offers start velocity scaling
+    bool hasStartVelocityScaling() const
+    {
+        return setNumVelocityScalingVariables && getVelocityScalingCallback && getViewOnStartVelocityScaling;
+    }
+    //! Whether the propagator offers end velocity scaling
+    bool hasEndVelocityScaling() const
+    {
+        return setNumVelocityScalingVariables && getVelocityScalingCallback && getViewOnEndVelocityScaling;
+    }
+    //! Whether the propagator offers position scaling
+    bool hasPositionScaling() const
+    {
+        return setNumPositionScalingVariables && getPositionScalingCallback && getViewOnPositionScaling;
+    }
+    //! Whether the propagator offers Parrinello-Rahman scaling
+    bool hasParrinelloRahmanScaling() const
+    {
+        return getPRScalingCallback && getViewOnPRScalingMatrix;
+    }
+
+    //! Function object for setting velocity scaling variables
     std::function<void(int, ScaleVelocities)> setNumVelocityScalingVariables;
-    //! Function variable for receiving view on velocity scaling (before step).
+    //! Function object for setting velocity scaling variables
+    std::function<void(int)> setNumPositionScalingVariables;
+    //! Function object for receiving view on velocity scaling (before step)
     std::function<ArrayRef<real>()> getViewOnStartVelocityScaling;
-    //! Function variable for receiving view on velocity scaling (after step).
+    //! Function object for receiving view on velocity scaling (after step)
     std::function<ArrayRef<real>()> getViewOnEndVelocityScaling;
-    //! Function variable for callback.
+    //! Function object for receiving view on position scaling
+    std::function<ArrayRef<real>()> getViewOnPositionScaling;
+    //! Function object to request callback allowing to signal a velocity scaling step
     std::function<PropagatorCallback()> getVelocityScalingCallback;
-    //! The tag of the creating propagator
-    PropagatorTag tag;
-};
-
-/*! \internal
- * \brief Information needed to connect a propagator to a barostat
- */
-struct PropagatorBarostatConnection
-{
-    //! Function variable for receiving view on pressure scaling matrix.
+    //! Function object to request callback allowing to signal a position scaling step
+    std::function<PropagatorCallback()> getPositionScalingCallback;
+    //! Function object for receiving view on pressure scaling matrix
     std::function<ArrayRef<rvec>()> getViewOnPRScalingMatrix;
-    //! Function variable for callback.
+    //! Function object to request callback allowing to signal a Parrinello-Rahman scaling step
     std::function<PropagatorCallback()> getPRScalingCallback;
-    //! The tag of the creating propagator
-    PropagatorTag tag;
 };
 
 //! /}
index 72f2d3bf8b3295824b8c540d5a59d13f7d83f498..b06b8744dc1af25c251c38dd80c737ede651dcfe 100644 (file)
@@ -94,11 +94,13 @@ ParrinelloRahmanBarostat::ParrinelloRahmanBarostat(int                  nstpcoup
     });
 }
 
-void ParrinelloRahmanBarostat::connectWithMatchingPropagator(const PropagatorBarostatConnection& connectionData,
+void ParrinelloRahmanBarostat::connectWithMatchingPropagator(const PropagatorConnection& connectionData,
                                                              const PropagatorTag& propagatorTag)
 {
     if (connectionData.tag == propagatorTag)
     {
+        GMX_RELEASE_ASSERT(connectionData.hasParrinelloRahmanScaling(),
+                           "Connection data lacks Parrinello-Rahman scaling");
         scalingTensor_      = connectionData.getViewOnPRScalingMatrix();
         propagatorCallback_ = connectionData.getPRScalingCallback();
     }
@@ -343,9 +345,10 @@ ISimulatorElement* ParrinelloRahmanBarostat::getElementPointerImpl(
             legacySimulatorData->inputrec,
             legacySimulatorData->mdAtoms));
     auto* barostat = static_cast<ParrinelloRahmanBarostat*>(element);
-    builderHelper->registerBarostat([barostat, propagatorTag](const PropagatorBarostatConnection& connection) {
-        barostat->connectWithMatchingPropagator(connection, propagatorTag);
-    });
+    builderHelper->registerTemperaturePressureControl(
+            [barostat, propagatorTag](const PropagatorConnection& connection) {
+                barostat->connectWithMatchingPropagator(connection, propagatorTag);
+            });
     return element;
 }
 
index 211ac669dd5b0677bcf5f0f032d10f9e4b0316d1..b49ba6de17e129bdd5199971780219ed2a9925fd 100644 (file)
@@ -101,8 +101,8 @@ public:
     [[nodiscard]] const rvec* boxVelocities() const;
 
     //! Connect this to propagator
-    void connectWithMatchingPropagator(const PropagatorBarostatConnection& connectionData,
-                                       const PropagatorTag&                propagatorTag);
+    void connectWithMatchingPropagator(const PropagatorConnection& connectionData,
+                                       const PropagatorTag&        propagatorTag);
 
     //! ICheckpointHelperClient write checkpoint implementation
     void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
index 2b2df893482483c8709076ad48052442993a803a..0e56ed06608b66aee1cf4442d8d00c9903c0b157 100644 (file)
@@ -896,6 +896,76 @@ PropagatorCallback Propagator<integrationStage>::prScalingCallback()
     return [this](Step step) { scalingStepPR_ = step; };
 }
 
+template<IntegrationStage integrationStage>
+static PropagatorConnection getConnection(Propagator<integrationStage> gmx_unused* propagator,
+                                          const PropagatorTag&                     propagatorTag)
+{
+    // gmx_unused is needed because gcc-7 & gcc-8 can't see that
+    // propagator is used for all IntegrationStage options
+
+    PropagatorConnection propagatorConnection{ propagatorTag };
+
+    // The clang-tidy version on our current CI throws 3 different warnings
+    // for the if constexpr lines, so disable linting for now. Also, this only
+    // works if the brace is on the same line, so turn off clang-format as well
+    // clang-format off
+    // NOLINTNEXTLINE
+    if constexpr (hasStartVelocityScaling<integrationStage>() || hasEndVelocityScaling<integrationStage>()) {
+        // clang-format on
+        propagatorConnection.setNumVelocityScalingVariables =
+                [propagator](int num, ScaleVelocities scaleVelocities) {
+                    propagator->setNumVelocityScalingVariables(num, scaleVelocities);
+                };
+        propagatorConnection.getVelocityScalingCallback = [propagator]() {
+            return propagator->velocityScalingCallback();
+        };
+    }
+    // clang-format off
+    // NOLINTNEXTLINE
+    if constexpr (hasStartVelocityScaling<integrationStage>()) {
+        // clang-format on
+        propagatorConnection.getViewOnStartVelocityScaling = [propagator]() {
+            return propagator->viewOnStartVelocityScaling();
+        };
+    }
+    // clang-format off
+    // NOLINTNEXTLINE
+    if constexpr (hasEndVelocityScaling<integrationStage>()) {
+        // clang-format on
+        propagatorConnection.getViewOnEndVelocityScaling = [propagator]() {
+            return propagator->viewOnEndVelocityScaling();
+        };
+    }
+    // clang-format off
+    // NOLINTNEXTLINE
+    if constexpr (hasPositionScaling<integrationStage>()) {
+        // clang-format on
+        propagatorConnection.setNumPositionScalingVariables = [propagator](int num) {
+            propagator->setNumPositionScalingVariables(num);
+        };
+        propagatorConnection.getViewOnPositionScaling = [propagator]() {
+            return propagator->viewOnPositionScaling();
+        };
+        propagatorConnection.getPositionScalingCallback = [propagator]() {
+            return propagator->positionScalingCallback();
+        };
+    }
+    // clang-format off
+    // NOLINTNEXTLINE
+    if constexpr (hasParrinelloRahmanScaling<integrationStage>()) {
+        // clang-format on
+        propagatorConnection.getViewOnPRScalingMatrix = [propagator]() {
+            return propagator->viewOnPRScalingMatrix();
+        };
+        propagatorConnection.getPRScalingCallback = [propagator]() {
+            return propagator->prScalingCallback();
+        };
+    }
+
+    // NOLINTNEXTLINE(readability-misleading-indentation)
+    return propagatorConnection;
+}
+
 // doxygen is confused by the two definitions
 //! \cond
 template<IntegrationStage integrationStage>
@@ -916,18 +986,7 @@ ISimulatorElement* Propagator<integrationStage>::getElementPointerImpl(
     auto* element    = builderHelper->storeElement(std::make_unique<Propagator<integrationStage>>(
             timestep, statePropagatorData, legacySimulatorData->mdAtoms, legacySimulatorData->wcycle));
     auto* propagator = static_cast<Propagator<integrationStage>*>(element);
-    builderHelper->registerWithThermostat(
-            { [propagator](int num, ScaleVelocities scaleVelocities) {
-                 propagator->setNumVelocityScalingVariables(num, scaleVelocities);
-             },
-              [propagator]() { return propagator->viewOnStartVelocityScaling(); },
-              [propagator]() { return propagator->viewOnEndVelocityScaling(); },
-              [propagator]() { return propagator->velocityScalingCallback(); },
-              propagatorTag });
-    builderHelper->registerWithBarostat(
-            { [propagator]() { return propagator->viewOnPRScalingMatrix(); },
-              [propagator]() { return propagator->prScalingCallback(); },
-              propagatorTag });
+    builderHelper->registerPropagator(getConnection<integrationStage>(propagator, propagatorTag));
     return element;
 }
 
index b6aeee83327468240621722a9ecc1602c5efb327..bfb6c5cd28a6fb104ba2a0edd5caa7127a1b4ece 100644 (file)
 #include "energydata.h"
 #include "freeenergyperturbationdata.h"
 #include "modularsimulator.h"
-#include "parrinellorahmanbarostat.h"
 #include "pmeloadbalancehelper.h"
 #include "propagator.h"
 #include "statepropagatordata.h"
-#include "velocityscalingtemperaturecoupling.h"
 
 namespace gmx
 {
@@ -453,18 +451,11 @@ ModularSimulatorAlgorithm ModularSimulatorAlgorithmBuilder::build()
     algorithmHasBeenBuilt_ = true;
 
     // Connect propagators with thermostat / barostat
-    for (const auto& thermostatRegistration : thermostatRegistrationFunctions_)
+    for (const auto& registrationFunction : pressureTemperatureControlRegistrationFunctions_)
     {
-        for (const auto& connection : propagatorThermostatConnections_)
+        for (const auto& connection : propagatorConnections_)
         {
-            thermostatRegistration(connection);
-        }
-    }
-    for (const auto& barostatRegistration : barostatRegistrationFunctions_)
-    {
-        for (const auto& connection : propagatorBarostatConnections_)
-        {
-            barostatRegistration(connection);
+            registrationFunction(connection);
         }
     }
 
@@ -778,27 +769,15 @@ std::optional<std::any> ModularSimulatorAlgorithmBuilderHelper::getStoredValue(c
     }
 }
 
-void ModularSimulatorAlgorithmBuilderHelper::registerThermostat(
-        std::function<void(const PropagatorThermostatConnection&)> registrationFunction)
+void ModularSimulatorAlgorithmBuilderHelper::registerTemperaturePressureControl(
+        std::function<void(const PropagatorConnection&)> registrationFunction)
 {
-    builder_->thermostatRegistrationFunctions_.emplace_back(std::move(registrationFunction));
+    builder_->pressureTemperatureControlRegistrationFunctions_.emplace_back(std::move(registrationFunction));
 }
 
-void ModularSimulatorAlgorithmBuilderHelper::registerBarostat(
-        std::function<void(const PropagatorBarostatConnection&)> registrationFunction)
+void ModularSimulatorAlgorithmBuilderHelper::registerPropagator(PropagatorConnection connectionData)
 {
-    builder_->barostatRegistrationFunctions_.emplace_back(std::move(registrationFunction));
+    builder_->propagatorConnections_.emplace_back(std::move(connectionData));
 }
 
-void ModularSimulatorAlgorithmBuilderHelper::registerWithThermostat(PropagatorThermostatConnection connectionData)
-{
-    builder_->propagatorThermostatConnections_.emplace_back(std::move(connectionData));
-}
-
-void ModularSimulatorAlgorithmBuilderHelper::registerWithBarostat(PropagatorBarostatConnection connectionData)
-{
-    builder_->propagatorBarostatConnections_.emplace_back(std::move(connectionData));
-}
-
-
 } // namespace gmx
index d2ce8534cc22138056799d74371f177002a15d87..a7e9e2a62f6331fd7ff191f82a4e614a7dd82edc 100644 (file)
@@ -324,14 +324,10 @@ public:
     void storeValue(const std::string& key, const ValueType& value);
     //! Get previously stored data. Returns std::nullopt if key is not found.
     std::optional<std::any> getStoredValue(const std::string& key) const;
-    //! Register a thermostat that accepts propagator registrations
-    void registerThermostat(std::function<void(const PropagatorThermostatConnection&)> registrationFunction);
-    //! Register a barostat that accepts propagator registrations
-    void registerBarostat(std::function<void(const PropagatorBarostatConnection&)> registrationFunction);
-    //! Register a propagator to the thermostat used
-    void registerWithThermostat(PropagatorThermostatConnection connectionData);
-    //! Register a propagator to the barostat used
-    void registerWithBarostat(PropagatorBarostatConnection connectionData);
+    //! Register temperature / pressure control algorithm to be matched with a propagator
+    void registerTemperaturePressureControl(std::function<void(const PropagatorConnection&)> registrationFunction);
+    //! Register a propagator to be used with a temperature / pressure control algorithm
+    void registerPropagator(PropagatorConnection connectionData);
 
 private:
     //! Pointer to the associated ModularSimulatorAlgorithmBuilder
@@ -474,14 +470,10 @@ private:
      */
     std::vector<ICheckpointHelperClient*> checkpointClients_;
 
-    //! List of thermostat registration functions
-    std::vector<std::function<void(const PropagatorThermostatConnection&)>> thermostatRegistrationFunctions_;
-    //! List of barostat registration functions
-    std::vector<std::function<void(const PropagatorBarostatConnection&)>> barostatRegistrationFunctions_;
-    //! List of data to connect propagators to thermostats
-    std::vector<PropagatorThermostatConnection> propagatorThermostatConnections_;
-    //! List of data to connect propagators to barostats
-    std::vector<PropagatorBarostatConnection> propagatorBarostatConnections_;
+    //! List of data to connect propagators to thermostats / barostats
+    std::vector<PropagatorConnection> propagatorConnections_;
+    //! List of temperature / pressure control registration functions
+    std::vector<std::function<void(const PropagatorConnection&)>> pressureTemperatureControlRegistrationFunctions_;
 };
 
 /*! \internal
index 9eef0d2707655c7f7e30c34573e2cfd84fd2d106..3f323c1acf04f1651d68632550bc743105eb694e 100644 (file)
@@ -88,8 +88,8 @@ class ITemperatureCouplingImpl
 {
 public:
     //! Allow access to the scaling vectors
-    virtual void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
-                                       int numTemperatureGroups) = 0;
+    virtual void connectWithPropagator(const PropagatorConnection& connectionData,
+                                       int                         numTemperatureGroups) = 0;
 
     /*! \brief Make a temperature control step
      *
@@ -177,9 +177,10 @@ public:
     }
 
     //! Connect with propagator - v-rescale only scales start step velocities
-    void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
-                               int                                   numTemperatureGroups) override
+    void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
     {
+        GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
+                           "V-Rescale requires start velocity scaling.");
         connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
     }
@@ -249,9 +250,10 @@ public:
     }
 
     //! Connect with propagator - Berendsen only scales start step velocities
-    void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
-                               int                                   numTemperatureGroups) override
+    void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
     {
+        GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
+                           "Berendsen T-coupling requires start velocity scaling.");
         connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
     }
@@ -363,9 +365,11 @@ public:
     }
 
     //! Connect with propagator - Nose-Hoover scales start and end step velocities
-    void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
-                               int                                   numTemperatureGroups) override
+    void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
     {
+        GMX_RELEASE_ASSERT(
+                connectionData.hasStartVelocityScaling() && connectionData.hasEndVelocityScaling(),
+                "Nose-Hoover T-coupling requires both start and end velocity scaling.");
         connectionData.setNumVelocityScalingVariables(numTemperatureGroups,
                                                       ScaleVelocities::PreStepAndPostStep);
         lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
@@ -489,7 +493,7 @@ VelocityScalingTemperatureCoupling::VelocityScalingTemperatureCoupling(
     });
 }
 
-void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorThermostatConnection& connectionData,
+void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorConnection& connectionData,
                                                                        const PropagatorTag& propagatorTag)
 {
     if (connectionData.tag == propagatorTag)
@@ -679,8 +683,8 @@ ISimulatorElement* VelocityScalingTemperatureCoupling::getElementPointerImpl(
             legacySimulatorData->inputrec->etc));
     auto* thermostat = static_cast<VelocityScalingTemperatureCoupling*>(element);
     // Capturing pointer is safe because lifetime is handled by caller
-    builderHelper->registerThermostat(
-            [thermostat, propagatorTag](const PropagatorThermostatConnection& connection) {
+    builderHelper->registerTemperaturePressureControl(
+            [thermostat, propagatorTag](const PropagatorConnection& connection) {
                 thermostat->connectWithMatchingPropagator(connection, propagatorTag);
             });
     return element;
index 8f068fa3315b68e2d031c0df066736dd4f1fdde3..489c9262adc09042ad4353af3ad57099ebdb8f60 100644 (file)
@@ -120,8 +120,8 @@ public:
     void elementTeardown() override {}
 
     //! Connect this to propagator
-    void connectWithMatchingPropagator(const PropagatorThermostatConnection& connectionData,
-                                       const PropagatorTag&                  propagatorTag);
+    void connectWithMatchingPropagator(const PropagatorConnection& connectionData,
+                                       const PropagatorTag&        propagatorTag);
 
     //! ICheckpointHelperClient write checkpoint implementation
     void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;