Adaptive force scaling for densityfitting
authorChristian Blau <cblau@gwdg.de>
Tue, 24 Sep 2019 14:32:28 +0000 (16:32 +0200)
committerChristian Blau <cblau@gwdg.de>
Fri, 18 Oct 2019 06:30:11 +0000 (08:30 +0200)
Scales the force adaptively for density-guided-simulations

refs #2282

Change-Id: I96310f498cf2fae9f7385a9396f1253b760d135e

12 files changed:
src/gromacs/applied_forces/densityfitting.cpp
src/gromacs/applied_forces/densityfittingforceprovider.cpp
src/gromacs/applied_forces/densityfittingforceprovider.h
src/gromacs/applied_forces/densityfittingoptions.cpp
src/gromacs/applied_forces/densityfittingoptions.h
src/gromacs/applied_forces/densityfittingparameters.h
src/gromacs/applied_forces/tests/densityfittingoptions.cpp
src/gromacs/math/exponentialmovingaverage.cpp
src/gromacs/math/exponentialmovingaverage.h
src/gromacs/math/tests/exponentialmovingaverage.cpp
src/gromacs/mdrun/runner.cpp
src/gromacs/utility/mdmodulenotification.h

index d067bcb727caeffaa53bedeea5a5ca615a3e8638..eb981f9ebfac96a12a89a4ceaa773bbf86a46e6c 100644 (file)
@@ -194,6 +194,18 @@ class DensityFittingSimulationParameterSetup
             return *pbcType_;
         }
 
+        //! Set the simulation time step
+        void setSimulationTimeStep(double timeStep)
+        {
+            simulationTimeStep_ = timeStep;
+        }
+
+        //! Return the simulation time step
+        double simulationTimeStep()
+        {
+            return simulationTimeStep_;
+        }
+
     private:
         //! The reference density to fit to
         std::unique_ptr<MultiDimArray<std::vector<float>, dynamicExtents3D> > referenceDensity_;
@@ -203,6 +215,8 @@ class DensityFittingSimulationParameterSetup
         std::unique_ptr<LocalAtomSet>      localAtomSet_;
         //! The type of periodic boundary conditions in the simulation
         std::unique_ptr<int>               pbcType_;
+        //! The simulation time step
+        double simulationTimeStep_ = 1;
 
         GMX_DISALLOW_COPY_AND_ASSIGN(DensityFittingSimulationParameterSetup);
 };
@@ -281,6 +295,12 @@ class DensityFitting final : public IMDModule
                 };
             notifier->notifier_.subscribe(setPeriodicBoundaryContionsFunction);
 
+            // setting the simulation time step
+            const auto setSimulationTimeStepFunction = [this](const SimulationTimeStep &simulationTimeStep) {
+                    this->densityFittingSimulationParameters_.setSimulationTimeStep(simulationTimeStep.delta_t);
+                };
+            notifier->notifier_.subscribe(setSimulationTimeStepFunction);
+
             // adding output to energy file
             const auto requestEnergyOutput
                 = [this](MdModulesEnergyOutputToDensityFittingRequestChecker *energyOutputRequest) {
@@ -327,6 +347,7 @@ class DensityFitting final : public IMDModule
                             densityFittingSimulationParameters_.transformationToDensityLattice(),
                             densityFittingSimulationParameters_.localAtomSet(),
                             densityFittingSimulationParameters_.periodicBoundaryConditionType(),
+                            densityFittingSimulationParameters_.simulationTimeStep(),
                             densityFittingState_);
                 forceProviders->addForceProvider(forceProvider_.get());
             }
@@ -373,6 +394,13 @@ class DensityFitting final : public IMDModule
                 checkpointWriting.builder_.addValue<std::int64_t>(
                         DensityFittingModuleInfo::name_ + "-stepsSinceLastCalculation",
                         state.stepsSinceLastCalculation_);
