Make it possible to use FEP lambda states as a reaction coordinate in AWH. Atom masse...
[alexxy/gromacs.git] / src / gromacs / awh / biasstate.cpp
index c63749e889b35327d51193b0039ffa2493f870ee..24dec6dca849b0f8e883c1f10b4c717f8bda88fd 100644 (file)
@@ -53,6 +53,7 @@
 #include <cstring>
 
 #include <algorithm>
+#include <optional>
 
 #include "gromacs/fileio/gmxfio.h"
 #include "gromacs/fileio/xvgr.h"
@@ -207,12 +208,15 @@ double freeEnergyMinimumValue(gmx::ArrayRef<const PointState> pointState)
  * w(point|value) = exp(bias(point) - U(value,point)),
  * where U is a harmonic umbrella potential.
  *
- * \param[in] dimParams     The bias dimensions parameters
- * \param[in] points        The point state.
- * \param[in] grid          The grid.
- * \param[in] pointIndex    Point to evaluate probability weight for.
- * \param[in] pointBias     Bias for the point (as a log weight).
- * \param[in] value         Coordinate value.
+ * \param[in] dimParams              The bias dimensions parameters
+ * \param[in] points                 The point state.
+ * \param[in] grid                   The grid.
+ * \param[in] pointIndex             Point to evaluate probability weight for.
+ * \param[in] pointBias              Bias for the point (as a log weight).
+ * \param[in] value                  Coordinate value.
+ * \param[in] neighborLambdaEnergies The energy of the system in neighboring lambdas states. Can be
+ * empty when there are no free energy lambda state dimensions.
+ * \param[in] gridpointIndex         The index of the current grid point.
  * \returns the log of the biased probability weight.
  */
 double biasedLogWeightFromPoint(const std::vector<DimParams>&  dimParams,
@@ -220,11 +224,13 @@ double biasedLogWeightFromPoint(const std::vector<DimParams>&  dimParams,
                                 const BiasGrid&                grid,
                                 int                            pointIndex,
                                 double                         pointBias,
-                                const awh_dvec                 value)
+                                const awh_dvec                 value,
+                                gmx::ArrayRef<const double>    neighborLambdaEnergies,
+                                int                            gridpointIndex)
 {
     double logWeight = detail::c_largeNegativeExponent;
 
-    /* Only points in the target reigon have non-zero weight */
+    /* Only points in the target region have non-zero weight */
     if (points[pointIndex].inTargetRegion())
     {
         logWeight = pointBias;
@@ -232,14 +238,62 @@ double biasedLogWeightFromPoint(const std::vector<DimParams>&  dimParams,
         /* Add potential for all parameter dimensions */
         for (size_t d = 0; d < dimParams.size(); d++)
         {
-            double dev = getDeviationFromPointAlongGridAxis(grid, d, pointIndex, value[d]);
-            logWeight -= 0.5 * dimParams[d].betak * dev * dev;
+            if (dimParams[d].isFepLambdaDimension())
+            {
+                /* If this is not a sampling step or if this function is called from
+                 * calcConvolvedBias(), when writing energy subblocks, neighborLambdaEnergies will
+                 * be empty. No convolution is required along the lambda dimension. */
+                if (!neighborLambdaEnergies.empty())
+                {
+                    const int pointLambdaIndex     = grid.point(pointIndex).coordValue[d];
+                    const int gridpointLambdaIndex = grid.point(gridpointIndex).coordValue[d];
+                    logWeight -= dimParams[d].beta
+                                 * (neighborLambdaEnergies[pointLambdaIndex]
+                                    - neighborLambdaEnergies[gridpointLambdaIndex]);
+                }
+            }
+            else
+            {
+                double dev = getDeviationFromPointAlongGridAxis(grid, d, pointIndex, value[d]);
+                logWeight -= 0.5 * dimParams[d].betak * dev * dev;
+            }
         }
     }
-
     return logWeight;
 }
 
+/*! \brief
+ * Calculates the marginal distribution (marginal probability) for each value along
+ * a free energy lambda axis.
+ * The marginal distribution of one coordinate dimension value is the sum of the probability
+ * distribution of all values (herein all neighbor values) with the same value in the dimension
+ * of interest.
+ * \param[in] grid               The bias grid.
+ * \param[in] neighbors          The points to use for the calculation of the marginal distribution.
+ * \param[in] probWeightNeighbor Probability weights of the neighbors.
+ * \returns The calculated marginal distribution in a 1D array with
+ * as many elements as there are points along the axis of interest.
+ */
+std::vector<double> calculateFELambdaMarginalDistribution(const BiasGrid&          grid,
+                                                          gmx::ArrayRef<const int> neighbors,
+                                                          gmx::ArrayRef<const double> probWeightNeighbor)
+{
+    const std::optional<int> lambdaAxisIndex = grid.lambdaAxisIndex();
+    GMX_RELEASE_ASSERT(lambdaAxisIndex,
+                       "There must be a free energy lambda axis in order to calculate the free "
+                       "energy lambda marginal distribution.");
+    const int           numFepLambdaStates = grid.numFepLambdaStates();
+    std::vector<double> lambdaMarginalDistribution(numFepLambdaStates, 0);
+
+    for (size_t i = 0; i < neighbors.size(); i++)
+    {
+        const int neighbor    = neighbors[i];
+        const int lambdaState = grid.point(neighbor).coordValue[lambdaAxisIndex.value()];
+        lambdaMarginalDistribution[lambdaState] += probWeightNeighbor[i];
+    }
+    return lambdaMarginalDistribution;
+}
+
 } // namespace
 
 void BiasState::calcConvolvedPmf(const std::vector<DimParams>& dimParams,
@@ -267,7 +321,7 @@ void BiasState::calcConvolvedPmf(const std::vector<DimParams>& dimParams,
                Note that this function only adds point within the target > 0 region.
                Sum weights, take the logarithm last to get the free energy. */
             double logWeight = biasedLogWeightFromPoint(dimParams, points_, grid, neighbor,
-                                                        biasNeighbor, point.coordValue);
+                                                        biasNeighbor, point.coordValue, {}, m);
             freeEnergyWeights += std::exp(logWeight);
         }
 
