};
/*! \internal
- * \brief Information needed to connect a propagator to a thermostat
+ * \brief Information needed to connect a propagator to a temperature and / or pressure coupling element
*/
-struct PropagatorThermostatConnection
+struct PropagatorConnection
{
- //! Function variable for setting velocity scaling variables.
+ //! The tag of the creating propagator
+ PropagatorTag tag;
+
+ //! Whether the propagator offers start velocity scaling
+ bool hasStartVelocityScaling() const
+ {
+ return setNumVelocityScalingVariables && getVelocityScalingCallback && getViewOnStartVelocityScaling;
+ }
+ //! Whether the propagator offers end velocity scaling
+ bool hasEndVelocityScaling() const
+ {
+ return setNumVelocityScalingVariables && getVelocityScalingCallback && getViewOnEndVelocityScaling;
+ }
+ //! Whether the propagator offers position scaling
+ bool hasPositionScaling() const
+ {
+ return setNumPositionScalingVariables && getPositionScalingCallback && getViewOnPositionScaling;
+ }
+ //! Whether the propagator offers Parrinello-Rahman scaling
+ bool hasParrinelloRahmanScaling() const
+ {
+ return getPRScalingCallback && getViewOnPRScalingMatrix;
+ }
+
+ //! Function object for setting velocity scaling variables
std::function<void(int, ScaleVelocities)> setNumVelocityScalingVariables;
- //! Function variable for receiving view on velocity scaling (before step).
+ //! Function object for setting velocity scaling variables
+ std::function<void(int)> setNumPositionScalingVariables;
+ //! Function object for receiving view on velocity scaling (before step)
std::function<ArrayRef<real>()> getViewOnStartVelocityScaling;
- //! Function variable for receiving view on velocity scaling (after step).
+ //! Function object for receiving view on velocity scaling (after step)
std::function<ArrayRef<real>()> getViewOnEndVelocityScaling;
- //! Function variable for callback.
+ //! Function object for receiving view on position scaling
+ std::function<ArrayRef<real>()> getViewOnPositionScaling;
+ //! Function object to request callback allowing to signal a velocity scaling step
std::function<PropagatorCallback()> getVelocityScalingCallback;
- //! The tag of the creating propagator
- PropagatorTag tag;
-};
-
-/*! \internal
- * \brief Information needed to connect a propagator to a barostat
- */
-struct PropagatorBarostatConnection
-{
- //! Function variable for receiving view on pressure scaling matrix.
+ //! Function object to request callback allowing to signal a position scaling step
+ std::function<PropagatorCallback()> getPositionScalingCallback;
+ //! Function object for receiving view on pressure scaling matrix
std::function<ArrayRef<rvec>()> getViewOnPRScalingMatrix;
- //! Function variable for callback.
+ //! Function object to request callback allowing to signal a Parrinello-Rahman scaling step
std::function<PropagatorCallback()> getPRScalingCallback;
- //! The tag of the creating propagator
- PropagatorTag tag;
};
//! /}
});
}
-void ParrinelloRahmanBarostat::connectWithMatchingPropagator(const PropagatorBarostatConnection& connectionData,
+void ParrinelloRahmanBarostat::connectWithMatchingPropagator(const PropagatorConnection& connectionData,
const PropagatorTag& propagatorTag)
{
if (connectionData.tag == propagatorTag)
{
+ GMX_RELEASE_ASSERT(connectionData.hasParrinelloRahmanScaling(),
+ "Connection data lacks Parrinello-Rahman scaling");
scalingTensor_ = connectionData.getViewOnPRScalingMatrix();
propagatorCallback_ = connectionData.getPRScalingCallback();
}
legacySimulatorData->inputrec,
legacySimulatorData->mdAtoms));
auto* barostat = static_cast<ParrinelloRahmanBarostat*>(element);
- builderHelper->registerBarostat([barostat, propagatorTag](const PropagatorBarostatConnection& connection) {
- barostat->connectWithMatchingPropagator(connection, propagatorTag);
- });
+ builderHelper->registerTemperaturePressureControl(
+ [barostat, propagatorTag](const PropagatorConnection& connection) {
+ barostat->connectWithMatchingPropagator(connection, propagatorTag);
+ });
return element;
}
[[nodiscard]] const rvec* boxVelocities() const;
//! Connect this to propagator
- void connectWithMatchingPropagator(const PropagatorBarostatConnection& connectionData,
- const PropagatorTag& propagatorTag);
+ void connectWithMatchingPropagator(const PropagatorConnection& connectionData,
+ const PropagatorTag& propagatorTag);
//! ICheckpointHelperClient write checkpoint implementation
void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;
return [this](Step step) { scalingStepPR_ = step; };
}
+template<IntegrationStage integrationStage>
+static PropagatorConnection getConnection(Propagator<integrationStage> gmx_unused* propagator,
+ const PropagatorTag& propagatorTag)
+{
+ // gmx_unused is needed because gcc-7 & gcc-8 can't see that
+ // propagator is used for all IntegrationStage options
+
+ PropagatorConnection propagatorConnection{ propagatorTag };
+
+ // The clang-tidy version on our current CI throws 3 different warnings
+ // for the if constexpr lines, so disable linting for now. Also, this only
+ // works if the brace is on the same line, so turn off clang-format as well
+ // clang-format off
+ // NOLINTNEXTLINE
+ if constexpr (hasStartVelocityScaling<integrationStage>() || hasEndVelocityScaling<integrationStage>()) {
+ // clang-format on
+ propagatorConnection.setNumVelocityScalingVariables =
+ [propagator](int num, ScaleVelocities scaleVelocities) {
+ propagator->setNumVelocityScalingVariables(num, scaleVelocities);
+ };
+ propagatorConnection.getVelocityScalingCallback = [propagator]() {
+ return propagator->velocityScalingCallback();
+ };
+ }
+ // clang-format off
+ // NOLINTNEXTLINE
+ if constexpr (hasStartVelocityScaling<integrationStage>()) {
+ // clang-format on
+ propagatorConnection.getViewOnStartVelocityScaling = [propagator]() {
+ return propagator->viewOnStartVelocityScaling();
+ };
+ }
+ // clang-format off
+ // NOLINTNEXTLINE
+ if constexpr (hasEndVelocityScaling<integrationStage>()) {
+ // clang-format on
+ propagatorConnection.getViewOnEndVelocityScaling = [propagator]() {
+ return propagator->viewOnEndVelocityScaling();
+ };
+ }
+ // clang-format off
+ // NOLINTNEXTLINE
+ if constexpr (hasPositionScaling<integrationStage>()) {
+ // clang-format on
+ propagatorConnection.setNumPositionScalingVariables = [propagator](int num) {
+ propagator->setNumPositionScalingVariables(num);
+ };
+ propagatorConnection.getViewOnPositionScaling = [propagator]() {
+ return propagator->viewOnPositionScaling();
+ };
+ propagatorConnection.getPositionScalingCallback = [propagator]() {
+ return propagator->positionScalingCallback();
+ };
+ }
+ // clang-format off
+ // NOLINTNEXTLINE
+ if constexpr (hasParrinelloRahmanScaling<integrationStage>()) {
+ // clang-format on
+ propagatorConnection.getViewOnPRScalingMatrix = [propagator]() {
+ return propagator->viewOnPRScalingMatrix();
+ };
+ propagatorConnection.getPRScalingCallback = [propagator]() {
+ return propagator->prScalingCallback();
+ };
+ }
+
+ // NOLINTNEXTLINE(readability-misleading-indentation)
+ return propagatorConnection;
+}
+
// doxygen is confused by the two definitions
//! \cond
template<IntegrationStage integrationStage>
auto* element = builderHelper->storeElement(std::make_unique<Propagator<integrationStage>>(
timestep, statePropagatorData, legacySimulatorData->mdAtoms, legacySimulatorData->wcycle));
auto* propagator = static_cast<Propagator<integrationStage>*>(element);
- builderHelper->registerWithThermostat(
- { [propagator](int num, ScaleVelocities scaleVelocities) {
- propagator->setNumVelocityScalingVariables(num, scaleVelocities);
- },
- [propagator]() { return propagator->viewOnStartVelocityScaling(); },
- [propagator]() { return propagator->viewOnEndVelocityScaling(); },
- [propagator]() { return propagator->velocityScalingCallback(); },
- propagatorTag });
- builderHelper->registerWithBarostat(
- { [propagator]() { return propagator->viewOnPRScalingMatrix(); },
- [propagator]() { return propagator->prScalingCallback(); },
- propagatorTag });
+ builderHelper->registerPropagator(getConnection<integrationStage>(propagator, propagatorTag));
return element;
}
#include "energydata.h"
#include "freeenergyperturbationdata.h"
#include "modularsimulator.h"
-#include "parrinellorahmanbarostat.h"
#include "pmeloadbalancehelper.h"
#include "propagator.h"
#include "statepropagatordata.h"
-#include "velocityscalingtemperaturecoupling.h"
namespace gmx
{
algorithmHasBeenBuilt_ = true;
// Connect propagators with thermostat / barostat
- for (const auto& thermostatRegistration : thermostatRegistrationFunctions_)
+ for (const auto& registrationFunction : pressureTemperatureControlRegistrationFunctions_)
{
- for (const auto& connection : propagatorThermostatConnections_)
+ for (const auto& connection : propagatorConnections_)
{
- thermostatRegistration(connection);
- }
- }
- for (const auto& barostatRegistration : barostatRegistrationFunctions_)
- {
- for (const auto& connection : propagatorBarostatConnections_)
- {
- barostatRegistration(connection);
+ registrationFunction(connection);
}
}
}
}
-void ModularSimulatorAlgorithmBuilderHelper::registerThermostat(
- std::function<void(const PropagatorThermostatConnection&)> registrationFunction)
+void ModularSimulatorAlgorithmBuilderHelper::registerTemperaturePressureControl(
+ std::function<void(const PropagatorConnection&)> registrationFunction)
{
- builder_->thermostatRegistrationFunctions_.emplace_back(std::move(registrationFunction));
+ builder_->pressureTemperatureControlRegistrationFunctions_.emplace_back(std::move(registrationFunction));
}
-void ModularSimulatorAlgorithmBuilderHelper::registerBarostat(
- std::function<void(const PropagatorBarostatConnection&)> registrationFunction)
+void ModularSimulatorAlgorithmBuilderHelper::registerPropagator(PropagatorConnection connectionData)
{
- builder_->barostatRegistrationFunctions_.emplace_back(std::move(registrationFunction));
+ builder_->propagatorConnections_.emplace_back(std::move(connectionData));
}
-void ModularSimulatorAlgorithmBuilderHelper::registerWithThermostat(PropagatorThermostatConnection connectionData)
-{
- builder_->propagatorThermostatConnections_.emplace_back(std::move(connectionData));
-}
-
-void ModularSimulatorAlgorithmBuilderHelper::registerWithBarostat(PropagatorBarostatConnection connectionData)
-{
- builder_->propagatorBarostatConnections_.emplace_back(std::move(connectionData));
-}
-
-
} // namespace gmx
void storeValue(const std::string& key, const ValueType& value);
//! Get previously stored data. Returns std::nullopt if key is not found.
std::optional<std::any> getStoredValue(const std::string& key) const;
- //! Register a thermostat that accepts propagator registrations
- void registerThermostat(std::function<void(const PropagatorThermostatConnection&)> registrationFunction);
- //! Register a barostat that accepts propagator registrations
- void registerBarostat(std::function<void(const PropagatorBarostatConnection&)> registrationFunction);
- //! Register a propagator to the thermostat used
- void registerWithThermostat(PropagatorThermostatConnection connectionData);
- //! Register a propagator to the barostat used
- void registerWithBarostat(PropagatorBarostatConnection connectionData);
+ //! Register temperature / pressure control algorithm to be matched with a propagator
+ void registerTemperaturePressureControl(std::function<void(const PropagatorConnection&)> registrationFunction);
+ //! Register a propagator to be used with a temperature / pressure control algorithm
+ void registerPropagator(PropagatorConnection connectionData);
private:
//! Pointer to the associated ModularSimulatorAlgorithmBuilder
*/
std::vector<ICheckpointHelperClient*> checkpointClients_;
- //! List of thermostat registration functions
- std::vector<std::function<void(const PropagatorThermostatConnection&)>> thermostatRegistrationFunctions_;
- //! List of barostat registration functions
- std::vector<std::function<void(const PropagatorBarostatConnection&)>> barostatRegistrationFunctions_;
- //! List of data to connect propagators to thermostats
- std::vector<PropagatorThermostatConnection> propagatorThermostatConnections_;
- //! List of data to connect propagators to barostats
- std::vector<PropagatorBarostatConnection> propagatorBarostatConnections_;
+ //! List of data to connect propagators to thermostats / barostats
+ std::vector<PropagatorConnection> propagatorConnections_;
+ //! List of temperature / pressure control registration functions
+ std::vector<std::function<void(const PropagatorConnection&)>> pressureTemperatureControlRegistrationFunctions_;
};
/*! \internal
{
public:
//! Allow access to the scaling vectors
- virtual void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
- int numTemperatureGroups) = 0;
+ virtual void connectWithPropagator(const PropagatorConnection& connectionData,
+ int numTemperatureGroups) = 0;
/*! \brief Make a temperature control step
*
}
//! Connect with propagator - v-rescale only scales start step velocities
- void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
- int numTemperatureGroups) override
+ void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
{
+ GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
+ "V-Rescale requires start velocity scaling.");
connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
}
}
//! Connect with propagator - Berendsen only scales start step velocities
- void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
- int numTemperatureGroups) override
+ void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
{
+ GMX_RELEASE_ASSERT(connectionData.hasStartVelocityScaling(),
+ "Berendsen T-coupling requires start velocity scaling.");
connectionData.setNumVelocityScalingVariables(numTemperatureGroups, ScaleVelocities::PreStepOnly);
lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
}
}
//! Connect with propagator - Nose-Hoover scales start and end step velocities
- void connectWithPropagator(const PropagatorThermostatConnection& connectionData,
- int numTemperatureGroups) override
+ void connectWithPropagator(const PropagatorConnection& connectionData, int numTemperatureGroups) override
{
+ GMX_RELEASE_ASSERT(
+ connectionData.hasStartVelocityScaling() && connectionData.hasEndVelocityScaling(),
+ "Nose-Hoover T-coupling requires both start and end velocity scaling.");
connectionData.setNumVelocityScalingVariables(numTemperatureGroups,
ScaleVelocities::PreStepAndPostStep);
lambdaStartVelocities_ = connectionData.getViewOnStartVelocityScaling();
});
}
-void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorThermostatConnection& connectionData,
+void VelocityScalingTemperatureCoupling::connectWithMatchingPropagator(const PropagatorConnection& connectionData,
const PropagatorTag& propagatorTag)
{
if (connectionData.tag == propagatorTag)
legacySimulatorData->inputrec->etc));
auto* thermostat = static_cast<VelocityScalingTemperatureCoupling*>(element);
// Capturing pointer is safe because lifetime is handled by caller
- builderHelper->registerThermostat(
- [thermostat, propagatorTag](const PropagatorThermostatConnection& connection) {
+ builderHelper->registerTemperaturePressureControl(
+ [thermostat, propagatorTag](const PropagatorConnection& connection) {
thermostat->connectWithMatchingPropagator(connection, propagatorTag);
});
return element;
void elementTeardown() override {}
//! Connect this to propagator
- void connectWithMatchingPropagator(const PropagatorThermostatConnection& connectionData,
- const PropagatorTag& propagatorTag);
+ void connectWithMatchingPropagator(const PropagatorConnection& connectionData,
+ const PropagatorTag& propagatorTag);
//! ICheckpointHelperClient write checkpoint implementation
void saveCheckpointState(std::optional<WriteCheckpointData> checkpointData, const t_commrec* cr) override;