+                checkpointWriting.builder_.addValue<real>(
+                        DensityFittingModuleInfo::name_ + "-adaptiveForceConstantScale",
+                        state.adaptiveForceConstantScale_);
+                KeyValueTreeObjectBuilder exponentialMovingAverageKvtEntry =
+                    checkpointWriting.builder_.addObject(
+                            DensityFittingModuleInfo::name_ + "-exponentialMovingAverageState");
+                exponentialMovingAverageStateAsKeyValueTree(exponentialMovingAverageKvtEntry, state.exponentialMovingAverageState_);
             }
         }
 
@@ -390,6 +418,19 @@ class DensityFitting final : public IMDModule
                         checkpointReading.checkpointedData_[
                             DensityFittingModuleInfo::name_ + "-stepsSinceLastCalculation"].cast<std::int64_t>();
                 }
+                if (checkpointReading.checkpointedData_.keyExists(DensityFittingModuleInfo::name_ + "-adaptiveForceConstantScale"))
+                {
+                    densityFittingState_.adaptiveForceConstantScale_
+                        = checkpointReading
+                                .checkpointedData_[DensityFittingModuleInfo::name_ + "-adaptiveForceConstantScale"].cast<real>();
+                }
+                if (checkpointReading.checkpointedData_.keyExists(
+                            DensityFittingModuleInfo::name_ + "-exponentialMovingAverageState"))
+                {
+                    densityFittingState_.exponentialMovingAverageState_ =
+                        exponentialMovingAverageStateFromKeyValueTree(checkpointReading
+                                                                          .checkpointedData_[DensityFittingModuleInfo::name_ + "-exponentialMovingAverageState"].asObject());
+                }
             }
         }
 
@@ -404,6 +445,8 @@ class DensityFitting final : public IMDModule
                 if (PAR(&(checkpointBroadcast.cr_)))
                 {
                     block_bc(&(checkpointBroadcast.cr_), densityFittingState_.stepsSinceLastCalculation_);
+                    block_bc(&(checkpointBroadcast.cr_), densityFittingState_.adaptiveForceConstantScale_);
+                    block_bc(&(checkpointBroadcast.cr_), densityFittingState_.exponentialMovingAverageState_);
                 }
             }
         }
index b96dd0285056f11b0274d422baef055498717844..5815001897d65a3a452d2fc41595e5cf026529ae 100644 (file)
 
 #include <numeric>
 
+#include "gromacs/compat/optional.h"
 #include "gromacs/gmxlib/network.h"
 #include "gromacs/math/densityfit.h"
 #include "gromacs/math/densityfittingforce.h"
+#include "gromacs/math/exponentialmovingaverage.h"
 #include "gromacs/math/gausstransform.h"
 #include "gromacs/mdtypes/commrec.h"
 #include "gromacs/mdtypes/enerdata.h"
@@ -100,6 +102,7 @@ class DensityFittingForceProvider::Impl
              const TranslateAndScale &transformationToDensityLattice,
              const LocalAtomSet &localAtomSet,
              int pbcType,
+             double simulationTimeStep,
              const DensityFittingForceProviderState &state);
         ~Impl();
         void calculateForces(const ForceProviderInput &forceProviderInput, ForceProviderOutput *forceProviderOutput);
@@ -122,8 +125,9 @@ class DensityFittingForceProvider::Impl
         TranslateAndScale                     transformationToDensityLattice_;
         RVec                                  referenceDensityCenter_;
         int                                   pbcType_;
-        real                                  forceConstantScale_;
 
+        //! Optionally scale the force according to a moving average of the similarity
+        compat::optional<ExponentialMovingAverage> expAverageSimilarity_;
 };
 
 DensityFittingForceProvider::Impl::~Impl() = default;
@@ -133,6 +137,7 @@ DensityFittingForceProvider::Impl::Impl(const DensityFittingParameters &paramete
                                         const TranslateAndScale &transformationToDensityLattice,
                                         const LocalAtomSet &localAtomSet,
                                         int pbcType,
+                                        double simulationTimeStep,
                                         const DensityFittingForceProviderState &state) :
     parameters_(parameters),
     state_(state),