@@ -416,19 +470,31 @@ int BiasState::warnForHistogramAnomalies(const BiasGrid& grid, int biasIndex, do
 double BiasState::calcUmbrellaForceAndPotential(const std::vector<DimParams>& dimParams,
                                                 const BiasGrid&               grid,
                                                 int                           point,
+                                                ArrayRef<const double>        neighborLambdaDhdl,
                                                 gmx::ArrayRef<double>         force) const
 {
     double potential = 0;
     for (size_t d = 0; d < dimParams.size(); d++)
     {
-        double deviation =
-                getDeviationFromPointAlongGridAxis(grid, d, point, coordState_.coordValue()[d]);
-
-        double k = dimParams[d].k;
+        if (dimParams[d].isFepLambdaDimension())
+        {
+            if (!neighborLambdaDhdl.empty())
+            {
+                const int coordpointLambdaIndex = grid.point(point).coordValue[d];
+                force[d]                        = neighborLambdaDhdl[coordpointLambdaIndex];
+                /* The potential should not be affected by the lambda dimension. */
+            }
+        }
+        else
+        {
+            double deviation =
+                    getDeviationFromPointAlongGridAxis(grid, d, point, coordState_.coordValue()[d]);
+            double k = dimParams[d].k;
 
-        /* Force from harmonic potential 0.5*k*dev^2 */
-        force[d] = -k * deviation;
-        potential += 0.5 * k * deviation * deviation;
+            /* Force from harmonic potential 0.5*k*dev^2 */
+            force[d] = -k * deviation;
+            potential += 0.5 * k * deviation * deviation;
+        }
     }
 
     return potential;
@@ -437,6 +503,7 @@ double BiasState::calcUmbrellaForceAndPotential(const std::vector<DimParams>& di
 void BiasState::calcConvolvedForce(const std::vector<DimParams>& dimParams,
                                    const BiasGrid&               grid,
                                    gmx::ArrayRef<const double>   probWeightNeighbor,
+                                   ArrayRef<const double>        neighborLambdaDhdl,
                                    gmx::ArrayRef<double>         forceWorkBuffer,
                                    gmx::ArrayRef<double>         force) const
 {
@@ -454,7 +521,7 @@ void BiasState::calcConvolvedForce(const std::vector<DimParams>& dimParams,
         int    indexNeighbor  = neighbor[n];
 
         /* Get the umbrella force from this point. The returned potential is ignored here. */
-        calcUmbrellaForceAndPotential(dimParams, grid, indexNeighbor, forceFromNeighbor);
+        calcUmbrellaForceAndPotential(dimParams, grid, indexNeighbor, neighborLambdaDhdl, forceFromNeighbor);
 
         /* Add the weighted umbrella force to the convolved force. */
         for (size_t d = 0; d < dimParams.size(); d++)
@@ -467,18 +534,25 @@ void BiasState::calcConvolvedForce(const std::vector<DimParams>& dimParams,
 double BiasState::moveUmbrella(const std::vector<DimParams>& dimParams,
                                const BiasGrid&               grid,
                                gmx::ArrayRef<const double>   probWeightNeighbor,
+                               ArrayRef<const double>        neighborLambdaDhdl,
                                gmx::ArrayRef<double>         biasForce,
                                int64_t                       step,
                                int64_t                       seed,
-                               int                           indexSeed)
+                               int                           indexSeed,
+                               bool                          onlySampleUmbrellaGridpoint)
 {
     /* Generate and set a new coordinate reference value */
     coordState_.sampleUmbrellaGridpoint(grid, coordState_.gridpointIndex(), probWeightNeighbor,
                                         step, seed, indexSeed);
 
+    if (onlySampleUmbrellaGridpoint)
+    {
+        return 0;
+    }
+
     std::vector<double> newForce(dimParams.size());
-    double              newPotential =
-            calcUmbrellaForceAndPotential(dimParams, grid, coordState_.umbrellaGridpoint(), newForce);
+    double              newPotential = calcUmbrellaForceAndPotential(
+            dimParams, grid, coordState_.umbrellaGridpoint(), neighborLambdaDhdl, newForce);
 
     /*  A modification of the reference value at time t will lead to a different
         force over t-dt/2 to t and over t to t+dt/2. For high switching rates
@@ -923,7 +997,17 @@ bool BiasState::isSamplingRegionCovered(const BiasParams&             params,
     double weightThreshold = 1;
     for (int d = 0; d < grid.numDimensions(); d++)
     {
-        weightThreshold *= grid.axis(d).spacing() * std::sqrt(dimParams[d].betak * 0.5 * M_1_PI);
+        if (grid.axis(d).isFepLambdaAxis())
+        {
+            /* TODO: Verify that a threshold of 1.0 is OK. With a very high sample weight 1.0 can be
+             * reached quickly even in regions with low probability. Should the sample weight be
+             * taken into account here? */
+            weightThreshold *= 1.0;
+        }
+        else
+        {
+            weightThreshold *= grid.axis(d).spacing() * std::sqrt(dimParams[d].betak * 0.5 * M_1_PI);
+        }
     }
 
     /* Project the sampling weights onto each dimension */
@@ -1154,6 +1238,7 @@ void BiasState::updateFreeEnergyAndAddSamplesToHistogram(const std::vector<DimPa
 
 double BiasState::updateProbabilityWeightsAndConvolvedBias(const std::vector<DimParams>& dimParams,
                                                            const BiasGrid&               grid,
+                                                           gmx::ArrayRef<const double> neighborLambdaEnergies,
                                                            std::vector<double, AlignedAllocator<double>>* weight) const
 {
     /* Only neighbors of the current coordinate value will have a non-negligible chance of getting sampled */
@@ -1179,9 +1264,9 @@ double BiasState::updateProbabilityWeightsAndConvolvedBias(const std::vector<Dim
             if (n < neighbors.size())
             {
                 const int neighbor = neighbors[n];
-                (*weight)[n] =
-                        biasedLogWeightFromPoint(dimParams, points_, grid, neighbor,
-                                                 points_[neighbor].bias(), coordState_.coordValue());
+                (*weight)[n]       = biasedLogWeightFromPoint(
+                        dimParams, points_, grid, neighbor, points_[neighbor].bias(),
+                        coordState_.coordValue(), neighborLambdaEnergies, coordState_.gridpointIndex());
             }
             else
             {
@@ -1201,6 +1286,30 @@ double BiasState::updateProbabilityWeightsAndConvolvedBias(const std::vector<Dim
 
     /* Normalize probabilities to sum to 1 */
     double invWeightSum = 1 / weightSum;
+
+    /* When there is a free energy lambda state axis remove the convolved contributions along that
+     * axis from the total bias. This must be done after calculating invWeightSum (since weightSum
+     * will be modified), but before normalizing the weights (below). */
+    if (grid.hasLambdaAxis())
+    {
+        /* If there is only one axis the bias will not be convolved in any dimension. */
+        if (grid.axis().size() == 1)
+        {
+            weightSum = gmx::exp(points_[coordState_.gridpointIndex()].bias());
+        }
+        else
+        {
+            for (size_t i = 0; i < neighbors.size(); i++)
+            {
+                const int neighbor = neighbors[i];
+                if (pointsHaveDifferentLambda(grid, coordState_.gridpointIndex(), neighbor))
+                {
+                    weightSum -= weightData[i];
+                }
+            }
+        }
+    }
+
     for (double& w : *weight)
     {
         w *= invWeightSum;
@@ -1221,8 +1330,13 @@ double BiasState::calcConvolvedBias(const std::vector<DimParams>& dimParams,
     double weightSum = 0;
     for (int neighbor : gridPoint.neighbor)
     {
+        /* No convolution is required along the lambda dimension. */
+        if (pointsHaveDifferentLambda(grid, point, neighbor))
+        {
+            continue;
+        }
         double logWeight = biasedLogWeightFromPoint(dimParams, points_, grid, neighbor,
-                                                    points_[neighbor].bias(), coordValue);
+                                                    points_[neighbor].bias(), coordValue, {}, point);
         weightSum += std::exp(logWeight);
     }
 
@@ -1284,9 +1398,10 @@ void BiasState::sampleProbabilityWeights(const BiasGrid& grid, gmx::ArrayRef<con
     }
 }
 
-void BiasState::sampleCoordAndPmf(const BiasGrid&             grid,
-                                  gmx::ArrayRef<const double> probWeightNeighbor,
-                                  double                      convolvedBias)
+void BiasState::sampleCoordAndPmf(const std::vector<DimParams>& dimParams,
+                                  const BiasGrid&               grid,
+                                  gmx::ArrayRef<const double>   probWeightNeighbor,
+                                  double                        convolvedBias)
 {
     /* Sampling-based deconvolution extracting the PMF.
      * Update the PMF histogram with the current coordinate value.
@@ -1301,13 +1416,60 @@ void BiasState::sampleCoordAndPmf(const BiasGrid&             grid,
      * it works (mainly because how the PMF histogram is rescaled).
      */
 
-    /* Only save coordinate data that is in range (the given index is always
-     * in range even if the coordinate value is not).
-     */
-    if (grid.covers(coordState_.coordValue()))
+    const int                gridPointIndex  = coordState_.gridpointIndex();
+    const std::optional<int> lambdaAxisIndex = grid.lambdaAxisIndex();
+
+    /* Update the PMF of points along a lambda axis with their bias. */
+    if (lambdaAxisIndex)
     {
-        /* Save PMF sum and keep a histogram of the sampled coordinate values */
-        points_[coordState_.gridpointIndex()].samplePmf(convolvedBias);
+        const std::vector<int>& neighbors = grid.point(gridPointIndex).neighbor;
+
+        std::vector<double> lambdaMarginalDistribution =
+                calculateFELambdaMarginalDistribution(grid, neighbors, probWeightNeighbor);
+
+        awh_dvec coordValueAlongLambda = { coordState_.coordValue()[0], coordState_.coordValue()[1],
+                                           coordState_.coordValue()[2], coordState_.coordValue()[3] };
+        for (size_t i = 0; i < neighbors.size(); i++)
+        {
+            const int neighbor = neighbors[i];
+            double    bias;
+            if (pointsAlongLambdaAxis(grid, gridPointIndex, neighbor))
+            {
+                const double neighborLambda = grid.point(neighbor).coordValue[lambdaAxisIndex.value()];
+                if (neighbor == gridPointIndex)
+                {
+                    bias = convolvedBias;
+                }
+                else
+                {
+                    coordValueAlongLambda[lambdaAxisIndex.value()] = neighborLambda;
+                    bias = calcConvolvedBias(dimParams, grid, coordValueAlongLambda);
+                }
+
+                const double probWeight   = lambdaMarginalDistribution[neighborLambda];
+                const double weightedBias = bias - std::log(probWeight);
+
+                if (neighbor == gridPointIndex && grid.covers(coordState_.coordValue()))
+                {
+                    points_[neighbor].samplePmf(weightedBias);
+                }
+                else
+                {
+                    points_[neighbor].updatePmfUnvisited(weightedBias);
+                }
+            }
+        }
+    }
+    else
+    {
+        /* Only save coordinate data that is in range (the given index is always
+         * in range even if the coordinate value is not).
+         */
+        if (grid.covers(coordState_.coordValue()))
+        {
+            /* Save PMF sum and keep a histogram of the sampled coordinate values */
+            points_[gridPointIndex].samplePmf(convolvedBias);
+        }
     }
 
     /* Save probability weights for the update */