Merge remote-tracking branch 'origin/release-2021' into master
[alexxy/gromacs.git] / src / gromacs / applied_forces / awh / biasstate.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2015,2016,2017,2018,2019, The GROMACS development team.
5  * Copyright (c) 2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 /*! \internal \file
38  * \brief
39  * Implements the BiasState class.
40  *
41  * \author Viveca Lindahl
42  * \author Berk Hess <hess@kth.se>
43  * \ingroup module_awh
44  */
45
46 #include "gmxpre.h"
47
48 #include "biasstate.h"
49
50 #include <cassert>
51 #include <cmath>
52 #include <cstdio>
53 #include <cstdlib>
54 #include <cstring>
55
56 #include <algorithm>
57 #include <optional>
58
59 #include "gromacs/fileio/gmxfio.h"
60 #include "gromacs/fileio/xvgr.h"
61 #include "gromacs/gmxlib/network.h"
62 #include "gromacs/math/units.h"
63 #include "gromacs/math/utilities.h"
64 #include "gromacs/mdrunutility/multisim.h"
65 #include "gromacs/mdtypes/awh_history.h"
66 #include "gromacs/mdtypes/awh_params.h"
67 #include "gromacs/mdtypes/commrec.h"
68 #include "gromacs/simd/simd.h"
69 #include "gromacs/simd/simd_math.h"
70 #include "gromacs/utility/arrayref.h"
71 #include "gromacs/utility/exceptions.h"
72 #include "gromacs/utility/gmxassert.h"
73 #include "gromacs/utility/smalloc.h"
74 #include "gromacs/utility/stringutil.h"
75
76 #include "biasgrid.h"
77 #include "pointstate.h"
78
79 namespace gmx
80 {
81
82 void BiasState::getPmf(gmx::ArrayRef<float> pmf) const
83 {
84     GMX_ASSERT(pmf.size() == points_.size(), "pmf should have the size of the bias grid");
85
86     /* The PMF is just the negative of the log of the sampled PMF histogram.
87      * Points with zero target weight are ignored, they will mostly contain noise.
88      */
89     for (size_t i = 0; i < points_.size(); i++)
90     {
91         pmf[i] = points_[i].inTargetRegion() ? -points_[i].logPmfSum() : GMX_FLOAT_MAX;
92     }
93 }
94
95 namespace
96 {
97
98 /*! \brief
99  * Sum an array over all simulations on the master rank of each simulation.
100  *
101  * \param[in,out] arrayRef      The data to sum.
102  * \param[in]     multiSimComm  Struct for multi-simulation communication.
103  */
104 void sumOverSimulations(gmx::ArrayRef<int> arrayRef, const gmx_multisim_t* multiSimComm)
105 {
106     gmx_sumi_sim(arrayRef.size(), arrayRef.data(), multiSimComm);
107 }
108
109 /*! \brief
110  * Sum an array over all simulations on the master rank of each simulation.
111  *
112  * \param[in,out] arrayRef      The data to sum.
113  * \param[in]     multiSimComm  Struct for multi-simulation communication.
114  */
115 void sumOverSimulations(gmx::ArrayRef<double> arrayRef, const gmx_multisim_t* multiSimComm)
116 {
117     gmx_sumd_sim(arrayRef.size(), arrayRef.data(), multiSimComm);
118 }
119
120 /*! \brief
121  * Sum an array over all simulations on all ranks of each simulation.
122  *
123  * This assumes the data is identical on all ranks within each simulation.
124  *
125  * \param[in,out] arrayRef      The data to sum.
126  * \param[in]     commRecord    Struct for intra-simulation communication.
127  * \param[in]     multiSimComm  Struct for multi-simulation communication.
128  */
129 template<typename T>
130 void sumOverSimulations(gmx::ArrayRef<T> arrayRef, const t_commrec* commRecord, const gmx_multisim_t* multiSimComm)
131 {
132     if (MASTER(commRecord))
133     {
134         sumOverSimulations(arrayRef, multiSimComm);
135     }
136     if (commRecord->nnodes > 1)
137     {
138         gmx_bcast(arrayRef.size() * sizeof(T), arrayRef.data(), commRecord->mpi_comm_mygroup);
139     }
140 }
141
142 /*! \brief
143  * Sum PMF over multiple simulations, when requested.
144  *
145  * \param[in,out] pointState         The state of the points in the bias.
146  * \param[in]     numSharedUpdate    The number of biases sharing the histogram.
147  * \param[in]     commRecord         Struct for intra-simulation communication.
148  * \param[in]     multiSimComm       Struct for multi-simulation communication.
149  */
150 void sumPmf(gmx::ArrayRef<PointState> pointState,
151             int                       numSharedUpdate,
152             const t_commrec*          commRecord,
153             const gmx_multisim_t*     multiSimComm)
154 {
155     if (numSharedUpdate == 1)
156     {
157         return;
158     }
159     GMX_ASSERT(multiSimComm != nullptr && numSharedUpdate % multiSimComm->numSimulations_ == 0,
160                "numSharedUpdate should be a multiple of multiSimComm->numSimulations_");
161     GMX_ASSERT(numSharedUpdate == multiSimComm->numSimulations_,
162                "Sharing within a simulation is not implemented (yet)");
163
164     std::vector<double> buffer(pointState.size());
165
166     /* Need to temporarily exponentiate the log weights to sum over simulations */
167     for (size_t i = 0; i < buffer.size(); i++)
168     {
169         buffer[i] = pointState[i].inTargetRegion() ? std::exp(pointState[i].logPmfSum()) : 0;
170     }
171
172     sumOverSimulations(gmx::ArrayRef<double>(buffer), commRecord, multiSimComm);
173
174     /* Take log again to get (non-normalized) PMF */
175     double normFac = 1.0 / numSharedUpdate;
176     for (gmx::index i = 0; i < pointState.ssize(); i++)
177     {
178         if (pointState[i].inTargetRegion())
179         {
180             pointState[i].setLogPmfSum(std::log(buffer[i] * normFac));
181         }
182     }
183 }
184
185 /*! \brief
186  * Find the minimum free energy value.
187  *
188  * \param[in] pointState  The state of the points.
189  * \returns the minimum free energy value.
190  */
191 double freeEnergyMinimumValue(gmx::ArrayRef<const PointState> pointState)
192 {
193     double fMin = GMX_FLOAT_MAX;
194
195     for (auto const& ps : pointState)
196     {
197         if (ps.inTargetRegion() && ps.freeEnergy() < fMin)
198         {
199             fMin = ps.freeEnergy();
200         }
201     }
202
203     return fMin;
204 }
205
206 /*! \brief
207  * Find and return the log of the probability weight of a point given a coordinate value.
208  *
209  * The unnormalized weight is given by
210  * w(point|value) = exp(bias(point) - U(value,point)),
211  * where U is a harmonic umbrella potential.
212  *
213  * \param[in] dimParams              The bias dimensions parameters
214  * \param[in] points                 The point state.
215  * \param[in] grid                   The grid.
216  * \param[in] pointIndex             Point to evaluate probability weight for.
217  * \param[in] pointBias              Bias for the point (as a log weight).
218  * \param[in] value                  Coordinate value.
219  * \param[in] neighborLambdaEnergies The energy of the system in neighboring lambdas states. Can be
220  * empty when there are no free energy lambda state dimensions.
221  * \param[in] gridpointIndex         The index of the current grid point.
222  * \returns the log of the biased probability weight.
223  */
224 double biasedLogWeightFromPoint(ArrayRef<const DimParams>  dimParams,
225                                 ArrayRef<const PointState> points,
226                                 const BiasGrid&            grid,
227                                 int                        pointIndex,
228                                 double                     pointBias,
229                                 const awh_dvec             value,
230                                 ArrayRef<const double>     neighborLambdaEnergies,
231                                 int                        gridpointIndex)
232 {
233     double logWeight = detail::c_largeNegativeExponent;
234
235     /* Only points in the target region have non-zero weight */
236     if (points[pointIndex].inTargetRegion())
237     {
238         logWeight = pointBias;
239
240         /* Add potential for all parameter dimensions */
241         for (size_t d = 0; d < dimParams.size(); d++)
242         {
243             if (dimParams[d].isFepLambdaDimension())
244             {
245                 /* If this is not a sampling step or if this function is called from
246                  * calcConvolvedBias(), when writing energy subblocks, neighborLambdaEnergies will
247                  * be empty. No convolution is required along the lambda dimension. */
248                 if (!neighborLambdaEnergies.empty())
249                 {
250                     const int pointLambdaIndex     = grid.point(pointIndex).coordValue[d];
251                     const int gridpointLambdaIndex = grid.point(gridpointIndex).coordValue[d];
252                     logWeight -= dimParams[d].fepDimParams().beta
253                                  * (neighborLambdaEnergies[pointLambdaIndex]
254                                     - neighborLambdaEnergies[gridpointLambdaIndex]);
255                 }
256             }
257             else
258             {
259                 double dev = getDeviationFromPointAlongGridAxis(grid, d, pointIndex, value[d]);
260                 logWeight -= 0.5 * dimParams[d].pullDimParams().betak * dev * dev;
261             }
262         }
263     }
264     return logWeight;
265 }
266
267 /*! \brief
268  * Calculates the marginal distribution (marginal probability) for each value along
269  * a free energy lambda axis.
270  * The marginal distribution of one coordinate dimension value is the sum of the probability
271  * distribution of all values (herein all neighbor values) with the same value in the dimension
272  * of interest.
273  * \param[in] grid               The bias grid.
274  * \param[in] neighbors          The points to use for the calculation of the marginal distribution.
275  * \param[in] probWeightNeighbor Probability weights of the neighbors.
276  * \returns The calculated marginal distribution in a 1D array with
277  * as many elements as there are points along the axis of interest.
278  */
279 std::vector<double> calculateFELambdaMarginalDistribution(const BiasGrid&        grid,
280                                                           ArrayRef<const int>    neighbors,
281                                                           ArrayRef<const double> probWeightNeighbor)
282 {
283     const std::optional<int> lambdaAxisIndex = grid.lambdaAxisIndex();
284     GMX_RELEASE_ASSERT(lambdaAxisIndex,
285                        "There must be a free energy lambda axis in order to calculate the free "
286                        "energy lambda marginal distribution.");
287     const int           numFepLambdaStates = grid.numFepLambdaStates();
288     std::vector<double> lambdaMarginalDistribution(numFepLambdaStates, 0);
289
290     for (size_t i = 0; i < neighbors.size(); i++)
291     {
292         const int neighbor    = neighbors[i];
293         const int lambdaState = grid.point(neighbor).coordValue[lambdaAxisIndex.value()];
294         lambdaMarginalDistribution[lambdaState] += probWeightNeighbor[i];
295     }
296     return lambdaMarginalDistribution;
297 }
298
299 } // namespace
300
301 void BiasState::calcConvolvedPmf(ArrayRef<const DimParams> dimParams,
302                                  const BiasGrid&           grid,
303                                  std::vector<float>*       convolvedPmf) const
304 {
305     size_t numPoints = grid.numPoints();
306
307     convolvedPmf->resize(numPoints);
308
309     /* Get the PMF to convolve. */
310     std::vector<float> pmf(numPoints);
311     getPmf(pmf);
312
313     for (size_t m = 0; m < numPoints; m++)
314     {
315         double           freeEnergyWeights = 0;
316         const GridPoint& point             = grid.point(m);
317         for (const auto& neighbor : point.neighbor)
318         {
319             /* Do not convolve the bias along a lambda axis - only use the pmf from the current point */
320             if (!pointsHaveDifferentLambda(grid, m, neighbor))
321             {
322                 /* The negative PMF is a positive bias. */
323                 double biasNeighbor = -pmf[neighbor];
324
325                 /* Add the convolved PMF weights for the neighbors of this point.
326                 Note that this function only adds point within the target > 0 region.
327                 Sum weights, take the logarithm last to get the free energy. */
328                 double logWeight = biasedLogWeightFromPoint(
329                         dimParams, points_, grid, neighbor, biasNeighbor, point.coordValue, {}, m);
330                 freeEnergyWeights += std::exp(logWeight);
331             }
332         }
333
334         GMX_RELEASE_ASSERT(freeEnergyWeights > 0,
335                            "Attempting to do log(<= 0) in AWH convolved PMF calculation.");
336         (*convolvedPmf)[m] = -std::log(static_cast<float>(freeEnergyWeights));
337     }
338 }
339
340 namespace
341 {
342
343 /*! \brief
344  * Updates the target distribution for all points.
345  *
346  * The target distribution is always updated for all points
347  * at the same time.
348  *
349  * \param[in,out] pointState  The state of all points.
350  * \param[in]     params      The bias parameters.
351  */
352 void updateTargetDistribution(ArrayRef<PointState> pointState, const BiasParams& params)
353 {
354     double freeEnergyCutoff = 0;
355     if (params.eTarget == AwhTargetType::Cutoff)
356     {
357         freeEnergyCutoff = freeEnergyMinimumValue(pointState) + params.freeEnergyCutoffInKT;
358     }
359
360     double sumTarget = 0;
361     for (PointState& ps : pointState)
362     {
363         sumTarget += ps.updateTargetWeight(params, freeEnergyCutoff);
364     }
365     GMX_RELEASE_ASSERT(sumTarget > 0, "We should have a non-zero distribution");
366
367     /* Normalize to 1 */
368     double invSum = 1.0 / sumTarget;
369     for (PointState& ps : pointState)
370     {
371         ps.scaleTarget(invSum);
372     }
373 }
374
375 /*! \brief
376  * Puts together a string describing a grid point.
377  *
378  * \param[in] grid         The grid.
379  * \param[in] point        BiasGrid point index.
380  * \returns a string for the point.
381  */
382 std::string gridPointValueString(const BiasGrid& grid, int point)
383 {
384     std::string pointString;
385
386     pointString += "(";
387
388     for (int d = 0; d < grid.numDimensions(); d++)
389     {
390         pointString += gmx::formatString("%g", grid.point(point).coordValue[d]);
391         if (d < grid.numDimensions() - 1)
392         {
393             pointString += ",";
394         }
395         else
396         {
397             pointString += ")";
398         }
399     }
400
401     return pointString;
402 }
403
404 } // namespace
405
406 int BiasState::warnForHistogramAnomalies(const BiasGrid& grid, int biasIndex, double t, FILE* fplog, int maxNumWarnings) const
407 {
408     GMX_ASSERT(fplog != nullptr, "Warnings can only be issued if there is log file.");
409     const double maxHistogramRatio = 0.5; /* Tolerance for printing a warning about the histogram ratios */
410
411     /* Sum up the histograms and get their normalization */
412     double sumVisits  = 0;
413     double sumWeights = 0;
414     for (const auto& pointState : points_)
415     {
416         if (pointState.inTargetRegion())
417         {
418             sumVisits += pointState.numVisitsTot();
419             sumWeights += pointState.weightSumTot();
420         }
421     }
422     GMX_RELEASE_ASSERT(sumVisits > 0, "We should have visits");
423     GMX_RELEASE_ASSERT(sumWeights > 0, "We should have weight");
424     double invNormVisits = 1.0 / sumVisits;
425     double invNormWeight = 1.0 / sumWeights;
426
427     /* Check all points for warnings */
428     int    numWarnings = 0;
429     size_t numPoints   = grid.numPoints();
430     for (size_t m = 0; m < numPoints; m++)
431     {
432         /* Skip points close to boundary or non-target region */
433         const GridPoint& gridPoint = grid.point(m);
434         bool             skipPoint = false;
435         for (size_t n = 0; (n < gridPoint.neighbor.size()) && !skipPoint; n++)
436         {
437             int neighbor = gridPoint.neighbor[n];
438             skipPoint    = !points_[neighbor].inTargetRegion();
439             for (int d = 0; (d < grid.numDimensions()) && !skipPoint; d++)
440             {
441                 const GridPoint& neighborPoint = grid.point(neighbor);
442                 skipPoint                      = neighborPoint.index[d] == 0
443                             || neighborPoint.index[d] == grid.axis(d).numPoints() - 1;
444             }
445         }
446
447         /* Warn if the coordinate distribution is less than the target distribution with a certain fraction somewhere */
448         const double relativeWeight = points_[m].weightSumTot() * invNormWeight;
449         const double relativeVisits = points_[m].numVisitsTot() * invNormVisits;
450         if (!skipPoint && relativeVisits < relativeWeight * maxHistogramRatio)
451         {
452             std::string pointValueString = gridPointValueString(grid, m);
453             std::string warningMessage   = gmx::formatString(
454                     "\nawh%d warning: "
455                     "at t = %g ps the obtained coordinate distribution at coordinate value %s "
456                     "is less than a fraction %g of the reference distribution at that point. "
457                     "If you are not certain about your settings you might want to increase your "
458                     "pull force constant or "
459                     "modify your sampling region.\n",
460                     biasIndex + 1,
461                     t,
462                     pointValueString.c_str(),
463                     maxHistogramRatio);
464             gmx::TextLineWrapper wrapper;
465             wrapper.settings().setLineLength(c_linewidth);
466             fprintf(fplog, "%s", wrapper.wrapToString(warningMessage).c_str());
467
468             numWarnings++;
469         }
470         if (numWarnings >= maxNumWarnings)
471         {
472             break;
473         }
474     }
475
476     return numWarnings;
477 }
478
479 double BiasState::calcUmbrellaForceAndPotential(ArrayRef<const DimParams> dimParams,
480                                                 const BiasGrid&           grid,
481                                                 int                       point,
482                                                 ArrayRef<const double>    neighborLambdaDhdl,
483                                                 ArrayRef<double>          force) const
484 {
485     double potential = 0;
486     for (size_t d = 0; d < dimParams.size(); d++)
487     {
488         if (dimParams[d].isFepLambdaDimension())
489         {
490             if (!neighborLambdaDhdl.empty())
491             {
492                 const int coordpointLambdaIndex = grid.point(point).coordValue[d];
493                 force[d]                        = neighborLambdaDhdl[coordpointLambdaIndex];
494                 /* The potential should not be affected by the lambda dimension. */
495             }
496         }
497         else
498         {
499             double deviation =
500                     getDeviationFromPointAlongGridAxis(grid, d, point, coordState_.coordValue()[d]);
501             double k = dimParams[d].pullDimParams().k;
502
503             /* Force from harmonic potential 0.5*k*dev^2 */
504             force[d] = -k * deviation;
505             potential += 0.5 * k * deviation * deviation;
506         }
507     }
508
509     return potential;
510 }
511
512 void BiasState::calcConvolvedForce(ArrayRef<const DimParams> dimParams,
513                                    const BiasGrid&           grid,
514                                    ArrayRef<const double>    probWeightNeighbor,
515                                    ArrayRef<const double>    neighborLambdaDhdl,
516                                    ArrayRef<double>          forceWorkBuffer,
517                                    ArrayRef<double>          force) const
518 {
519     for (size_t d = 0; d < dimParams.size(); d++)
520     {
521         force[d] = 0;
522     }
523
524     /* Only neighboring points have non-negligible contribution. */
525     const std::vector<int>& neighbor          = grid.point(coordState_.gridpointIndex()).neighbor;
526     gmx::ArrayRef<double>   forceFromNeighbor = forceWorkBuffer;
527     for (size_t n = 0; n < neighbor.size(); n++)
528     {
529         double weightNeighbor = probWeightNeighbor[n];
530         int    indexNeighbor  = neighbor[n];
531
532         /* Get the umbrella force from this point. The returned potential is ignored here. */
533         calcUmbrellaForceAndPotential(dimParams, grid, indexNeighbor, neighborLambdaDhdl, forceFromNeighbor);
534
535         /* Add the weighted umbrella force to the convolved force. */
536         for (size_t d = 0; d < dimParams.size(); d++)
537         {
538             force[d] += forceFromNeighbor[d] * weightNeighbor;
539         }
540     }
541 }
542
543 double BiasState::moveUmbrella(ArrayRef<const DimParams> dimParams,
544                                const BiasGrid&           grid,
545                                ArrayRef<const double>    probWeightNeighbor,
546                                ArrayRef<const double>    neighborLambdaDhdl,
547                                ArrayRef<double>          biasForce,
548                                int64_t                   step,
549                                int64_t                   seed,
550                                int                       indexSeed,
551                                bool                      onlySampleUmbrellaGridpoint)
552 {
553     /* Generate and set a new coordinate reference value */
554     coordState_.sampleUmbrellaGridpoint(
555             grid, coordState_.gridpointIndex(), probWeightNeighbor, step, seed, indexSeed);
556
557     if (onlySampleUmbrellaGridpoint)
558     {
559         return 0;
560     }
561
562     std::vector<double> newForce(dimParams.size());
563     double              newPotential = calcUmbrellaForceAndPotential(
564             dimParams, grid, coordState_.umbrellaGridpoint(), neighborLambdaDhdl, newForce);
565
566     /*  A modification of the reference value at time t will lead to a different
567         force over t-dt/2 to t and over t to t+dt/2. For high switching rates
568         this means the force and velocity will change signs roughly as often.
569         To avoid any issues we take the average of the previous and new force
570         at steps when the reference value has been moved. E.g. if the ref. value
571         is set every step to (coord dvalue +/- delta) would give zero force.
572      */
573     for (gmx::index d = 0; d < biasForce.ssize(); d++)
574     {
575         /* Average of the current and new force */
576         biasForce[d] = 0.5 * (biasForce[d] + newForce[d]);
577     }
578
579     return newPotential;
580 }
581
582 namespace
583 {
584
585 /*! \brief
586  * Sets the histogram rescaling factors needed to control the histogram size.
587  *
588  * For sake of robustness, the reference weight histogram can grow at a rate
589  * different from the actual sampling rate. Typically this happens for a limited
590  * initial time, alternatively growth is scaled down by a constant factor for all
591  * times. Since the size of the reference histogram sets the size of the free
592  * energy update this should be reflected also in the PMF. Thus the PMF histogram
593  * needs to be rescaled too.
594  *
595  * This function should only be called by the bias update function or wrapped by a function that
596  * knows what scale factors should be applied when, e.g,
597  * getSkippedUpdateHistogramScaleFactors().
598  *
599  * \param[in]  params             The bias parameters.
600  * \param[in]  newHistogramSize   New reference weight histogram size.
601  * \param[in]  oldHistogramSize   Previous reference weight histogram size (before adding new samples).
602  * \param[out] weightHistScaling  Scaling factor for the reference weight histogram.
603  * \param[out] logPmfSumScaling   Log of the scaling factor for the PMF histogram.
604  */
605 void setHistogramUpdateScaleFactors(const BiasParams& params,
606                                     double            newHistogramSize,
607                                     double            oldHistogramSize,
608                                     double*           weightHistScaling,
609                                     double*           logPmfSumScaling)
610 {
611
612     /* The two scaling factors below are slightly different (ignoring the log factor) because the
613        reference and the PMF histogram apply weight scaling differently. The weight histogram
614        applies is  locally, i.e. each sample is scaled down meaning all samples get equal weight.
615        It is done this way because that is what target type local Boltzmann (for which
616        target = weight histogram) needs. In contrast, the PMF histogram is rescaled globally
617        by repeatedly scaling down the whole histogram. The reasons for doing it this way are:
618        1) empirically this is necessary for converging the PMF; 2) since the extraction of
619        the PMF is theoretically only valid for a constant bias, new samples should get more
620        weight than old ones for which the bias is fluctuating more. */
621     *weightHistScaling =
622             newHistogramSize / (oldHistogramSize + params.updateWeight * params.localWeightScaling);
623     *logPmfSumScaling = std::log(newHistogramSize / (oldHistogramSize + params.updateWeight));
624 }
625
626 } // namespace
627
628 void BiasState::getSkippedUpdateHistogramScaleFactors(const BiasParams& params,
629                                                       double*           weightHistScaling,
630                                                       double*           logPmfSumScaling) const
631 {
632     GMX_ASSERT(params.skipUpdates(),
633                "Calling function for skipped updates when skipping updates is not allowed");
634
635     if (inInitialStage())
636     {
637         /* In between global updates the reference histogram size is kept constant so we trivially
638            know what the histogram size was at the time of the skipped update. */
639         double histogramSize = histogramSize_.histogramSize();
640         setHistogramUpdateScaleFactors(
641                 params, histogramSize, histogramSize, weightHistScaling, logPmfSumScaling);
642     }
643     else
644     {
645         /* In the final stage, the reference histogram grows at the sampling rate which gives trivial scale factors. */
646         *weightHistScaling = 1;
647         *logPmfSumScaling  = 0;
648     }
649 }
650
651 void BiasState::doSkippedUpdatesForAllPoints(const BiasParams& params)
652 {
653     double weightHistScaling;
654     double logPmfsumScaling;
655
656     getSkippedUpdateHistogramScaleFactors(params, &weightHistScaling, &logPmfsumScaling);
657
658     for (auto& pointState : points_)
659     {
660         bool didUpdate = pointState.performPreviouslySkippedUpdates(
661                 params, histogramSize_.numUpdates(), weightHistScaling, logPmfsumScaling);
662
663         /* Update the bias for this point only if there were skipped updates in the past to avoid calculating the log unneccessarily */
664         if (didUpdate)
665         {
666             pointState.updateBias();
667         }
668     }
669 }
670
671 void BiasState::doSkippedUpdatesInNeighborhood(const BiasParams& params, const BiasGrid& grid)
672 {
673     double weightHistScaling;
674     double logPmfsumScaling;
675
676     getSkippedUpdateHistogramScaleFactors(params, &weightHistScaling, &logPmfsumScaling);
677
678     /* For each neighbor point of the center point, refresh its state by adding the results of all past, skipped updates. */
679     const std::vector<int>& neighbors = grid.point(coordState_.gridpointIndex()).neighbor;
680     for (const auto& neighbor : neighbors)
681     {
682         bool didUpdate = points_[neighbor].performPreviouslySkippedUpdates(
683                 params, histogramSize_.numUpdates(), weightHistScaling, logPmfsumScaling);
684
685         if (didUpdate)
686         {
687             points_[neighbor].updateBias();
688         }
689     }
690 }
691
692 namespace
693 {
694
695 /*! \brief
696  * Merge update lists from multiple sharing simulations.
697  *
698  * \param[in,out] updateList    Update list for this simulation (assumed >= npoints long).
699  * \param[in]     numPoints     Total number of points.
700  * \param[in]     commRecord    Struct for intra-simulation communication.
701  * \param[in]     multiSimComm  Struct for multi-simulation communication.
702  */
703 void mergeSharedUpdateLists(std::vector<int>*     updateList,
704                             int                   numPoints,
705                             const t_commrec*      commRecord,
706                             const gmx_multisim_t* multiSimComm)
707 {
708     std::vector<int> numUpdatesOfPoint;
709
710     /* Flag the update points of this sim.
711        TODO: we can probably avoid allocating this array and just use the input array. */
712     numUpdatesOfPoint.resize(numPoints, 0);
713     for (auto& pointIndex : *updateList)
714     {
715         numUpdatesOfPoint[pointIndex] = 1;
716     }
717
718     /* Sum over the sims to get all the flagged points */
719     sumOverSimulations(arrayRefFromArray(numUpdatesOfPoint.data(), numPoints), commRecord, multiSimComm);
720
721     /* Collect the indices of the flagged points in place. The resulting array will be the merged update list.*/
722     updateList->clear();
723     for (int m = 0; m < numPoints; m++)
724     {
725         if (numUpdatesOfPoint[m] > 0)
726         {
727             updateList->push_back(m);
728         }
729     }
730 }
731
732 /*! \brief
733  * Generate an update list of points sampled since the last update.
734  *
735  * \param[in] grid              The AWH bias.
736  * \param[in] points            The point state.
737  * \param[in] originUpdatelist  The origin of the rectangular region that has been sampled since
738  * last update. \param[in] endUpdatelist     The end of the rectangular that has been sampled since
739  * last update. \param[in,out] updateList    Local update list to set (assumed >= npoints long).
740  */
741 void makeLocalUpdateList(const BiasGrid&            grid,
742                          ArrayRef<const PointState> points,
743                          const awh_ivec             originUpdatelist,
744                          const awh_ivec             endUpdatelist,
745                          std::vector<int>*          updateList)
746 {
747     awh_ivec origin;
748     awh_ivec numPoints;
749
750     /* Define the update search grid */
751     for (int d = 0; d < grid.numDimensions(); d++)
752     {
753         origin[d]    = originUpdatelist[d];
754         numPoints[d] = endUpdatelist[d] - originUpdatelist[d] + 1;
755
756         /* Because the end_updatelist is unwrapped it can be > (npoints - 1) so that numPoints can be > npoints in grid.
757            This helps for calculating the distance/number of points but should be removed and fixed when the way of
758            updating origin/end updatelist is changed (see sampleProbabilityWeights). */
759         numPoints[d] = std::min(grid.axis(d).numPoints(), numPoints[d]);
760     }
761
762     /* Make the update list */
763     updateList->clear();
764     int  pointIndex  = -1;
765     bool pointExists = true;
766     while (pointExists)
767     {
768         pointExists = advancePointInSubgrid(grid, origin, numPoints, &pointIndex);
769
770         if (pointExists && points[pointIndex].inTargetRegion())
771         {
772             updateList->push_back(pointIndex);
773         }
774     }
775 }
776
777 } // namespace
778
779 void BiasState::resetLocalUpdateRange(const BiasGrid& grid)
780 {
781     const int gridpointIndex = coordState_.gridpointIndex();
782     for (int d = 0; d < grid.numDimensions(); d++)
783     {
784         /* This gives the  minimum range consisting only of the current closest point. */
785         originUpdatelist_[d] = grid.point(gridpointIndex).index[d];
786         endUpdatelist_[d]    = grid.point(gridpointIndex).index[d];
787     }
788 }
789
790 namespace
791 {
792
793 /*! \brief
794  * Add partial histograms (accumulating between updates) to accumulating histograms.
795  *
796  * \param[in,out] pointState         The state of the points in the bias.
797  * \param[in,out] weightSumCovering  The weights for checking covering.
798  * \param[in]     numSharedUpdate    The number of biases sharing the histrogram.
799  * \param[in]     commRecord         Struct for intra-simulation communication.
800  * \param[in]     multiSimComm       Struct for multi-simulation communication.
801  * \param[in]     localUpdateList    List of points with data.
802  */
803 void sumHistograms(gmx::ArrayRef<PointState> pointState,
804                    gmx::ArrayRef<double>     weightSumCovering,
805                    int                       numSharedUpdate,
806                    const t_commrec*          commRecord,
807                    const gmx_multisim_t*     multiSimComm,
808                    const std::vector<int>&   localUpdateList)
809 {
810     /* The covering checking histograms are added before summing over simulations, so that the
811        weights from different simulations are kept distinguishable. */
812     for (int globalIndex : localUpdateList)
813     {
814         weightSumCovering[globalIndex] += pointState[globalIndex].weightSumIteration();
815     }
816
817     /* Sum histograms over multiple simulations if needed. */
818     if (numSharedUpdate > 1)
819     {
820         GMX_ASSERT(numSharedUpdate == multiSimComm->numSimulations_,
821                    "Sharing within a simulation is not implemented (yet)");
822
823         /* Collect the weights and counts in linear arrays to be able to use gmx_sumd_sim. */
824         std::vector<double> weightSum;
825         std::vector<double> coordVisits;
826
827         weightSum.resize(localUpdateList.size());
828         coordVisits.resize(localUpdateList.size());
829
830         for (size_t localIndex = 0; localIndex < localUpdateList.size(); localIndex++)
831         {
832             const PointState& ps = pointState[localUpdateList[localIndex]];
833
834             weightSum[localIndex]   = ps.weightSumIteration();
835             coordVisits[localIndex] = ps.numVisitsIteration();
836         }
837
838         sumOverSimulations(gmx::ArrayRef<double>(weightSum), commRecord, multiSimComm);
839         sumOverSimulations(gmx::ArrayRef<double>(coordVisits), commRecord, multiSimComm);
840
841         /* Transfer back the result */
842         for (size_t localIndex = 0; localIndex < localUpdateList.size(); localIndex++)
843         {
844             PointState& ps = pointState[localUpdateList[localIndex]];
845
846             ps.setPartialWeightAndCount(weightSum[localIndex], coordVisits[localIndex]);
847         }
848     }
849
850     /* Now add the partial counts and weights to the accumulating histograms.
851        Note: we still need to use the weights for the update so we wait
852        with resetting them until the end of the update. */
853     for (int globalIndex : localUpdateList)
854     {
855         pointState[globalIndex].addPartialWeightAndCount();
856     }
857 }
858
859 /*! \brief
860  * Label points along an axis as covered or not.
861  *
862  * A point is covered if it is surrounded by visited points up to a radius = coverRadius.
863  *
864  * \param[in]     visited        Visited? For each point.
865  * \param[in]     checkCovering  Check for covering? For each point.
866  * \param[in]     numPoints      The number of grid points along this dimension.
867  * \param[in]     period         Period in number of points.
868  * \param[in]     coverRadius    Cover radius, in points, needed for defining a point as covered.
869  * \param[in,out] covered        In this array elements are 1 for covered points and 0 for
870  * non-covered points, this routine assumes that \p covered has at least size \p numPoints.
871  */
872 void labelCoveredPoints(const std::vector<bool>& visited,
873                         const std::vector<bool>& checkCovering,
874                         int                      numPoints,
875                         int                      period,
876                         int                      coverRadius,
877                         gmx::ArrayRef<int>       covered)
878 {
879     GMX_ASSERT(covered.ssize() >= numPoints, "covered should be at least as large as the grid");
880
881     bool haveFirstNotVisited = false;
882     int  firstNotVisited     = -1;
883     int  notVisitedLow       = -1;
884     int  notVisitedHigh      = -1;
885
886     for (int n = 0; n < numPoints; n++)
887     {
888         if (checkCovering[n] && !visited[n])
889         {
890             if (!haveFirstNotVisited)
891             {
892                 notVisitedLow       = n;
893                 firstNotVisited     = n;
894                 haveFirstNotVisited = true;
895             }
896             else
897             {
898                 notVisitedHigh = n;
899
900                 /* Have now an interval I = [notVisitedLow,notVisitedHigh] of visited points bounded
901                    by unvisited points. The unvisted end points affect the coveredness of the
902                    visited with a reach equal to the cover radius. */
903                 int notCoveredLow  = notVisitedLow + coverRadius;
904                 int notCoveredHigh = notVisitedHigh - coverRadius;
905                 for (int i = notVisitedLow; i <= notVisitedHigh; i++)
906                 {
907                     covered[i] = static_cast<int>((i > notCoveredLow) && (i < notCoveredHigh));
908                 }
909
910                 /* Find a new interval to set covering for. Make the notVisitedHigh of this interval
911                    the notVisitedLow of the next. */
912                 notVisitedLow = notVisitedHigh;
913             }
914         }
915     }
916
917     /* Have labelled all the internal points. Now take care of the boundary regions. */
918     if (!haveFirstNotVisited)
919     {
920         /* No non-visited points <=> all points visited => all points covered. */
921
922         for (int n = 0; n < numPoints; n++)
923         {
924             covered[n] = 1;
925         }
926     }
927     else
928     {
929         int lastNotVisited = notVisitedLow;
930
931         /* For periodic boundaries, non-visited points can influence points
932            on the other side of the boundary so we need to wrap around. */
933
934         /* Lower end. For periodic boundaries the last upper end not visited point becomes the low-end not visited point.
935            For non-periodic boundaries there is no lower end point so a dummy value is used. */
936         int notVisitedHigh = firstNotVisited;
937         int notVisitedLow  = period > 0 ? (lastNotVisited - period) : -(coverRadius + 1);
938
939         int notCoveredLow  = notVisitedLow + coverRadius;
940         int notCoveredHigh = notVisitedHigh - coverRadius;
941
942         for (int i = 0; i <= notVisitedHigh; i++)
943         {
944             /* For non-periodic boundaries notCoveredLow = -1 will impose no restriction. */
945             covered[i] = static_cast<int>((i > notCoveredLow) && (i < notCoveredHigh));
946         }
947
948         /* Upper end. Same as for lower end but in the other direction. */
949         notVisitedHigh = period > 0 ? (firstNotVisited + period) : (numPoints + coverRadius);
950         notVisitedLow  = lastNotVisited;
951
952         notCoveredLow  = notVisitedLow + coverRadius;
953         notCoveredHigh = notVisitedHigh - coverRadius;
954
955         for (int i = notVisitedLow; i <= numPoints - 1; i++)
956         {
957             /* For non-periodic boundaries notCoveredHigh = numPoints will impose no restriction. */
958             covered[i] = static_cast<int>((i > notCoveredLow) && (i < notCoveredHigh));
959         }
960     }
961 }
962
963 } // namespace
964
965 bool BiasState::isSamplingRegionCovered(const BiasParams&         params,
966                                         ArrayRef<const DimParams> dimParams,
967                                         const BiasGrid&           grid,
968                                         const t_commrec*          commRecord,
969                                         const gmx_multisim_t*     multiSimComm) const
970 {
971     /* Allocate and initialize arrays: one for checking visits along each dimension,
972        one for keeping track of which points to check and one for the covered points.
973        Possibly these could be kept as AWH variables to avoid these allocations. */
974     struct CheckDim
975     {
976         std::vector<bool> visited;
977         std::vector<bool> checkCovering;
978         // We use int for the covering array since we might use gmx_sumi_sim.
979         std::vector<int> covered;
980     };
981
982     std::vector<CheckDim> checkDim;
983     checkDim.resize(grid.numDimensions());
984
985     for (int d = 0; d < grid.numDimensions(); d++)
986     {
987         const size_t numPoints = grid.axis(d).numPoints();
988         checkDim[d].visited.resize(numPoints, false);
989         checkDim[d].checkCovering.resize(numPoints, false);
990         checkDim[d].covered.resize(numPoints, 0);
991     }
992
993     /* Set visited points along each dimension and which points should be checked for covering.
994        Specifically, points above the free energy cutoff (if there is one) or points outside
995        of the target region are ignored. */
996
997     /* Set the free energy cutoff */
998     double maxFreeEnergy = GMX_FLOAT_MAX;
999
1000     if (params.eTarget == AwhTargetType::Cutoff)
1001     {
1002         maxFreeEnergy = freeEnergyMinimumValue(points_) + params.freeEnergyCutoffInKT;
1003     }
1004
1005     /* Set the threshold weight for a point to be considered visited. */
1006     double weightThreshold = 1;
1007     for (int d = 0; d < grid.numDimensions(); d++)
1008     {
1009         if (grid.axis(d).isFepLambdaAxis())
1010         {
1011             /* Do not modify the weight threshold based on a FEP lambda axis. The spread
1012              * of the sampling weights is not depending on a Gaussian distribution (like
1013              * below). */
1014             weightThreshold *= 1.0;
1015         }
1016         else
1017         {
1018             /* The spacing is proportional to 1/sqrt(betak). The weight threshold will be
1019              * approximately (given that the spacing can be modified if the dimension is periodic)
1020              * proportional to sqrt(1/(2*pi)). */
1021             weightThreshold *= grid.axis(d).spacing()
1022                                * std::sqrt(dimParams[d].pullDimParams().betak * 0.5 * M_1_PI);
1023         }
1024     }
1025
1026     /* Project the sampling weights onto each dimension */
1027     for (size_t m = 0; m < grid.numPoints(); m++)
1028     {
1029         const PointState& pointState = points_[m];
1030
1031         for (int d = 0; d < grid.numDimensions(); d++)
1032         {
1033             int n = grid.point(m).index[d];
1034
1035             /* Is visited if it was already visited or if there is enough weight at the current point */
1036             checkDim[d].visited[n] = checkDim[d].visited[n] || (weightSumCovering_[m] > weightThreshold);
1037
1038             /* Check for covering if there is at least point in this slice that is in the target region and within the cutoff */
1039             checkDim[d].checkCovering[n] =
1040                     checkDim[d].checkCovering[n]
1041                     || (pointState.inTargetRegion() && pointState.freeEnergy() < maxFreeEnergy);
1042         }
1043     }
1044
1045     /* Label each point along each dimension as covered or not. */
1046     for (int d = 0; d < grid.numDimensions(); d++)
1047     {
1048         labelCoveredPoints(checkDim[d].visited,
1049                            checkDim[d].checkCovering,
1050                            grid.axis(d).numPoints(),
1051                            grid.axis(d).numPointsInPeriod(),
1052                            params.coverRadius()[d],
1053                            checkDim[d].covered);
1054     }
1055
1056     /* Now check for global covering. Each dimension needs to be covered separately.
1057        A dimension is covered if each point is covered.  Multiple simulations collectively
1058        cover the points, i.e. a point is covered if any of the simulations covered it.
1059        However, visited points are not shared, i.e. if a point is covered or not is
1060        determined by the visits of a single simulation. In general the covering criterion is
1061        all points covered => all points are surrounded by visited points up to a radius = coverRadius.
1062        For 1 simulation, all points covered <=> all points visited. For multiple simulations
1063        however, all points visited collectively !=> all points covered, except for coverRadius = 0.
1064        In the limit of large coverRadius, all points covered => all points visited by at least one
1065        simulation (since no point will be covered until all points have been visited by a
1066        single simulation). Basically coverRadius sets how much "connectedness" (or mixing) a point
1067        needs with surrounding points before sharing covering information with other simulations. */
1068
1069     /* Communicate the covered points between sharing simulations if needed. */
1070     if (params.numSharedUpdate > 1)
1071     {
1072         /* For multiple dimensions this may not be the best way to do it. */
1073         for (int d = 0; d < grid.numDimensions(); d++)
1074         {
1075             sumOverSimulations(
1076                     gmx::arrayRefFromArray(checkDim[d].covered.data(), grid.axis(d).numPoints()),
1077                     commRecord,
1078                     multiSimComm);
1079         }
1080     }
1081
1082     /* Now check if for each dimension all points are covered. Break if not true. */
1083     bool allPointsCovered = true;
1084     for (int d = 0; d < grid.numDimensions() && allPointsCovered; d++)
1085     {
1086         for (int n = 0; n < grid.axis(d).numPoints() && allPointsCovered; n++)
1087         {
1088             allPointsCovered = (checkDim[d].covered[n] != 0);
1089         }
1090     }
1091
1092     return allPointsCovered;
1093 }
1094
1095 /*! \brief
1096  * Normalizes the free energy and PMF sum.
1097  *
1098  * \param[in] pointState  The state of the points.
1099  */
1100 static void normalizeFreeEnergyAndPmfSum(std::vector<PointState>* pointState)
1101 {
1102     double minF = freeEnergyMinimumValue(*pointState);
1103
1104     for (PointState& ps : *pointState)
1105     {
1106         ps.normalizeFreeEnergyAndPmfSum(minF);
1107     }
1108 }
1109
1110 void BiasState::updateFreeEnergyAndAddSamplesToHistogram(ArrayRef<const DimParams> dimParams,
1111                                                          const BiasGrid&           grid,
1112                                                          const BiasParams&         params,
1113                                                          const t_commrec*          commRecord,
1114                                                          const gmx_multisim_t*     multiSimComm,
1115                                                          double                    t,
1116                                                          int64_t                   step,
1117                                                          FILE*                     fplog,
1118                                                          std::vector<int>*         updateList)
1119 {
1120     /* Note hat updateList is only used in this scope and is always
1121      * re-initialized. We do not use a local vector, because that would
1122      * cause reallocation every time this funtion is called and the vector
1123      * can be the size of the whole grid.
1124      */
1125
1126     /* Make a list of all local points, i.e. those that could have been touched since
1127        the last update. These are the points needed for summing histograms below
1128        (non-local points only add zeros). For local updates, this will also be the
1129        final update list. */
1130     makeLocalUpdateList(grid, points_, originUpdatelist_, endUpdatelist_, updateList);
1131     if (params.numSharedUpdate > 1)
1132     {
1133         mergeSharedUpdateLists(updateList, points_.size(), commRecord, multiSimComm);
1134     }
1135
1136     /* Reset the range for the next update */
1137     resetLocalUpdateRange(grid);
1138
1139     /* Add samples to histograms for all local points and sync simulations if needed */
1140     sumHistograms(points_, weightSumCovering_, params.numSharedUpdate, commRecord, multiSimComm, *updateList);
1141
1142     sumPmf(points_, params.numSharedUpdate, commRecord, multiSimComm);
1143
1144     /* Renormalize the free energy if values are too large. */
1145     bool needToNormalizeFreeEnergy = false;
1146     for (int& globalIndex : *updateList)
1147     {
1148         /* We want to keep the absolute value of the free energies to be less
1149            c_largePositiveExponent to be able to safely pass these values to exp(). The check below
1150            ensures this as long as the free energy values grow less than 0.5*c_largePositiveExponent
1151            in a return time to this neighborhood. For reasonable update sizes it's unlikely that
1152            this requirement would be broken. */
1153         if (std::abs(points_[globalIndex].freeEnergy()) > 0.5 * detail::c_largePositiveExponent)
1154         {
1155             needToNormalizeFreeEnergy = true;
1156             break;
1157         }
1158     }
1159
1160     /* Update target distribution? */
1161     bool needToUpdateTargetDistribution =
1162             (params.eTarget != AwhTargetType::Constant && params.isUpdateTargetStep(step));
1163
1164     /* In the initial stage, the histogram grows dynamically as a function of the number of coverings. */
1165     bool detectedCovering = false;
1166     if (inInitialStage())
1167     {
1168         detectedCovering =
1169                 (params.isCheckCoveringStep(step)
1170                  && isSamplingRegionCovered(params, dimParams, grid, commRecord, multiSimComm));
1171     }
1172
1173     /* The weighthistogram size after this update. */
1174     double newHistogramSize = histogramSize_.newHistogramSize(
1175             params, t, detectedCovering, points_, weightSumCovering_, fplog);
1176
1177     /* Make the update list. Usually we try to only update local points,
1178      * but if the update has non-trivial or non-deterministic effects
1179      * on non-local points a global update is needed. This is the case when:
1180      * 1) a covering occurred in the initial stage, leading to non-trivial
1181      *    histogram rescaling factors; or
1182      * 2) the target distribution will be updated, since we don't make any
1183      *    assumption on its form; or
1184      * 3) the AWH parameters are such that we never attempt to skip non-local
1185      *    updates; or
1186      * 4) the free energy values have grown so large that a renormalization
1187      *    is needed.
1188      */
1189     if (needToUpdateTargetDistribution || detectedCovering || !params.skipUpdates() || needToNormalizeFreeEnergy)
1190     {
1191         /* Global update, just add all points. */
1192         updateList->clear();
1193         for (size_t m = 0; m < points_.size(); m++)
1194         {
1195             if (points_[m].inTargetRegion())
1196             {
1197                 updateList->push_back(m);
1198             }
1199         }
1200     }
1201
1202     /* Set histogram scale factors. */
1203     double weightHistScalingSkipped = 0;
1204     double logPmfsumScalingSkipped  = 0;
1205     if (params.skipUpdates())
1206     {
1207         getSkippedUpdateHistogramScaleFactors(params, &weightHistScalingSkipped, &logPmfsumScalingSkipped);
1208     }
1209     double weightHistScalingNew;
1210     double logPmfsumScalingNew;
1211     setHistogramUpdateScaleFactors(
1212             params, newHistogramSize, histogramSize_.histogramSize(), &weightHistScalingNew, &logPmfsumScalingNew);
1213
1214     /* Update free energy and reference weight histogram for points in the update list. */
1215     for (int pointIndex : *updateList)
1216     {
1217         PointState* pointStateToUpdate = &points_[pointIndex];
1218
1219         /* Do updates from previous update steps that were skipped because this point was at that time non-local. */
1220         if (params.skipUpdates())
1221         {
1222             pointStateToUpdate->performPreviouslySkippedUpdates(
1223                     params, histogramSize_.numUpdates(), weightHistScalingSkipped, logPmfsumScalingSkipped);
1224         }
1225
1226         /* Now do an update with new sampling data. */
1227         pointStateToUpdate->updateWithNewSampling(
1228                 params, histogramSize_.numUpdates(), weightHistScalingNew, logPmfsumScalingNew);
1229     }
1230
1231     /* Only update the histogram size after we are done with the local point updates */
1232     histogramSize_.setHistogramSize(newHistogramSize, weightHistScalingNew);
1233
1234     if (needToNormalizeFreeEnergy)
1235     {
1236         normalizeFreeEnergyAndPmfSum(&points_);
1237     }
1238
1239     if (needToUpdateTargetDistribution)
1240     {
1241         /* The target distribution is always updated for all points at once. */
1242         updateTargetDistribution(points_, params);
1243     }
1244
1245     /* Update the bias. The bias is updated separately and last since it simply a function of
1246        the free energy and the target distribution and we want to avoid doing extra work. */
1247     for (int pointIndex : *updateList)
1248     {
1249         points_[pointIndex].updateBias();
1250     }
1251
1252     /* Increase the update counter. */
1253     histogramSize_.incrementNumUpdates();
1254 }
1255
1256 double BiasState::updateProbabilityWeightsAndConvolvedBias(ArrayRef<const DimParams> dimParams,
1257                                                            const BiasGrid&           grid,
1258                                                            ArrayRef<const double> neighborLambdaEnergies,
1259                                                            std::vector<double, AlignedAllocator<double>>* weight) const
1260 {
1261     /* Only neighbors of the current coordinate value will have a non-negligible chance of getting sampled */
1262     const std::vector<int>& neighbors = grid.point(coordState_.gridpointIndex()).neighbor;
1263
1264 #if GMX_SIMD_HAVE_DOUBLE
1265     typedef SimdDouble PackType;
1266     constexpr int      packSize = GMX_SIMD_DOUBLE_WIDTH;
1267 #else
1268     typedef double PackType;
1269     constexpr int  packSize = 1;
1270 #endif
1271     /* Round the size of the weight array up to packSize */
1272     const int weightSize = ((neighbors.size() + packSize - 1) / packSize) * packSize;
1273     weight->resize(weightSize);
1274
1275     double* gmx_restrict weightData = weight->data();
1276     PackType             weightSumPack(0.0);
1277     for (size_t i = 0; i < neighbors.size(); i += packSize)
1278     {
1279         for (size_t n = i; n < i + packSize; n++)
1280         {
1281             if (n < neighbors.size())
1282             {
1283                 const int neighbor = neighbors[n];
1284                 (*weight)[n]       = biasedLogWeightFromPoint(dimParams,
1285                                                         points_,
1286                                                         grid,
1287                                                         neighbor,
1288                                                         points_[neighbor].bias(),
1289                                                         coordState_.coordValue(),
1290                                                         neighborLambdaEnergies,
1291                                                         coordState_.gridpointIndex());
1292             }
1293             else
1294             {
1295                 /* Pad with values that don't affect the result */
1296                 (*weight)[n] = detail::c_largeNegativeExponent;
1297             }
1298         }
1299         PackType weightPack = load<PackType>(weightData + i);
1300         weightPack          = gmx::exp(weightPack);
1301         weightSumPack       = weightSumPack + weightPack;
1302         store(weightData + i, weightPack);
1303     }
1304     /* Sum of probability weights */
1305     double weightSum = reduce(weightSumPack);
1306     GMX_RELEASE_ASSERT(weightSum > 0,
1307                        "zero probability weight when updating AWH probability weights.");
1308
1309     /* Normalize probabilities to sum to 1 */
1310     double invWeightSum = 1 / weightSum;
1311
1312     /* When there is a free energy lambda state axis remove the convolved contributions along that
1313      * axis from the total bias. This must be done after calculating invWeightSum (since weightSum
1314      * will be modified), but before normalizing the weights (below). */
1315     if (grid.hasLambdaAxis())
1316     {
1317         /* If there is only one axis the bias will not be convolved in any dimension. */
1318         if (grid.axis().size() == 1)
1319         {
1320             weightSum = gmx::exp(points_[coordState_.gridpointIndex()].bias());
1321         }
1322         else
1323         {
1324             for (size_t i = 0; i < neighbors.size(); i++)
1325             {
1326                 const int neighbor = neighbors[i];
1327                 if (pointsHaveDifferentLambda(grid, coordState_.gridpointIndex(), neighbor))
1328                 {
1329                     weightSum -= weightData[i];
1330                 }
1331             }
1332         }
1333     }
1334
1335     for (double& w : *weight)
1336     {
1337         w *= invWeightSum;
1338     }
1339
1340     /* Return the convolved bias */
1341     return std::log(weightSum);
1342 }
1343
1344 double BiasState::calcConvolvedBias(ArrayRef<const DimParams> dimParams,
1345                                     const BiasGrid&           grid,
1346                                     const awh_dvec&           coordValue) const
1347 {
1348     int              point     = grid.nearestIndex(coordValue);
1349     const GridPoint& gridPoint = grid.point(point);
1350
1351     /* Sum the probability weights from the neighborhood of the given point */
1352     double weightSum = 0;
1353     for (int neighbor : gridPoint.neighbor)
1354     {
1355         /* No convolution is required along the lambda dimension. */
1356         if (pointsHaveDifferentLambda(grid, point, neighbor))
1357         {
1358             continue;
1359         }
1360         double logWeight = biasedLogWeightFromPoint(
1361                 dimParams, points_, grid, neighbor, points_[neighbor].bias(), coordValue, {}, point);
1362         weightSum += std::exp(logWeight);
1363     }
1364
1365     /* Returns -GMX_FLOAT_MAX if no neighboring points were in the target region. */
1366     return (weightSum > 0) ? std::log(weightSum) : -GMX_FLOAT_MAX;
1367 }
1368
1369 void BiasState::sampleProbabilityWeights(const BiasGrid& grid, gmx::ArrayRef<const double> probWeightNeighbor)
1370 {
1371     const std::vector<int>& neighbor = grid.point(coordState_.gridpointIndex()).neighbor;
1372
1373     /* Save weights for next update */
1374     for (size_t n = 0; n < neighbor.size(); n++)
1375     {
1376         points_[neighbor[n]].increaseWeightSumIteration(probWeightNeighbor[n]);
1377     }
1378
1379     /* Update the local update range. Two corner points define this rectangular
1380      * domain. We need to choose two new corner points such that the new domain
1381      * contains both the old update range and the current neighborhood.
1382      * In the simplest case when an update is performed every sample,
1383      * the update range would simply equal the current neighborhood.
1384      */
1385     int neighborStart = neighbor[0];
1386     int neighborLast  = neighbor[neighbor.size() - 1];
1387     for (int d = 0; d < grid.numDimensions(); d++)
1388     {
1389         int origin = grid.point(neighborStart).index[d];
1390         int last   = grid.point(neighborLast).index[d];
1391
1392         if (origin > last)
1393         {
1394             /* Unwrap if wrapped around the boundary (only happens for periodic
1395              * boundaries). This has been already for the stored index interval.
1396              */
1397             /* TODO: what we want to do is to find the smallest the update
1398              * interval that contains all points that need to be updated.
1399              * This amounts to combining two intervals, the current
1400              * [origin, end] update interval and the new touched neighborhood
1401              * into a new interval that contains all points from both the old
1402              * intervals.
1403              *
1404              * For periodic boundaries it becomes slightly more complicated
1405              * than for closed boundaries because then it needs not be
1406              * true that origin < end (so one can't simply relate the origin/end
1407              * in the min()/max() below). The strategy here is to choose the
1408              * origin closest to a reference point (index 0) and then unwrap
1409              * the end index if needed and choose the largest end index.
1410              * This ensures that both intervals are in the new interval
1411              * but it's not necessarily the smallest.
1412              * Currently we solve this by going through each possibility
1413              * and checking them.
1414              */
1415             last += grid.axis(d).numPointsInPeriod();
1416         }
1417
1418         originUpdatelist_[d] = std::min(originUpdatelist_[d], origin);
1419         endUpdatelist_[d]    = std::max(endUpdatelist_[d], last);
1420     }
1421 }
1422
1423 void BiasState::sampleCoordAndPmf(const std::vector<DimParams>& dimParams,
1424                                   const BiasGrid&               grid,
1425                                   gmx::ArrayRef<const double>   probWeightNeighbor,
1426                                   double                        convolvedBias)
1427 {
1428     /* Sampling-based deconvolution extracting the PMF.
1429      * Update the PMF histogram with the current coordinate value.
1430      *
1431      * Because of the finite width of the harmonic potential, the free energy
1432      * defined for each coordinate point does not exactly equal that of the
1433      * actual coordinate, the PMF. However, the PMF can be estimated by applying
1434      * the relation exp(-PMF) = exp(-bias_convolved)*P_biased/Z, i.e. by keeping a
1435      * reweighted histogram of the coordinate value. Strictly, this relies on
1436      * the unknown normalization constant Z being either constant or known. Here,
1437      * neither is true except in the long simulation time limit. Empirically however,
1438      * it works (mainly because how the PMF histogram is rescaled).
1439      */
1440
1441     const int                gridPointIndex  = coordState_.gridpointIndex();
1442     const std::optional<int> lambdaAxisIndex = grid.lambdaAxisIndex();
1443
1444     /* Update the PMF of points along a lambda axis with their bias. */
1445     if (lambdaAxisIndex)
1446     {
1447         const std::vector<int>& neighbors = grid.point(gridPointIndex).neighbor;
1448
1449         std::vector<double> lambdaMarginalDistribution =
1450                 calculateFELambdaMarginalDistribution(grid, neighbors, probWeightNeighbor);
1451
1452         awh_dvec coordValueAlongLambda = { coordState_.coordValue()[0],
1453                                            coordState_.coordValue()[1],
1454                                            coordState_.coordValue()[2],
1455                                            coordState_.coordValue()[3] };
1456         for (size_t i = 0; i < neighbors.size(); i++)
1457         {
1458             const int neighbor = neighbors[i];
1459             double    bias;
1460             if (pointsAlongLambdaAxis(grid, gridPointIndex, neighbor))
1461             {
1462                 const double neighborLambda = grid.point(neighbor).coordValue[lambdaAxisIndex.value()];
1463                 if (neighbor == gridPointIndex)
1464                 {
1465                     bias = convolvedBias;
1466                 }
1467                 else
1468                 {
1469                     coordValueAlongLambda[lambdaAxisIndex.value()] = neighborLambda;
1470                     bias = calcConvolvedBias(dimParams, grid, coordValueAlongLambda);
1471                 }
1472
1473                 const double probWeight = lambdaMarginalDistribution[neighborLambda];
1474                 const double weightedBias = bias - std::log(std::max(probWeight, GMX_DOUBLE_MIN)); // avoid log(0)
1475
1476                 if (neighbor == gridPointIndex && grid.covers(coordState_.coordValue()))
1477                 {
1478                     points_[neighbor].samplePmf(weightedBias);
1479                 }
1480                 else
1481                 {
1482                     points_[neighbor].updatePmfUnvisited(weightedBias);
1483                 }
1484             }
1485         }
1486     }
1487     else
1488     {
1489         /* Only save coordinate data that is in range (the given index is always
1490          * in range even if the coordinate value is not).
1491          */
1492         if (grid.covers(coordState_.coordValue()))
1493         {
1494             /* Save PMF sum and keep a histogram of the sampled coordinate values */
1495             points_[gridPointIndex].samplePmf(convolvedBias);
1496         }
1497     }
1498
1499     /* Save probability weights for the update */
1500     sampleProbabilityWeights(grid, probWeightNeighbor);
1501 }
1502
1503 void BiasState::initHistoryFromState(AwhBiasHistory* biasHistory) const
1504 {
1505     biasHistory->pointState.resize(points_.size());
1506 }
1507
1508 void BiasState::updateHistory(AwhBiasHistory* biasHistory, const BiasGrid& grid) const
1509 {
1510     GMX_RELEASE_ASSERT(biasHistory->pointState.size() == points_.size(),
1511                        "The AWH history setup does not match the AWH state.");
1512
1513     AwhBiasStateHistory* stateHistory = &biasHistory->state;
1514     stateHistory->umbrellaGridpoint   = coordState_.umbrellaGridpoint();
1515
1516     for (size_t m = 0; m < biasHistory->pointState.size(); m++)
1517     {
1518         AwhPointStateHistory* psh = &biasHistory->pointState[m];
1519
1520         points_[m].storeState(psh);
1521
1522         psh->weightsum_covering = weightSumCovering_[m];
1523     }
1524
1525     histogramSize_.storeState(stateHistory);
1526
1527     stateHistory->origin_index_updatelist = multiDimGridIndexToLinear(grid, originUpdatelist_);
1528     stateHistory->end_index_updatelist    = multiDimGridIndexToLinear(grid, endUpdatelist_);
1529 }
1530
1531 void BiasState::restoreFromHistory(const AwhBiasHistory& biasHistory, const BiasGrid& grid)
1532 {
1533     const AwhBiasStateHistory& stateHistory = biasHistory.state;
1534
1535     coordState_.restoreFromHistory(stateHistory);
1536
1537     if (biasHistory.pointState.size() != points_.size())
1538     {
1539         GMX_THROW(
1540                 InvalidInputError("Bias grid size in checkpoint and simulation do not match. "
1541                                   "Likely you provided a checkpoint from a different simulation."));
1542     }
1543     for (size_t m = 0; m < points_.size(); m++)
1544     {
1545         points_[m].setFromHistory(biasHistory.pointState[m]);
1546     }
1547
1548     for (size_t m = 0; m < weightSumCovering_.size(); m++)
1549     {
1550         weightSumCovering_[m] = biasHistory.pointState[m].weightsum_covering;
1551     }
1552
1553     histogramSize_.restoreFromHistory(stateHistory);
1554
1555     linearGridindexToMultiDim(grid, stateHistory.origin_index_updatelist, originUpdatelist_);
1556     linearGridindexToMultiDim(grid, stateHistory.end_index_updatelist, endUpdatelist_);
1557 }
1558
1559 void BiasState::broadcast(const t_commrec* commRecord)
1560 {
1561     gmx_bcast(sizeof(coordState_), &coordState_, commRecord->mpi_comm_mygroup);
1562
1563     gmx_bcast(points_.size() * sizeof(PointState), points_.data(), commRecord->mpi_comm_mygroup);
1564
1565     gmx_bcast(weightSumCovering_.size() * sizeof(double), weightSumCovering_.data(), commRecord->mpi_comm_mygroup);
1566
1567     gmx_bcast(sizeof(histogramSize_), &histogramSize_, commRecord->mpi_comm_mygroup);
1568 }
1569
1570 void BiasState::setFreeEnergyToConvolvedPmf(ArrayRef<const DimParams> dimParams, const BiasGrid& grid)
1571 {
1572     std::vector<float> convolvedPmf;
1573
1574     calcConvolvedPmf(dimParams, grid, &convolvedPmf);
1575
1576     for (size_t m = 0; m < points_.size(); m++)
1577     {
1578         points_[m].setFreeEnergy(convolvedPmf[m]);
1579     }
1580 }
1581
1582 /*! \brief
1583  * Count trailing data rows containing only zeros.
1584  *
1585  * \param[in] data        2D data array.
1586  * \param[in] numRows     Number of rows in array.
1587  * \param[in] numColumns  Number of cols in array.
1588  * \returns the number of trailing zero rows.
1589  */
1590 static int countTrailingZeroRows(const double* const* data, int numRows, int numColumns)
1591 {
1592     int numZeroRows = 0;
1593     for (int m = numRows - 1; m >= 0; m--)
1594     {
1595         bool rowIsZero = true;
1596         for (int d = 0; d < numColumns; d++)
1597         {
1598             if (data[d][m] != 0)
1599             {
1600                 rowIsZero = false;
1601                 break;
1602             }
1603         }
1604
1605         if (!rowIsZero)
1606         {
1607             /* At a row with non-zero data */
1608             break;
1609         }
1610         else
1611         {
1612             /* Still at a zero data row, keep checking rows higher up. */
1613             numZeroRows++;
1614         }
1615     }
1616
1617     return numZeroRows;
1618 }
1619
1620 /*! \brief
1621  * Initializes the PMF and target with data read from an input table.
1622  *
1623  * \param[in]     dimParams   The dimension parameters.
1624  * \param[in]     grid        The grid.
1625  * \param[in]     filename    The filename to read PMF and target from.
1626  * \param[in]     numBias     Number of biases.
1627  * \param[in]     biasIndex   The index of the bias.
1628  * \param[in,out] pointState  The state of the points in this bias.
1629  */
1630 static void readUserPmfAndTargetDistribution(ArrayRef<const DimParams> dimParams,
1631                                              const BiasGrid&           grid,
1632                                              const std::string&        filename,
1633                                              int                       numBias,
1634                                              int                       biasIndex,
1635                                              std::vector<PointState>*  pointState)
1636 {
1637     /* Read the PMF and target distribution.
1638        From the PMF, the convolved PMF, or the reference value free energy, can be calculated
1639        base on the force constant. The free energy and target together determine the bias.
1640      */
1641     std::string filenameModified(filename);
1642     if (numBias > 1)
1643     {
1644         size_t n = filenameModified.rfind('.');
1645         GMX_RELEASE_ASSERT(n != std::string::npos,
1646                            "The filename should contain an extension starting with .");
1647         filenameModified.insert(n, formatString("%d", biasIndex));
1648     }
1649
1650     std::string correctFormatMessage = formatString(
1651             "%s is expected in the following format. "
1652             "The first ndim column(s) should contain the coordinate values for each point, "
1653             "each column containing values of one dimension (in ascending order). "
1654             "For a multidimensional coordinate, points should be listed "
1655             "in the order obtained by traversing lower dimensions first. "
1656             "E.g. for two-dimensional grid of size nxn: "
1657             "(1, 1), (1, 2),..., (1, n), (2, 1), (2, 2), ..., , (n, n - 1), (n, n). "
1658             "Column ndim +  1 should contain the PMF value for each coordinate value. "
1659             "The target distribution values should be in column ndim + 2  or column ndim + 5. "
1660             "Make sure the input file ends with a new line but has no trailing new lines.",
1661             filename.c_str());
1662     gmx::TextLineWrapper wrapper;
1663     wrapper.settings().setLineLength(c_linewidth);
1664     correctFormatMessage = wrapper.wrapToString(correctFormatMessage);
1665
1666     double** data;
1667     int      numColumns;
1668     int      numRows = read_xvg(filenameModified.c_str(), &data, &numColumns);
1669
1670     /* Check basic data properties here. BiasGrid takes care of more complicated things. */
1671
1672     if (numRows <= 0)
1673     {
1674         std::string mesg = gmx::formatString(
1675                 "%s is empty!.\n\n%s", filename.c_str(), correctFormatMessage.c_str());
1676         GMX_THROW(InvalidInputError(mesg));
1677     }
1678
1679     /* Less than 2 points is not useful for PMF or target. */
1680     if (numRows < 2)
1681     {
1682         std::string mesg = gmx::formatString(
1683                 "%s contains too few data points (%d)."
1684                 "The minimum number of points is 2.",
1685                 filename.c_str(),
1686                 numRows);
1687         GMX_THROW(InvalidInputError(mesg));
1688     }
1689
1690     /* Make sure there are enough columns of data.
1691
1692        Two formats are allowed. Either with columns  {coords, PMF, target} or
1693        {coords, PMF, x, y, z, target, ...}. The latter format is allowed since that
1694        is how AWH output is written (x, y, z being other AWH variables). For this format,
1695        trailing columns are ignored.
1696      */
1697     int columnIndexTarget;
1698     int numColumnsMin  = dimParams.size() + 2;
1699     int columnIndexPmf = dimParams.size();
1700     if (numColumns == numColumnsMin)
1701     {
1702         columnIndexTarget = columnIndexPmf + 1;
1703     }
1704     else
1705     {
1706         columnIndexTarget = columnIndexPmf + 4;
1707     }
1708
1709     if (numColumns < numColumnsMin)
1710     {
1711         std::string mesg = gmx::formatString(
1712                 "The number of columns in %s should be at least %d."
1713                 "\n\n%s",
1714                 filename.c_str(),
1715                 numColumnsMin,
1716                 correctFormatMessage.c_str());
1717         GMX_THROW(InvalidInputError(mesg));
1718     }
1719
1720     /* read_xvg can give trailing zero data rows for trailing new lines in the input. We allow 1 zero row,
1721        since this could be real data. But multiple trailing zero rows cannot correspond to valid data. */
1722     int numZeroRows = countTrailingZeroRows(data, numRows, numColumns);
1723     if (numZeroRows > 1)
1724     {
1725         std::string mesg = gmx::formatString(
1726                 "Found %d trailing zero data rows in %s. Please remove trailing empty lines and "
1727                 "try again.",
1728                 numZeroRows,
1729                 filename.c_str());
1730         GMX_THROW(InvalidInputError(mesg));
1731     }
1732
1733     /* Convert from user units to internal units before sending the data of to grid. */
1734     for (size_t d = 0; d < dimParams.size(); d++)
1735     {
1736         double scalingFactor = dimParams[d].scaleUserInputToInternal(1);
1737         if (scalingFactor == 1)
1738         {
1739             continue;
1740         }
1741         for (size_t m = 0; m < pointState->size(); m++)
1742         {
1743             data[d][m] *= scalingFactor;
1744         }
1745     }
1746
1747     /* Get a data point for each AWH grid point so that they all get data. */
1748     std::vector<int> gridIndexToDataIndex(grid.numPoints());
1749     mapGridToDataGrid(&gridIndexToDataIndex, data, numRows, filename, grid, correctFormatMessage);
1750
1751     /* Extract the data for each grid point.
1752      * We check if the target distribution is zero for all points.
1753      */
1754     bool targetDistributionIsZero = true;
1755     for (size_t m = 0; m < pointState->size(); m++)
1756     {
1757         (*pointState)[m].setLogPmfSum(-data[columnIndexPmf][gridIndexToDataIndex[m]]);
1758         double target = data[columnIndexTarget][gridIndexToDataIndex[m]];
1759
1760         /* Check if the values are allowed. */
1761         if (target < 0)
1762         {
1763             std::string mesg = gmx::formatString(
1764                     "Target distribution weight at point %zu (%g) in %s is negative.",
1765                     m,
1766                     target,
1767                     filename.c_str());
1768             GMX_THROW(InvalidInputError(mesg));
1769         }
1770         if (target > 0)
1771         {
1772             targetDistributionIsZero = false;
1773         }
1774         (*pointState)[m].setTargetConstantWeight(target);
1775     }
1776
1777     if (targetDistributionIsZero)
1778     {
1779         std::string mesg =
1780                 gmx::formatString("The target weights given in column %d in %s are all 0",
1781                                   columnIndexTarget,
1782                                   filename.c_str());
1783         GMX_THROW(InvalidInputError(mesg));
1784     }
1785
1786     /* Free the arrays. */
1787     for (int m = 0; m < numColumns; m++)
1788     {
1789         sfree(data[m]);
1790     }
1791     sfree(data);
1792 }
1793
1794 void BiasState::normalizePmf(int numSharingSims)
1795 {
1796     /* The normalization of the PMF estimate matters because it determines how big effect the next sample has.
1797        Approximately (for large enough force constant) we should have:
1798        sum_x(exp(-pmf(x)) = nsamples*sum_xref(exp(-f(xref)).
1799      */
1800
1801     /* Calculate the normalization factor, i.e. divide by the pmf sum, multiply by the number of samples and the f sum */
1802     double expSumPmf = 0;
1803     double expSumF   = 0;
1804     for (const PointState& pointState : points_)
1805     {
1806         if (pointState.inTargetRegion())
1807         {
1808             expSumPmf += std::exp(pointState.logPmfSum());
1809             expSumF += std::exp(-pointState.freeEnergy());
1810         }
1811     }
1812     double numSamples = histogramSize_.histogramSize() / numSharingSims;
1813
1814     /* Renormalize */
1815     double logRenorm = std::log(numSamples * expSumF / expSumPmf);
1816     for (PointState& pointState : points_)
1817     {
1818         if (pointState.inTargetRegion())
1819         {
1820             pointState.setLogPmfSum(pointState.logPmfSum() + logRenorm);
1821         }
1822     }
1823 }
1824
1825 void BiasState::initGridPointState(const AwhBiasParams&      awhBiasParams,
1826                                    ArrayRef<const DimParams> dimParams,
1827                                    const BiasGrid&           grid,
1828                                    const BiasParams&         params,
1829                                    const std::string&        filename,
1830                                    int                       numBias)
1831 {
1832     /* Modify PMF, free energy and the constant target distribution factor
1833      * to user input values if there is data given.
1834      */
1835     if (awhBiasParams.userPMFEstimate())
1836     {
1837         readUserPmfAndTargetDistribution(dimParams, grid, filename, numBias, params.biasIndex, &points_);
1838         setFreeEnergyToConvolvedPmf(dimParams, grid);
1839     }
1840
1841     /* The local Boltzmann distribution is special because the target distribution is updated as a function of the reference weighthistogram. */
1842     GMX_RELEASE_ASSERT(params.eTarget != AwhTargetType::LocalBoltzmann || points_[0].weightSumRef() != 0,
1843                        "AWH reference weight histogram not initialized properly with local "
1844                        "Boltzmann target distribution.");
1845
1846     updateTargetDistribution(points_, params);
1847
1848     for (PointState& pointState : points_)
1849     {
1850         if (pointState.inTargetRegion())
1851         {
1852             pointState.updateBias();
1853         }
1854         else
1855         {
1856             /* Note that for zero target this is a value that represents -infinity but should not be used for biasing. */
1857             pointState.setTargetToZero();
1858         }
1859     }
1860
1861     /* Set the initial reference weighthistogram. */
1862     const double histogramSize = histogramSize_.histogramSize();
1863     for (auto& pointState : points_)
1864     {
1865         pointState.setInitialReferenceWeightHistogram(histogramSize);
1866     }
1867
1868     /* Make sure the pmf is normalized consistently with the histogram size.
1869        Note: the target distribution and free energy need to be set here. */
1870     normalizePmf(params.numSharedUpdate);
1871 }
1872
1873 BiasState::BiasState(const AwhBiasParams&      awhBiasParams,
1874                      double                    histogramSizeInitial,
1875                      ArrayRef<const DimParams> dimParams,
1876                      const BiasGrid&           grid) :
1877     coordState_(awhBiasParams, dimParams, grid),
1878     points_(grid.numPoints()),
1879     weightSumCovering_(grid.numPoints()),
1880     histogramSize_(awhBiasParams, histogramSizeInitial)
1881 {
1882     /* The minimum and maximum multidimensional point indices that are affected by the next update */
1883     for (size_t d = 0; d < dimParams.size(); d++)
1884     {
1885         int index            = grid.point(coordState_.gridpointIndex()).index[d];
1886         originUpdatelist_[d] = index;
1887         endUpdatelist_[d]    = index;
1888     }
1889 }
1890
1891 } // namespace gmx