@@ -147,8 +152,16 @@ DensityFittingForceProvider::Impl::Impl(const DensityFittingParameters &paramete
     amplitudeLookup_(parameters_.amplitudeLookupMethod_),
     transformationToDensityLattice_(transformationToDensityLattice),
     pbcType_(pbcType),
-    forceConstantScale_(parameters_.calculationIntervalInSteps_)
+    expAverageSimilarity_(compat::nullopt)
 {
+    if (parameters_.adaptiveForceScaling_)
+    {
+        GMX_ASSERT(simulationTimeStep > 0, "Simulation time step must be larger than zero for adaptive for scaling.");
+        expAverageSimilarity_.emplace(ExponentialMovingAverage(
+                                              parameters_.adaptiveForceScalingTimeConstant_
+                                              / (simulationTimeStep * parameters_.calculationIntervalInSteps_),
+                                              state.exponentialMovingAverageState_));
+    }
     referenceDensityCenter_  = {
         real(referenceDensity.extent(XX))/2,
         real(referenceDensity.extent(YY))/2,
@@ -256,23 +269,43 @@ void DensityFittingForceProvider::Impl::calculateForces(const ForceProviderInput
 
     transformationToDensityLattice_.scaleOperationOnly().inverseIgnoringZeroScale(forces_);
 
-    auto densityForceIterator = forces_.cbegin();
+    auto       densityForceIterator   = forces_.cbegin();
+    const real effectiveForceConstant = state_.adaptiveForceConstantScale_ *
+        parameters_.calculationIntervalInSteps_ * parameters_.forceConstant_;
     for (const auto localAtomIndex : localAtomSet_.localIndex())
     {
-        forceProviderOutput->forceWithVirial_.force_[localAtomIndex] +=
-            forceConstantScale_ * parameters_.forceConstant_ * *densityForceIterator;
+        forceProviderOutput->forceWithVirial_.force_[localAtomIndex]
+            += effectiveForceConstant * *densityForceIterator;
         ++densityForceIterator;
     }
 
     // calculate corresponding potential energy
     const float similarity  = measure_.similarity(gaussTransform_.constView());
-    const real  energy      = -similarity * parameters_.forceConstant_;
+    const real  energy      = -similarity * parameters_.forceConstant_ * state_.adaptiveForceConstantScale_;
     forceProviderOutput->enerd_.term[F_DENSITYFITTING] += energy;
+
+    if (expAverageSimilarity_.has_value())
+    {
+        expAverageSimilarity_->updateWithDataPoint(similarity);
+
+        if (expAverageSimilarity_->increasing())
+        {
+            state_.adaptiveForceConstantScale_ /= 1._real + expAverageSimilarity_->inverseTimeConstant();
+        }
+        else
+        {
+            state_.adaptiveForceConstantScale_ *= 1._real + expAverageSimilarity_->inverseTimeConstant();
+        }
+    }
 }
 
 DensityFittingForceProviderState
 DensityFittingForceProvider::Impl::state()
 {
+    if (expAverageSimilarity_.has_value())
+    {
+        state_.exponentialMovingAverageState_ = expAverageSimilarity_->state();
+    }
     return state_;
 }
 
@@ -287,8 +320,9 @@ DensityFittingForceProvider::DensityFittingForceProvider(const DensityFittingPar
                                                          const TranslateAndScale &transformationToDensityLattice,
                                                          const LocalAtomSet &localAtomSet,
                                                          int pbcType,
+                                                         double simulationTimeStep,
                                                          const DensityFittingForceProviderState &state)
-    : impl_(new Impl(parameters, referenceDensity, transformationToDensityLattice, localAtomSet, pbcType, state))
+    : impl_(new Impl(parameters, referenceDensity, transformationToDensityLattice, localAtomSet, pbcType, simulationTimeStep, state))
 {}
 
 void DensityFittingForceProvider::calculateForces(const ForceProviderInput  &forceProviderInput,
index 70b8dbf71e53bde677bd2f2adeeef7b60ba417a0..766fbef58e51584f095d9cdf0665b84024aa5d86 100644 (file)
@@ -46,6 +46,7 @@
 
 #include "gromacs/domdec/localatomset.h"
 #include "gromacs/math/coordinatetransformation.h"
+#include "gromacs/math/exponentialmovingaverage.h"
 #include "gromacs/mdspan/extensions.h"
 #include "gromacs/mdtypes/iforceprovider.h"
 #include "gromacs/utility/classhelpers.h"
@@ -63,7 +64,11 @@ struct DensityFittingForceProviderState
     /*! \brief The steps since the last force calculation.
      *  Used if density fitting is to be calculated every N steps.
      */
-    std::int64_t stepsSinceLastCalculation_ = 0;
+    std::int64_t                  stepsSinceLastCalculation_ = 0;
+    //! The state of the exponential moving average of the similarity measure
+    ExponentialMovingAverageState exponentialMovingAverageState_ = {};
+    //! An additional factor scaling the force for adaptive force scaling
+    real                          adaptiveForceConstantScale_ = 1.0_real;
 };
 
 /*! \internal \brief
@@ -78,6 +83,7 @@ class DensityFittingForceProvider final : public IForceProvider
                                     const TranslateAndScale &transformationToDensityLattice,
                                     const LocalAtomSet &localAtomSet,
                                     int pbcType,
+                                    double simulationTimeStep,
                                     const DensityFittingForceProviderState &state);
         ~DensityFittingForceProvider();
         /*!\brief Calculate forces that maximise goodness-of-fit with a reference density map.
index bf9f89aa127a98f38e5840478ca0a330f115e55f..4243ba0446faddf6e0cd7c87b628c61b42b4ee2e 100644 (file)
@@ -140,6 +140,8 @@ void DensityFittingOptions::initMdpTransform(IKeyValueTreeTransformRules * rules
     densityfittingMdpTransformFromString<std::string>(rules, stringIdentityTransform, c_referenceDensityFileNameTag_);
     densityfittingMdpTransformFromString<std::int64_t>(rules, &fromStdString<std::int64_t>, c_everyNStepsTag_);
     densityfittingMdpTransformFromString<bool>(rules, &fromStdString<bool>, c_normalizeDensitiesTag_);
+    densityfittingMdpTransformFromString<bool>(rules, &fromStdString<bool>, c_adaptiveForceScalingTag_);
+    densityfittingMdpTransformFromString<real>(rules, &fromStdString<real>, c_adaptiveForceScalingTimeConstantTag_);
 }
 
 void DensityFittingOptions::buildMdpOutput(KeyValueTreeObjectBuilder *builder) const
@@ -171,6 +173,10 @@ void DensityFittingOptions::buildMdpOutput(KeyValueTreeObjectBuilder *builder) c
         addDensityFittingMdpOutputValue(builder, parameters_.calculationIntervalInSteps_, c_everyNStepsTag_);
         addDensityFittingMdpOutputValueComment(builder, "; Normalize the sum of density voxel values to one", c_normalizeDensitiesTag_);
         addDensityFittingMdpOutputValue(builder, parameters_.normalizeDensities_, c_normalizeDensitiesTag_);
+        addDensityFittingMdpOutputValueComment(builder, "; Apply adaptive force scaling", c_adaptiveForceScalingTag_);
+        addDensityFittingMdpOutputValue(builder, parameters_.adaptiveForceScaling_, c_adaptiveForceScalingTag_);
+        addDensityFittingMdpOutputValueComment(builder, "; Time constant for adaptive force scaling in ps", c_adaptiveForceScalingTimeConstantTag_);
+        addDensityFittingMdpOutputValue(builder, parameters_.adaptiveForceScalingTimeConstant_, c_adaptiveForceScalingTimeConstantTag_);
     }
 }
 
@@ -195,6 +201,8 @@ void DensityFittingOptions::initMdpOptions(IOptionsContainerWithSections *option
     section.addOption(StringOption(c_referenceDensityFileNameTag_.c_str()).store(&referenceDensityFileName_));
     section.addOption(Int64Option(c_everyNStepsTag_.c_str()).store(&parameters_.calculationIntervalInSteps_));
     section.addOption(BooleanOption(c_normalizeDensitiesTag_.c_str()).store(&parameters_.normalizeDensities_));
+    section.addOption(BooleanOption(c_adaptiveForceScalingTag_.c_str()).store(&parameters_.adaptiveForceScaling_));
+    section.addOption(RealOption(c_adaptiveForceScalingTimeConstantTag_.c_str()).store(&parameters_.adaptiveForceScalingTimeConstant_));
 }
 
 bool DensityFittingOptions::active() const
index 034f3d3a3354bfbe4840d7058a07bceae19c674b..af5d24df02af5b3f5127b044ec6d611992b0577e 100644 (file)
@@ -131,6 +131,10 @@ class DensityFittingOptions final : public IMdpOptionProvider
 
         const std::string c_normalizeDensitiesTag_ = "normalize-densities";
 
+        const std::string c_adaptiveForceScalingTag_ = "adaptive-force-scaling";
+
+        const std::string c_adaptiveForceScalingTimeConstantTag_ = "adaptive-force-scaling-time-constant";
+
 
         DensityFittingParameters parameters_;
 };
index d81f12e2e049440df364a15e8f2462b29ee57a0b..9e576fcb02a4d546a4bf0771d01e3f80e8d0ad1e 100644 (file)
@@ -77,6 +77,10 @@ struct DensityFittingParameters
     std::int64_t                   calculationIntervalInSteps_ = 1;
     //! Normalize reference and simulated densities
     bool                           normalizeDensities_ = true;
+    //! Perform adaptive force scaling during the simulation
+    bool                           adaptiveForceScaling_ = false;
+    //! The time constant for the adaptive force scaling in ps
+    real                           adaptiveForceScalingTimeConstant_ = 4;
 };
 
 /*!\brief Check if two structs holding density fitting parameters are equal.
index 6999b133badc29b111c9fe8ad23f6f4e6b445908..88514dc30f1ad61b759f80322378045d09004e59 100644 (file)
@@ -207,6 +207,10 @@ TEST_F(DensityFittingOptionsTest, OutputDefaultValuesWhenActive)
         "density-guided-simulation-nst = 1\n"
         "; Normalize the sum of density voxel values to one\n"
         "density-guided-simulation-normalize-densities = true\n"
+        "; Apply adaptive force scaling\n"
+        "density-guided-simulation-adaptive-force-scaling = false\n"
+        "; Time constant for adaptive force scaling in ps\n"
+        "density-guided-simulation-adaptive-force-scaling-time-constant = 4\n"
         };
 
     EXPECT_EQ(expectedString, stream.toString());
index 36275e4d29c56c77c28d91f6038fff56db6a5fd5..59646f6aae57d2825a85c5ac8a24db138536e810 100644 (file)
 #include "exponentialmovingaverage.h"
 
 #include "gromacs/utility/exceptions.h"
+#include "gromacs/utility/keyvaluetree.h"
 
 namespace gmx
 {
 
-ExponentialMovingAverage::ExponentialMovingAverage(real timeConstant, const ExponentialMovingAverageState &state) :
-    state_(state)
+//! Convert the exponential moving average state as key-value-tree object
+void exponentialMovingAverageStateAsKeyValueTree(KeyValueTreeObjectBuilder builder, const ExponentialMovingAverageState &state)
+{
+    builder.addValue<real>("weighted-sum", state.weightedSum_);
+    builder.addValue<real>("weighted-count", state.weightedCount_);
+    builder.addValue<bool>("increasing", state.increasing_);
+}
+
+//! Sets the exponential moving average state from a key-value-tree object
+ExponentialMovingAverageState
+exponentialMovingAverageStateFromKeyValueTree(const KeyValueTreeObject &object)
+{
+    const real weightedSum   = object["weighted-sum"].cast<real>();
+    const real weightedCount = object["weighted-count"].cast<real>();
+    const bool increasing    = object["increasing"].cast<bool>();
+    return {weightedSum, weightedCount, increasing};
+}
+
+ExponentialMovingAverage::ExponentialMovingAverage(real timeConstant, const ExponentialMovingAverageState &state)
+    : state_(state)
 {
     if (timeConstant < 1)
     {
index 2c11bc7b858ca572756b650c746bd3803fa7372b..f2c705d947ad40f16fafa70f13e1666078a57aad 100644 (file)
@@ -43,6 +43,7 @@
 #ifndef GMX_MATH_EXPONENTIALMOVINGAVERAGE_H
 #define GMX_MATH_EXPONENTIALMOVINGAVERAGE_H
 
+#include "gromacs/utility/keyvaluetreebuilder.h"
 #include "gromacs/utility/real.h"
 
 namespace gmx
@@ -62,6 +63,13 @@ struct ExponentialMovingAverageState
     bool increasing_ = false;
 };
 
+//! Convert the exponential moving average state as key-value-tree object
+void exponentialMovingAverageStateAsKeyValueTree(KeyValueTreeObjectBuilder builder, const ExponentialMovingAverageState &state);
+
+//! Sets the expoential moving average state from a key-value-tree object
+ExponentialMovingAverageState
+exponentialMovingAverageStateFromKeyValueTree(const KeyValueTreeObject &object);
+
 /*! \libinternal
  * \brief Evaluate the exponential moving average with bias correction.
  *
index 5aa361534314de84214e82c9169cbc084f2c108d..338cb03ffcb404a6667d293c3213e8db39366b5f 100644 (file)
@@ -139,6 +139,22 @@ TEST(ExponentialMovingAverage, InverseLagTimeCorrect)
     EXPECT_REAL_EQ(0.5, exponentialMovingAverage.inverseTimeConstant());
 }
 
+TEST(ExponentialMovingAverage, RoundTripAsKeyValueTree)
+{
+    KeyValueTreeBuilder           builder;
+    const real                    weightedSum   = 9;
+    const real                    weightedCount = 1;
+    const bool                    increasing    = true;
+    ExponentialMovingAverageState state         = {weightedSum, weightedCount, increasing};
+    exponentialMovingAverageStateAsKeyValueTree(builder.rootObject(), state);
+    state = {};
+    KeyValueTreeObject result = builder.build();
+    state = exponentialMovingAverageStateFromKeyValueTree(result);
+    EXPECT_EQ(weightedSum, state.weightedSum_);
+    EXPECT_EQ(weightedCount, state.weightedCount_);
+    EXPECT_EQ(increasing, state.increasing_);
+}
+
 } // namespace
 
 } // namespace test
index 170e1a3f590bd38e852f96b8e33dba2f4d33c81c..719d300d073384807ed71a874e631cb55a76578e 100644 (file)
@@ -1291,6 +1291,7 @@ int Mdrunner::mdrunner()
         mdModulesNotifier.notify(*cr);
         mdModulesNotifier.notify(&atomSets);
         mdModulesNotifier.notify(PeriodicBoundaryConditionType {inputrec->ePBC});
+        mdModulesNotifier.notify(SimulationTimeStep { inputrec->delta_t });
         /* Initiate forcerecord */
         fr                 = new t_forcerec;
         fr->forceProviders = mdModules_->initForceProviders();
index 654007ecae860fb1d05b5741242c6a0625dbd1b6..0d8ecb39532cb4159afee4604208f32588695ba5 100644 (file)
@@ -248,6 +248,12 @@ class EnergyCalculationFrequencyErrors
         std::vector<std::string> errorMessages_;
 };
 
+struct SimulationTimeStep
+{
+    //! Time step (ps)
+    const double delta_t;
+};
+
 struct MdModulesNotifier
 {
 //! Register callback function types for MdModule
@@ -262,7 +268,8 @@ struct MdModulesNotifier
         MdModulesCheckpointReadingDataOnMaster,
         MdModulesCheckpointReadingBroadcast,
         MdModulesWriteCheckpointData,
-        PeriodicBoundaryConditionType>::type notifier_;
+        PeriodicBoundaryConditionType,
+        const SimulationTimeStep &>::type notifier_;
 };
 
 } // namespace gmx