783cfab1f8a1335421cd098a76a2aca4c0597e70
[alexxy/gromacs.git] / src / gromacs / applied_forces / awh / biasstate.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2015,2016,2017,2018,2019, The GROMACS development team.
5  * Copyright (c) 2020,2021, by the GROMACS development team, led by
6  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
7  * and including many others, as listed in the AUTHORS file in the
8  * top-level source directory and at http://www.gromacs.org.
9  *
10  * GROMACS is free software; you can redistribute it and/or
11  * modify it under the terms of the GNU Lesser General Public License
12  * as published by the Free Software Foundation; either version 2.1
13  * of the License, or (at your option) any later version.
14  *
15  * GROMACS is distributed in the hope that it will be useful,
16  * but WITHOUT ANY WARRANTY; without even the implied warranty of
17  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18  * Lesser General Public License for more details.
19  *
20  * You should have received a copy of the GNU Lesser General Public
21  * License along with GROMACS; if not, see
22  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
23  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
24  *
25  * If you want to redistribute modifications to GROMACS, please
26  * consider that scientific software is very special. Version
27  * control is crucial - bugs must be traceable. We will be happy to
28  * consider code for inclusion in the official distribution, but
29  * derived work must not be called official GROMACS. Details are found
30  * in the README & COPYING files - if they are missing, get the
31  * official version at http://www.gromacs.org.
32  *
33  * To help us fund GROMACS development, we humbly ask that you cite
34  * the research papers on the package. Check out http://www.gromacs.org.
35  */
36
37 /*! \internal \file
38  * \brief
39  * Implements the BiasState class.
40  *
41  * \author Viveca Lindahl
42  * \author Berk Hess <hess@kth.se>
43  * \ingroup module_awh
44  */
45
46 #include "gmxpre.h"
47
48 #include "biasstate.h"
49
50 #include <cassert>
51 #include <cmath>
52 #include <cstdio>
53 #include <cstdlib>
54 #include <cstring>
55
56 #include <algorithm>
57 #include <optional>
58
59 #include "gromacs/fileio/gmxfio.h"
60 #include "gromacs/fileio/xvgr.h"
61 #include "gromacs/gmxlib/network.h"
62 #include "gromacs/math/utilities.h"
63 #include "gromacs/mdrunutility/multisim.h"
64 #include "gromacs/mdtypes/awh_history.h"
65 #include "gromacs/mdtypes/awh_params.h"
66 #include "gromacs/mdtypes/commrec.h"
67 #include "gromacs/simd/simd.h"
68 #include "gromacs/simd/simd_math.h"
69 #include "gromacs/utility/arrayref.h"
70 #include "gromacs/utility/exceptions.h"
71 #include "gromacs/utility/gmxassert.h"
72 #include "gromacs/utility/smalloc.h"
73 #include "gromacs/utility/stringutil.h"
74
75 #include "biasgrid.h"
76 #include "pointstate.h"
77
78 namespace gmx
79 {
80
81 void BiasState::getPmf(gmx::ArrayRef<float> pmf) const
82 {
83     GMX_ASSERT(pmf.size() == points_.size(), "pmf should have the size of the bias grid");
84
85     /* The PMF is just the negative of the log of the sampled PMF histogram.
86      * Points with zero target weight are ignored, they will mostly contain noise.
87      */
88     for (size_t i = 0; i < points_.size(); i++)
89     {
90         pmf[i] = points_[i].inTargetRegion() ? -points_[i].logPmfSum() : GMX_FLOAT_MAX;
91     }
92 }
93
94 namespace
95 {
96
97 /*! \brief
98  * Sum an array over all simulations on the master rank of each simulation.
99  *
100  * \param[in,out] arrayRef      The data to sum.
101  * \param[in]     multiSimComm  Struct for multi-simulation communication.
102  */
103 void sumOverSimulations(gmx::ArrayRef<int> arrayRef, const gmx_multisim_t* multiSimComm)
104 {
105     gmx_sumi_sim(arrayRef.size(), arrayRef.data(), multiSimComm);
106 }
107
108 /*! \brief
109  * Sum an array over all simulations on the master rank of each simulation.
110  *
111  * \param[in,out] arrayRef      The data to sum.
112  * \param[in]     multiSimComm  Struct for multi-simulation communication.
113  */
114 void sumOverSimulations(gmx::ArrayRef<double> arrayRef, const gmx_multisim_t* multiSimComm)
115 {
116     gmx_sumd_sim(arrayRef.size(), arrayRef.data(), multiSimComm);
117 }
118
119 /*! \brief
120  * Sum an array over all simulations on all ranks of each simulation.
121  *
122  * This assumes the data is identical on all ranks within each simulation.
123  *
124  * \param[in,out] arrayRef      The data to sum.
125  * \param[in]     commRecord    Struct for intra-simulation communication.
126  * \param[in]     multiSimComm  Struct for multi-simulation communication.
127  */
128 template<typename T>
129 void sumOverSimulations(gmx::ArrayRef<T> arrayRef, const t_commrec* commRecord, const gmx_multisim_t* multiSimComm)
130 {
131     if (MASTER(commRecord))
132     {
133         sumOverSimulations(arrayRef, multiSimComm);
134     }
135     if (commRecord->nnodes > 1)
136     {
137         gmx_bcast(arrayRef.size() * sizeof(T), arrayRef.data(), commRecord->mpi_comm_mygroup);
138     }
139 }
140
141 /*! \brief
142  * Sum PMF over multiple simulations, when requested.
143  *
144  * \param[in,out] pointState         The state of the points in the bias.
145  * \param[in]     numSharedUpdate    The number of biases sharing the histogram.
146  * \param[in]     commRecord         Struct for intra-simulation communication.
147  * \param[in]     multiSimComm       Struct for multi-simulation communication.
148  */
149 void sumPmf(gmx::ArrayRef<PointState> pointState,
150             int                       numSharedUpdate,
151             const t_commrec*          commRecord,
152             const gmx_multisim_t*     multiSimComm)
153 {
154     if (numSharedUpdate == 1)
155     {
156         return;
157     }
158     GMX_ASSERT(multiSimComm != nullptr && numSharedUpdate % multiSimComm->numSimulations_ == 0,
159                "numSharedUpdate should be a multiple of multiSimComm->numSimulations_");
160     GMX_ASSERT(numSharedUpdate == multiSimComm->numSimulations_,
161                "Sharing within a simulation is not implemented (yet)");
162
163     std::vector<double> buffer(pointState.size());
164
165     /* Need to temporarily exponentiate the log weights to sum over simulations */
166     for (size_t i = 0; i < buffer.size(); i++)
167     {
168         buffer[i] = pointState[i].inTargetRegion() ? std::exp(pointState[i].logPmfSum()) : 0;
169     }
170
171     sumOverSimulations(gmx::ArrayRef<double>(buffer), commRecord, multiSimComm);
172
173     /* Take log again to get (non-normalized) PMF */
174     double normFac = 1.0 / numSharedUpdate;
175     for (gmx::index i = 0; i < pointState.ssize(); i++)
176     {
177         if (pointState[i].inTargetRegion())
178         {
179             pointState[i].setLogPmfSum(std::log(buffer[i] * normFac));
180         }
181     }
182 }
183
184 /*! \brief
185  * Find the minimum free energy value.
186  *
187  * \param[in] pointState  The state of the points.
188  * \returns the minimum free energy value.
189  */
190 double freeEnergyMinimumValue(gmx::ArrayRef<const PointState> pointState)
191 {
192     double fMin = GMX_FLOAT_MAX;
193
194     for (auto const& ps : pointState)
195     {
196         if (ps.inTargetRegion() && ps.freeEnergy() < fMin)
197         {
198             fMin = ps.freeEnergy();
199         }
200     }
201
202     return fMin;
203 }
204
205 /*! \brief
206  * Find and return the log of the probability weight of a point given a coordinate value.
207  *
208  * The unnormalized weight is given by
209  * w(point|value) = exp(bias(point) - U(value,point)),
210  * where U is a harmonic umbrella potential.
211  *
212  * \param[in] dimParams              The bias dimensions parameters
213  * \param[in] points                 The point state.
214  * \param[in] grid                   The grid.
215  * \param[in] pointIndex             Point to evaluate probability weight for.
216  * \param[in] pointBias              Bias for the point (as a log weight).
217  * \param[in] value                  Coordinate value.
218  * \param[in] neighborLambdaEnergies The energy of the system in neighboring lambdas states. Can be
219  * empty when there are no free energy lambda state dimensions.
220  * \param[in] gridpointIndex         The index of the current grid point.
221  * \returns the log of the biased probability weight.
222  */
223 double biasedLogWeightFromPoint(const std::vector<DimParams>&  dimParams,
224                                 const std::vector<PointState>& points,
225                                 const BiasGrid&                grid,
226                                 int                            pointIndex,
227                                 double                         pointBias,
228                                 const awh_dvec                 value,
229                                 gmx::ArrayRef<const double>    neighborLambdaEnergies,
230                                 int                            gridpointIndex)
231 {
232     double logWeight = detail::c_largeNegativeExponent;
233
234     /* Only points in the target region have non-zero weight */
235     if (points[pointIndex].inTargetRegion())
236     {
237         logWeight = pointBias;
238
239         /* Add potential for all parameter dimensions */
240         for (size_t d = 0; d < dimParams.size(); d++)
241         {
242             if (dimParams[d].isFepLambdaDimension())
243             {
244                 /* If this is not a sampling step or if this function is called from
245                  * calcConvolvedBias(), when writing energy subblocks, neighborLambdaEnergies will
246                  * be empty. No convolution is required along the lambda dimension. */
247                 if (!neighborLambdaEnergies.empty())
248                 {
249                     const int pointLambdaIndex     = grid.point(pointIndex).coordValue[d];
250                     const int gridpointLambdaIndex = grid.point(gridpointIndex).coordValue[d];
251                     logWeight -= dimParams[d].fepDimParams().beta
252                                  * (neighborLambdaEnergies[pointLambdaIndex]
253                                     - neighborLambdaEnergies[gridpointLambdaIndex]);
254                 }
255             }
256             else
257             {
258                 double dev = getDeviationFromPointAlongGridAxis(grid, d, pointIndex, value[d]);
259                 logWeight -= 0.5 * dimParams[d].pullDimParams().betak * dev * dev;
260             }
261         }
262     }
263     return logWeight;
264 }
265
266 /*! \brief
267  * Calculates the marginal distribution (marginal probability) for each value along
268  * a free energy lambda axis.
269  * The marginal distribution of one coordinate dimension value is the sum of the probability
270  * distribution of all values (herein all neighbor values) with the same value in the dimension
271  * of interest.
272  * \param[in] grid               The bias grid.
273  * \param[in] neighbors          The points to use for the calculation of the marginal distribution.
274  * \param[in] probWeightNeighbor Probability weights of the neighbors.
275  * \returns The calculated marginal distribution in a 1D array with
276  * as many elements as there are points along the axis of interest.
277  */
278 std::vector<double> calculateFELambdaMarginalDistribution(const BiasGrid&          grid,
279                                                           gmx::ArrayRef<const int> neighbors,
280                                                           gmx::ArrayRef<const double> probWeightNeighbor)
281 {
282     const std::optional<int> lambdaAxisIndex = grid.lambdaAxisIndex();
283     GMX_RELEASE_ASSERT(lambdaAxisIndex,
284                        "There must be a free energy lambda axis in order to calculate the free "
285                        "energy lambda marginal distribution.");
286     const int           numFepLambdaStates = grid.numFepLambdaStates();
287     std::vector<double> lambdaMarginalDistribution(numFepLambdaStates, 0);
288
289     for (size_t i = 0; i < neighbors.size(); i++)
290     {
291         const int neighbor    = neighbors[i];
292         const int lambdaState = grid.point(neighbor).coordValue[lambdaAxisIndex.value()];
293         lambdaMarginalDistribution[lambdaState] += probWeightNeighbor[i];
294     }
295     return lambdaMarginalDistribution;
296 }
297
298 } // namespace
299
300 void BiasState::calcConvolvedPmf(const std::vector<DimParams>& dimParams,
301                                  const BiasGrid&               grid,
302                                  std::vector<float>*           convolvedPmf) const
303 {
304     size_t numPoints = grid.numPoints();
305
306     convolvedPmf->resize(numPoints);
307
308     /* Get the PMF to convolve. */
309     std::vector<float> pmf(numPoints);
310     getPmf(pmf);
311
312     for (size_t m = 0; m < numPoints; m++)
313     {
314         double           freeEnergyWeights = 0;
315         const GridPoint& point             = grid.point(m);
316         for (auto& neighbor : point.neighbor)
317         {
318             /* Do not convolve the bias along a lambda axis - only use the pmf from the current point */
319             if (!pointsHaveDifferentLambda(grid, m, neighbor))
320             {
321                 /* The negative PMF is a positive bias. */
322                 double biasNeighbor = -pmf[neighbor];
323
324                 /* Add the convolved PMF weights for the neighbors of this point.
325                 Note that this function only adds point within the target > 0 region.
326                 Sum weights, take the logarithm last to get the free energy. */
327                 double logWeight = biasedLogWeightFromPoint(
328                         dimParams, points_, grid, neighbor, biasNeighbor, point.coordValue, {}, m);
329                 freeEnergyWeights += std::exp(logWeight);
330             }
331         }
332
333         GMX_RELEASE_ASSERT(freeEnergyWeights > 0,
334                            "Attempting to do log(<= 0) in AWH convolved PMF calculation.");
335         (*convolvedPmf)[m] = -std::log(static_cast<float>(freeEnergyWeights));
336     }
337 }
338
339 namespace
340 {
341
342 /*! \brief
343  * Updates the target distribution for all points.
344  *
345  * The target distribution is always updated for all points
346  * at the same time.
347  *
348  * \param[in,out] pointState  The state of all points.
349  * \param[in]     params      The bias parameters.
350  */
351 void updateTargetDistribution(gmx::ArrayRef<PointState> pointState, const BiasParams& params)
352 {
353     double freeEnergyCutoff = 0;
354     if (params.eTarget == AwhTargetType::Cutoff)
355     {
356         freeEnergyCutoff = freeEnergyMinimumValue(pointState) + params.freeEnergyCutoffInKT;
357     }
358
359     double sumTarget = 0;
360     for (PointState& ps : pointState)
361     {
362         sumTarget += ps.updateTargetWeight(params, freeEnergyCutoff);
363     }
364     GMX_RELEASE_ASSERT(sumTarget > 0, "We should have a non-zero distribution");
365
366     /* Normalize to 1 */
367     double invSum = 1.0 / sumTarget;
368     for (PointState& ps : pointState)
369     {
370         ps.scaleTarget(invSum);
371     }
372 }
373
374 /*! \brief
375  * Puts together a string describing a grid point.
376  *
377  * \param[in] grid         The grid.
378  * \param[in] point        BiasGrid point index.
379  * \returns a string for the point.
380  */
381 std::string gridPointValueString(const BiasGrid& grid, int point)
382 {
383     std::string pointString;
384
385     pointString += "(";
386
387     for (int d = 0; d < grid.numDimensions(); d++)
388     {
389         pointString += gmx::formatString("%g", grid.point(point).coordValue[d]);
390         if (d < grid.numDimensions() - 1)
391         {
392             pointString += ",";
393         }
394         else
395         {
396             pointString += ")";
397         }
398     }
399
400     return pointString;
401 }
402
403 } // namespace
404
405 int BiasState::warnForHistogramAnomalies(const BiasGrid& grid, int biasIndex, double t, FILE* fplog, int maxNumWarnings) const
406 {
407     GMX_ASSERT(fplog != nullptr, "Warnings can only be issued if there is log file.");
408     const double maxHistogramRatio = 0.5; /* Tolerance for printing a warning about the histogram ratios */
409
410     /* Sum up the histograms and get their normalization */
411     double sumVisits  = 0;
412     double sumWeights = 0;
413     for (auto& pointState : points_)
414     {
415         if (pointState.inTargetRegion())
416         {
417             sumVisits += pointState.numVisitsTot();
418             sumWeights += pointState.weightSumTot();
419         }
420     }
421     GMX_RELEASE_ASSERT(sumVisits > 0, "We should have visits");
422     GMX_RELEASE_ASSERT(sumWeights > 0, "We should have weight");
423     double invNormVisits = 1.0 / sumVisits;
424     double invNormWeight = 1.0 / sumWeights;
425
426     /* Check all points for warnings */
427     int    numWarnings = 0;
428     size_t numPoints   = grid.numPoints();
429     for (size_t m = 0; m < numPoints; m++)
430     {
431         /* Skip points close to boundary or non-target region */
432         const GridPoint& gridPoint = grid.point(m);
433         bool             skipPoint = false;
434         for (size_t n = 0; (n < gridPoint.neighbor.size()) && !skipPoint; n++)
435         {
436             int neighbor = gridPoint.neighbor[n];
437             skipPoint    = !points_[neighbor].inTargetRegion();
438             for (int d = 0; (d < grid.numDimensions()) && !skipPoint; d++)
439             {
440                 const GridPoint& neighborPoint = grid.point(neighbor);
441                 skipPoint                      = neighborPoint.index[d] == 0
442                             || neighborPoint.index[d] == grid.axis(d).numPoints() - 1;
443             }
444         }
445
446         /* Warn if the coordinate distribution is less than the target distribution with a certain fraction somewhere */
447         const double relativeWeight = points_[m].weightSumTot() * invNormWeight;
448         const double relativeVisits = points_[m].numVisitsTot() * invNormVisits;
449         if (!skipPoint && relativeVisits < relativeWeight * maxHistogramRatio)
450         {
451             std::string pointValueString = gridPointValueString(grid, m);
452             std::string warningMessage   = gmx::formatString(
453                     "\nawh%d warning: "
454                     "at t = %g ps the obtained coordinate distribution at coordinate value %s "
455                     "is less than a fraction %g of the reference distribution at that point. "
456                     "If you are not certain about your settings you might want to increase your "
457                     "pull force constant or "
458                     "modify your sampling region.\n",
459                     biasIndex + 1,
460                     t,
461                     pointValueString.c_str(),
462                     maxHistogramRatio);
463             gmx::TextLineWrapper wrapper;
464             wrapper.settings().setLineLength(c_linewidth);
465             fprintf(fplog, "%s", wrapper.wrapToString(warningMessage).c_str());
466
467             numWarnings++;
468         }
469         if (numWarnings >= maxNumWarnings)
470         {
471             break;
472         }
473     }
474
475     return numWarnings;
476 }
477
478 double BiasState::calcUmbrellaForceAndPotential(const std::vector<DimParams>& dimParams,
479                                                 const BiasGrid&               grid,
480                                                 int                           point,
481                                                 ArrayRef<const double>        neighborLambdaDhdl,
482                                                 gmx::ArrayRef<double>         force) const
483 {
484     double potential = 0;
485     for (size_t d = 0; d < dimParams.size(); d++)
486     {
487         if (dimParams[d].isFepLambdaDimension())
488         {
489             if (!neighborLambdaDhdl.empty())
490             {
491                 const int coordpointLambdaIndex = grid.point(point).coordValue[d];
492                 force[d]                        = neighborLambdaDhdl[coordpointLambdaIndex];
493                 /* The potential should not be affected by the lambda dimension. */
494             }
495         }
496         else
497         {
498             double deviation =
499                     getDeviationFromPointAlongGridAxis(grid, d, point, coordState_.coordValue()[d]);
500             double k = dimParams[d].pullDimParams().k;
501
502             /* Force from harmonic potential 0.5*k*dev^2 */
503             force[d] = -k * deviation;
504             potential += 0.5 * k * deviation * deviation;
505         }
506     }
507
508     return potential;
509 }
510
511 void BiasState::calcConvolvedForce(const std::vector<DimParams>& dimParams,
512                                    const BiasGrid&               grid,
513                                    gmx::ArrayRef<const double>   probWeightNeighbor,
514                                    ArrayRef<const double>        neighborLambdaDhdl,
515                                    gmx::ArrayRef<double>         forceWorkBuffer,
516                                    gmx::ArrayRef<double>         force) const
517 {
518     for (size_t d = 0; d < dimParams.size(); d++)
519     {
520         force[d] = 0;
521     }
522
523     /* Only neighboring points have non-negligible contribution. */
524     const std::vector<int>& neighbor          = grid.point(coordState_.gridpointIndex()).neighbor;
525     gmx::ArrayRef<double>   forceFromNeighbor = forceWorkBuffer;
526     for (size_t n = 0; n < neighbor.size(); n++)
527     {
528         double weightNeighbor = probWeightNeighbor[n];
529         int    indexNeighbor  = neighbor[n];
530
531         /* Get the umbrella force from this point. The returned potential is ignored here. */
532         calcUmbrellaForceAndPotential(dimParams, grid, indexNeighbor, neighborLambdaDhdl, forceFromNeighbor);
533
534         /* Add the weighted umbrella force to the convolved force. */
535         for (size_t d = 0; d < dimParams.size(); d++)
536         {
537             force[d] += forceFromNeighbor[d] * weightNeighbor;
538         }
539     }
540 }
541
542 double BiasState::moveUmbrella(const std::vector<DimParams>& dimParams,
543                                const BiasGrid&               grid,
544                                gmx::ArrayRef<const double>   probWeightNeighbor,
545                                ArrayRef<const double>        neighborLambdaDhdl,
546                                gmx::ArrayRef<double>         biasForce,
547                                int64_t                       step,
548                                int64_t                       seed,
549                                int                           indexSeed,
550                                bool                          onlySampleUmbrellaGridpoint)
551 {
552     /* Generate and set a new coordinate reference value */
553     coordState_.sampleUmbrellaGridpoint(
554             grid, coordState_.gridpointIndex(), probWeightNeighbor, step, seed, indexSeed);
555
556     if (onlySampleUmbrellaGridpoint)
557     {
558         return 0;
559     }
560
561     std::vector<double> newForce(dimParams.size());
562     double              newPotential = calcUmbrellaForceAndPotential(
563             dimParams, grid, coordState_.umbrellaGridpoint(), neighborLambdaDhdl, newForce);
564
565     /*  A modification of the reference value at time t will lead to a different
566         force over t-dt/2 to t and over t to t+dt/2. For high switching rates
567         this means the force and velocity will change signs roughly as often.
568         To avoid any issues we take the average of the previous and new force
569         at steps when the reference value has been moved. E.g. if the ref. value
570         is set every step to (coord dvalue +/- delta) would give zero force.
571      */
572     for (gmx::index d = 0; d < biasForce.ssize(); d++)
573     {
574         /* Average of the current and new force */
575         biasForce[d] = 0.5 * (biasForce[d] + newForce[d]);
576     }
577
578     return newPotential;
579 }
580
581 namespace
582 {
583
584 /*! \brief
585  * Sets the histogram rescaling factors needed to control the histogram size.
586  *
587  * For sake of robustness, the reference weight histogram can grow at a rate
588  * different from the actual sampling rate. Typically this happens for a limited
589  * initial time, alternatively growth is scaled down by a constant factor for all
590  * times. Since the size of the reference histogram sets the size of the free
591  * energy update this should be reflected also in the PMF. Thus the PMF histogram
592  * needs to be rescaled too.
593  *
594  * This function should only be called by the bias update function or wrapped by a function that
595  * knows what scale factors should be applied when, e.g,
596  * getSkippedUpdateHistogramScaleFactors().
597  *
598  * \param[in]  params             The bias parameters.
599  * \param[in]  newHistogramSize   New reference weight histogram size.
600  * \param[in]  oldHistogramSize   Previous reference weight histogram size (before adding new samples).
601  * \param[out] weightHistScaling  Scaling factor for the reference weight histogram.
602  * \param[out] logPmfSumScaling   Log of the scaling factor for the PMF histogram.
603  */
604 void setHistogramUpdateScaleFactors(const BiasParams& params,
605                                     double            newHistogramSize,
606                                     double            oldHistogramSize,
607                                     double*           weightHistScaling,
608                                     double*           logPmfSumScaling)
609 {
610
611     /* The two scaling factors below are slightly different (ignoring the log factor) because the
612        reference and the PMF histogram apply weight scaling differently. The weight histogram
613        applies is  locally, i.e. each sample is scaled down meaning all samples get equal weight.
614        It is done this way because that is what target type local Boltzmann (for which
615        target = weight histogram) needs. In contrast, the PMF histogram is rescaled globally
616        by repeatedly scaling down the whole histogram. The reasons for doing it this way are:
617        1) empirically this is necessary for converging the PMF; 2) since the extraction of
618        the PMF is theoretically only valid for a constant bias, new samples should get more
619        weight than old ones for which the bias is fluctuating more. */
620     *weightHistScaling =
621             newHistogramSize / (oldHistogramSize + params.updateWeight * params.localWeightScaling);
622     *logPmfSumScaling = std::log(newHistogramSize / (oldHistogramSize + params.updateWeight));
623 }
624
625 } // namespace
626
627 void BiasState::getSkippedUpdateHistogramScaleFactors(const BiasParams& params,
628                                                       double*           weightHistScaling,
629                                                       double*           logPmfSumScaling) const
630 {
631     GMX_ASSERT(params.skipUpdates(),
632                "Calling function for skipped updates when skipping updates is not allowed");
633
634     if (inInitialStage())
635     {
636         /* In between global updates the reference histogram size is kept constant so we trivially
637            know what the histogram size was at the time of the skipped update. */
638         double histogramSize = histogramSize_.histogramSize();
639         setHistogramUpdateScaleFactors(
640                 params, histogramSize, histogramSize, weightHistScaling, logPmfSumScaling);
641     }
642     else
643     {
644         /* In the final stage, the reference histogram grows at the sampling rate which gives trivial scale factors. */
645         *weightHistScaling = 1;
646         *logPmfSumScaling  = 0;
647     }
648 }
649
650 void BiasState::doSkippedUpdatesForAllPoints(const BiasParams& params)
651 {
652     double weightHistScaling;
653     double logPmfsumScaling;
654
655     getSkippedUpdateHistogramScaleFactors(params, &weightHistScaling, &logPmfsumScaling);
656
657     for (auto& pointState : points_)
658     {
659         bool didUpdate = pointState.performPreviouslySkippedUpdates(
660                 params, histogramSize_.numUpdates(), weightHistScaling, logPmfsumScaling);
661
662         /* Update the bias for this point only if there were skipped updates in the past to avoid calculating the log unneccessarily */
663         if (didUpdate)
664         {
665             pointState.updateBias();
666         }
667     }
668 }
669
670 void BiasState::doSkippedUpdatesInNeighborhood(const BiasParams& params, const BiasGrid& grid)
671 {
672     double weightHistScaling;
673     double logPmfsumScaling;
674
675     getSkippedUpdateHistogramScaleFactors(params, &weightHistScaling, &logPmfsumScaling);
676
677     /* For each neighbor point of the center point, refresh its state by adding the results of all past, skipped updates. */
678     const std::vector<int>& neighbors = grid.point(coordState_.gridpointIndex()).neighbor;
679     for (auto& neighbor : neighbors)
680     {
681         bool didUpdate = points_[neighbor].performPreviouslySkippedUpdates(
682                 params, histogramSize_.numUpdates(), weightHistScaling, logPmfsumScaling);
683
684         if (didUpdate)
685         {
686             points_[neighbor].updateBias();
687         }
688     }
689 }
690
691 namespace
692 {
693
694 /*! \brief
695  * Merge update lists from multiple sharing simulations.
696  *
697  * \param[in,out] updateList    Update list for this simulation (assumed >= npoints long).
698  * \param[in]     numPoints     Total number of points.
699  * \param[in]     commRecord    Struct for intra-simulation communication.
700  * \param[in]     multiSimComm  Struct for multi-simulation communication.
701  */
702 void mergeSharedUpdateLists(std::vector<int>*     updateList,
703                             int                   numPoints,
704                             const t_commrec*      commRecord,
705                             const gmx_multisim_t* multiSimComm)
706 {
707     std::vector<int> numUpdatesOfPoint;
708
709     /* Flag the update points of this sim.
710        TODO: we can probably avoid allocating this array and just use the input array. */
711     numUpdatesOfPoint.resize(numPoints, 0);
712     for (auto& pointIndex : *updateList)
713     {
714         numUpdatesOfPoint[pointIndex] = 1;
715     }
716
717     /* Sum over the sims to get all the flagged points */
718     sumOverSimulations(arrayRefFromArray(numUpdatesOfPoint.data(), numPoints), commRecord, multiSimComm);
719
720     /* Collect the indices of the flagged points in place. The resulting array will be the merged update list.*/
721     updateList->clear();
722     for (int m = 0; m < numPoints; m++)
723     {
724         if (numUpdatesOfPoint[m] > 0)
725         {
726             updateList->push_back(m);
727         }
728     }
729 }
730
731 /*! \brief
732  * Generate an update list of points sampled since the last update.
733  *
734  * \param[in] grid              The AWH bias.
735  * \param[in] points            The point state.
736  * \param[in] originUpdatelist  The origin of the rectangular region that has been sampled since
737  * last update. \param[in] endUpdatelist     The end of the rectangular that has been sampled since
738  * last update. \param[in,out] updateList    Local update list to set (assumed >= npoints long).
739  */
740 void makeLocalUpdateList(const BiasGrid&                grid,
741                          const std::vector<PointState>& points,
742                          const awh_ivec                 originUpdatelist,
743                          const awh_ivec                 endUpdatelist,
744                          std::vector<int>*              updateList)
745 {
746     awh_ivec origin;
747     awh_ivec numPoints;
748
749     /* Define the update search grid */
750     for (int d = 0; d < grid.numDimensions(); d++)
751     {
752         origin[d]    = originUpdatelist[d];
753         numPoints[d] = endUpdatelist[d] - originUpdatelist[d] + 1;
754
755         /* Because the end_updatelist is unwrapped it can be > (npoints - 1) so that numPoints can be > npoints in grid.
756            This helps for calculating the distance/number of points but should be removed and fixed when the way of
757            updating origin/end updatelist is changed (see sampleProbabilityWeights). */
758         numPoints[d] = std::min(grid.axis(d).numPoints(), numPoints[d]);
759     }
760
761     /* Make the update list */
762     updateList->clear();
763     int  pointIndex  = -1;
764     bool pointExists = true;
765     while (pointExists)
766     {
767         pointExists = advancePointInSubgrid(grid, origin, numPoints, &pointIndex);
768
769         if (pointExists && points[pointIndex].inTargetRegion())
770         {
771             updateList->push_back(pointIndex);
772         }
773     }
774 }
775
776 } // namespace
777
778 void BiasState::resetLocalUpdateRange(const BiasGrid& grid)
779 {
780     const int gridpointIndex = coordState_.gridpointIndex();
781     for (int d = 0; d < grid.numDimensions(); d++)
782     {
783         /* This gives the  minimum range consisting only of the current closest point. */
784         originUpdatelist_[d] = grid.point(gridpointIndex).index[d];
785         endUpdatelist_[d]    = grid.point(gridpointIndex).index[d];
786     }
787 }
788
789 namespace
790 {
791
792 /*! \brief
793  * Add partial histograms (accumulating between updates) to accumulating histograms.
794  *
795  * \param[in,out] pointState         The state of the points in the bias.
796  * \param[in,out] weightSumCovering  The weights for checking covering.
797  * \param[in]     numSharedUpdate    The number of biases sharing the histrogram.
798  * \param[in]     commRecord         Struct for intra-simulation communication.
799  * \param[in]     multiSimComm       Struct for multi-simulation communication.
800  * \param[in]     localUpdateList    List of points with data.
801  */
802 void sumHistograms(gmx::ArrayRef<PointState> pointState,
803                    gmx::ArrayRef<double>     weightSumCovering,
804                    int                       numSharedUpdate,
805                    const t_commrec*          commRecord,
806                    const gmx_multisim_t*     multiSimComm,
807                    const std::vector<int>&   localUpdateList)
808 {
809     /* The covering checking histograms are added before summing over simulations, so that the
810        weights from different simulations are kept distinguishable. */
811     for (int globalIndex : localUpdateList)
812     {
813         weightSumCovering[globalIndex] += pointState[globalIndex].weightSumIteration();
814     }
815
816     /* Sum histograms over multiple simulations if needed. */
817     if (numSharedUpdate > 1)
818     {
819         GMX_ASSERT(numSharedUpdate == multiSimComm->numSimulations_,
820                    "Sharing within a simulation is not implemented (yet)");
821
822         /* Collect the weights and counts in linear arrays to be able to use gmx_sumd_sim. */
823         std::vector<double> weightSum;
824         std::vector<double> coordVisits;
825
826         weightSum.resize(localUpdateList.size());
827         coordVisits.resize(localUpdateList.size());
828
829         for (size_t localIndex = 0; localIndex < localUpdateList.size(); localIndex++)
830         {
831             const PointState& ps = pointState[localUpdateList[localIndex]];
832
833             weightSum[localIndex]   = ps.weightSumIteration();
834             coordVisits[localIndex] = ps.numVisitsIteration();
835         }
836
837         sumOverSimulations(gmx::ArrayRef<double>(weightSum), commRecord, multiSimComm);
838         sumOverSimulations(gmx::ArrayRef<double>(coordVisits), commRecord, multiSimComm);
839
840         /* Transfer back the result */
841         for (size_t localIndex = 0; localIndex < localUpdateList.size(); localIndex++)
842         {
843             PointState& ps = pointState[localUpdateList[localIndex]];
844
845             ps.setPartialWeightAndCount(weightSum[localIndex], coordVisits[localIndex]);
846         }
847     }
848
849     /* Now add the partial counts and weights to the accumulating histograms.
850        Note: we still need to use the weights for the update so we wait
851        with resetting them until the end of the update. */
852     for (int globalIndex : localUpdateList)
853     {
854         pointState[globalIndex].addPartialWeightAndCount();
855     }
856 }
857
858 /*! \brief
859  * Label points along an axis as covered or not.
860  *
861  * A point is covered if it is surrounded by visited points up to a radius = coverRadius.
862  *
863  * \param[in]     visited        Visited? For each point.
864  * \param[in]     checkCovering  Check for covering? For each point.
865  * \param[in]     numPoints      The number of grid points along this dimension.
866  * \param[in]     period         Period in number of points.
867  * \param[in]     coverRadius    Cover radius, in points, needed for defining a point as covered.
868  * \param[in,out] covered        In this array elements are 1 for covered points and 0 for
869  * non-covered points, this routine assumes that \p covered has at least size \p numPoints.
870  */
871 void labelCoveredPoints(const std::vector<bool>& visited,
872                         const std::vector<bool>& checkCovering,
873                         int                      numPoints,
874                         int                      period,
875                         int                      coverRadius,
876                         gmx::ArrayRef<int>       covered)
877 {
878     GMX_ASSERT(covered.ssize() >= numPoints, "covered should be at least as large as the grid");
879
880     bool haveFirstNotVisited = false;
881     int  firstNotVisited     = -1;
882     int  notVisitedLow       = -1;
883     int  notVisitedHigh      = -1;
884
885     for (int n = 0; n < numPoints; n++)
886     {
887         if (checkCovering[n] && !visited[n])
888         {
889             if (!haveFirstNotVisited)
890             {
891                 notVisitedLow       = n;
892                 firstNotVisited     = n;
893                 haveFirstNotVisited = true;
894             }
895             else
896             {
897                 notVisitedHigh = n;
898
899                 /* Have now an interval I = [notVisitedLow,notVisitedHigh] of visited points bounded
900                    by unvisited points. The unvisted end points affect the coveredness of the
901                    visited with a reach equal to the cover radius. */
902                 int notCoveredLow  = notVisitedLow + coverRadius;
903                 int notCoveredHigh = notVisitedHigh - coverRadius;
904                 for (int i = notVisitedLow; i <= notVisitedHigh; i++)
905                 {
906                     covered[i] = static_cast<int>((i > notCoveredLow) && (i < notCoveredHigh));
907                 }
908
909                 /* Find a new interval to set covering for. Make the notVisitedHigh of this interval
910                    the notVisitedLow of the next. */
911                 notVisitedLow = notVisitedHigh;
912             }
913         }
914     }
915
916     /* Have labelled all the internal points. Now take care of the boundary regions. */
917     if (!haveFirstNotVisited)
918     {
919         /* No non-visited points <=> all points visited => all points covered. */
920
921         for (int n = 0; n < numPoints; n++)
922         {
923             covered[n] = 1;
924         }
925     }
926     else
927     {
928         int lastNotVisited = notVisitedLow;
929
930         /* For periodic boundaries, non-visited points can influence points
931            on the other side of the boundary so we need to wrap around. */
932
933         /* Lower end. For periodic boundaries the last upper end not visited point becomes the low-end not visited point.
934            For non-periodic boundaries there is no lower end point so a dummy value is used. */
935         int notVisitedHigh = firstNotVisited;
936         int notVisitedLow  = period > 0 ? (lastNotVisited - period) : -(coverRadius + 1);
937
938         int notCoveredLow  = notVisitedLow + coverRadius;
939         int notCoveredHigh = notVisitedHigh - coverRadius;
940
941         for (int i = 0; i <= notVisitedHigh; i++)
942         {
943             /* For non-periodic boundaries notCoveredLow = -1 will impose no restriction. */
944             covered[i] = static_cast<int>((i > notCoveredLow) && (i < notCoveredHigh));
945         }
946
947         /* Upper end. Same as for lower end but in the other direction. */
948         notVisitedHigh = period > 0 ? (firstNotVisited + period) : (numPoints + coverRadius);
949         notVisitedLow  = lastNotVisited;
950
951         notCoveredLow  = notVisitedLow + coverRadius;
952         notCoveredHigh = notVisitedHigh - coverRadius;
953
954         for (int i = notVisitedLow; i <= numPoints - 1; i++)
955         {
956             /* For non-periodic boundaries notCoveredHigh = numPoints will impose no restriction. */
957             covered[i] = static_cast<int>((i > notCoveredLow) && (i < notCoveredHigh));
958         }
959     }
960 }
961
962 } // namespace
963
964 bool BiasState::isSamplingRegionCovered(const BiasParams&             params,
965                                         const std::vector<DimParams>& dimParams,
966                                         const BiasGrid&               grid,
967                                         const t_commrec*              commRecord,
968                                         const gmx_multisim_t*         multiSimComm) const
969 {
970     /* Allocate and initialize arrays: one for checking visits along each dimension,
971        one for keeping track of which points to check and one for the covered points.
972        Possibly these could be kept as AWH variables to avoid these allocations. */
973     struct CheckDim
974     {
975         std::vector<bool> visited;
976         std::vector<bool> checkCovering;
977         // We use int for the covering array since we might use gmx_sumi_sim.
978         std::vector<int> covered;
979     };
980
981     std::vector<CheckDim> checkDim;
982     checkDim.resize(grid.numDimensions());
983
984     for (int d = 0; d < grid.numDimensions(); d++)
985     {
986         const size_t numPoints = grid.axis(d).numPoints();
987         checkDim[d].visited.resize(numPoints, false);
988         checkDim[d].checkCovering.resize(numPoints, false);
989         checkDim[d].covered.resize(numPoints, 0);
990     }
991
992     /* Set visited points along each dimension and which points should be checked for covering.
993        Specifically, points above the free energy cutoff (if there is one) or points outside
994        of the target region are ignored. */
995
996     /* Set the free energy cutoff */
997     double maxFreeEnergy = GMX_FLOAT_MAX;
998
999     if (params.eTarget == AwhTargetType::Cutoff)
1000     {
1001         maxFreeEnergy = freeEnergyMinimumValue(points_) + params.freeEnergyCutoffInKT;
1002     }
1003
1004     /* Set the threshold weight for a point to be considered visited. */
1005     double weightThreshold = 1;
1006     for (int d = 0; d < grid.numDimensions(); d++)
1007     {
1008         if (grid.axis(d).isFepLambdaAxis())
1009         {
1010             /* Do not modify the weight threshold based on a FEP lambda axis. The spread
1011              * of the sampling weights is not depending on a Gaussian distribution (like
1012              * below). */
1013             weightThreshold *= 1.0;
1014         }
1015         else
1016         {
1017             /* The spacing is proportional to 1/sqrt(betak). The weight threshold will be
1018              * approximately (given that the spacing can be modified if the dimension is periodic)
1019              * proportional to sqrt(1/(2*pi)). */
1020             weightThreshold *= grid.axis(d).spacing()
1021                                * std::sqrt(dimParams[d].pullDimParams().betak * 0.5 * M_1_PI);
1022         }
1023     }
1024
1025     /* Project the sampling weights onto each dimension */
1026     for (size_t m = 0; m < grid.numPoints(); m++)
1027     {
1028         const PointState& pointState = points_[m];
1029
1030         for (int d = 0; d < grid.numDimensions(); d++)
1031         {
1032             int n = grid.point(m).index[d];
1033
1034             /* Is visited if it was already visited or if there is enough weight at the current point */
1035             checkDim[d].visited[n] = checkDim[d].visited[n] || (weightSumCovering_[m] > weightThreshold);
1036
1037             /* Check for covering if there is at least point in this slice that is in the target region and within the cutoff */
1038             checkDim[d].checkCovering[n] =
1039                     checkDim[d].checkCovering[n]
1040                     || (pointState.inTargetRegion() && pointState.freeEnergy() < maxFreeEnergy);
1041         }
1042     }
1043
1044     /* Label each point along each dimension as covered or not. */
1045     for (int d = 0; d < grid.numDimensions(); d++)
1046     {
1047         labelCoveredPoints(checkDim[d].visited,
1048                            checkDim[d].checkCovering,
1049                            grid.axis(d).numPoints(),
1050                            grid.axis(d).numPointsInPeriod(),
1051                            params.coverRadius()[d],
1052                            checkDim[d].covered);
1053     }
1054
1055     /* Now check for global covering. Each dimension needs to be covered separately.
1056        A dimension is covered if each point is covered.  Multiple simulations collectively
1057        cover the points, i.e. a point is covered if any of the simulations covered it.
1058        However, visited points are not shared, i.e. if a point is covered or not is
1059        determined by the visits of a single simulation. In general the covering criterion is
1060        all points covered => all points are surrounded by visited points up to a radius = coverRadius.
1061        For 1 simulation, all points covered <=> all points visited. For multiple simulations
1062        however, all points visited collectively !=> all points covered, except for coverRadius = 0.
1063        In the limit of large coverRadius, all points covered => all points visited by at least one
1064        simulation (since no point will be covered until all points have been visited by a
1065        single simulation). Basically coverRadius sets how much "connectedness" (or mixing) a point
1066        needs with surrounding points before sharing covering information with other simulations. */
1067
1068     /* Communicate the covered points between sharing simulations if needed. */
1069     if (params.numSharedUpdate > 1)
1070     {
1071         /* For multiple dimensions this may not be the best way to do it. */
1072         for (int d = 0; d < grid.numDimensions(); d++)
1073         {
1074             sumOverSimulations(
1075                     gmx::arrayRefFromArray(checkDim[d].covered.data(), grid.axis(d).numPoints()),
1076                     commRecord,
1077                     multiSimComm);
1078         }
1079     }
1080
1081     /* Now check if for each dimension all points are covered. Break if not true. */
1082     bool allPointsCovered = true;
1083     for (int d = 0; d < grid.numDimensions() && allPointsCovered; d++)
1084     {
1085         for (int n = 0; n < grid.axis(d).numPoints() && allPointsCovered; n++)
1086         {
1087             allPointsCovered = (checkDim[d].covered[n] != 0);
1088         }
1089     }
1090
1091     return allPointsCovered;
1092 }
1093
1094 /*! \brief
1095  * Normalizes the free energy and PMF sum.
1096  *
1097  * \param[in] pointState  The state of the points.
1098  */
1099 static void normalizeFreeEnergyAndPmfSum(std::vector<PointState>* pointState)
1100 {
1101     double minF = freeEnergyMinimumValue(*pointState);
1102
1103     for (PointState& ps : *pointState)
1104     {
1105         ps.normalizeFreeEnergyAndPmfSum(minF);
1106     }
1107 }
1108
1109 void BiasState::updateFreeEnergyAndAddSamplesToHistogram(const std::vector<DimParams>& dimParams,
1110                                                          const BiasGrid&               grid,
1111                                                          const BiasParams&             params,
1112                                                          const t_commrec*              commRecord,
1113                                                          const gmx_multisim_t*         multiSimComm,
1114                                                          double                        t,
1115                                                          int64_t                       step,
1116                                                          FILE*                         fplog,
1117                                                          std::vector<int>*             updateList)
1118 {
1119     /* Note hat updateList is only used in this scope and is always
1120      * re-initialized. We do not use a local vector, because that would
1121      * cause reallocation every time this funtion is called and the vector
1122      * can be the size of the whole grid.
1123      */
1124
1125     /* Make a list of all local points, i.e. those that could have been touched since
1126        the last update. These are the points needed for summing histograms below
1127        (non-local points only add zeros). For local updates, this will also be the
1128        final update list. */
1129     makeLocalUpdateList(grid, points_, originUpdatelist_, endUpdatelist_, updateList);
1130     if (params.numSharedUpdate > 1)
1131     {
1132         mergeSharedUpdateLists(updateList, points_.size(), commRecord, multiSimComm);
1133     }
1134
1135     /* Reset the range for the next update */
1136     resetLocalUpdateRange(grid);
1137
1138     /* Add samples to histograms for all local points and sync simulations if needed */
1139     sumHistograms(points_, weightSumCovering_, params.numSharedUpdate, commRecord, multiSimComm, *updateList);
1140
1141     sumPmf(points_, params.numSharedUpdate, commRecord, multiSimComm);
1142
1143     /* Renormalize the free energy if values are too large. */
1144     bool needToNormalizeFreeEnergy = false;
1145     for (int& globalIndex : *updateList)
1146     {
1147         /* We want to keep the absolute value of the free energies to be less
1148            c_largePositiveExponent to be able to safely pass these values to exp(). The check below
1149            ensures this as long as the free energy values grow less than 0.5*c_largePositiveExponent
1150            in a return time to this neighborhood. For reasonable update sizes it's unlikely that
1151            this requirement would be broken. */
1152         if (std::abs(points_[globalIndex].freeEnergy()) > 0.5 * detail::c_largePositiveExponent)
1153         {
1154             needToNormalizeFreeEnergy = true;
1155             break;
1156         }
1157     }
1158
1159     /* Update target distribution? */
1160     bool needToUpdateTargetDistribution =
1161             (params.eTarget != AwhTargetType::Constant && params.isUpdateTargetStep(step));
1162
1163     /* In the initial stage, the histogram grows dynamically as a function of the number of coverings. */
1164     bool detectedCovering = false;
1165     if (inInitialStage())
1166     {
1167         detectedCovering =
1168                 (params.isCheckCoveringStep(step)
1169                  && isSamplingRegionCovered(params, dimParams, grid, commRecord, multiSimComm));
1170     }
1171
1172     /* The weighthistogram size after this update. */
1173     double newHistogramSize = histogramSize_.newHistogramSize(
1174             params, t, detectedCovering, points_, weightSumCovering_, fplog);
1175
1176     /* Make the update list. Usually we try to only update local points,
1177      * but if the update has non-trivial or non-deterministic effects
1178      * on non-local points a global update is needed. This is the case when:
1179      * 1) a covering occurred in the initial stage, leading to non-trivial
1180      *    histogram rescaling factors; or
1181      * 2) the target distribution will be updated, since we don't make any
1182      *    assumption on its form; or
1183      * 3) the AWH parameters are such that we never attempt to skip non-local
1184      *    updates; or
1185      * 4) the free energy values have grown so large that a renormalization
1186      *    is needed.
1187      */
1188     if (needToUpdateTargetDistribution || detectedCovering || !params.skipUpdates() || needToNormalizeFreeEnergy)
1189     {
1190         /* Global update, just add all points. */
1191         updateList->clear();
1192         for (size_t m = 0; m < points_.size(); m++)
1193         {
1194             if (points_[m].inTargetRegion())
1195             {
1196                 updateList->push_back(m);
1197             }
1198         }
1199     }
1200
1201     /* Set histogram scale factors. */
1202     double weightHistScalingSkipped = 0;
1203     double logPmfsumScalingSkipped  = 0;
1204     if (params.skipUpdates())
1205     {
1206         getSkippedUpdateHistogramScaleFactors(params, &weightHistScalingSkipped, &logPmfsumScalingSkipped);
1207     }
1208     double weightHistScalingNew;
1209     double logPmfsumScalingNew;
1210     setHistogramUpdateScaleFactors(
1211             params, newHistogramSize, histogramSize_.histogramSize(), &weightHistScalingNew, &logPmfsumScalingNew);
1212
1213     /* Update free energy and reference weight histogram for points in the update list. */
1214     for (int pointIndex : *updateList)
1215     {
1216         PointState* pointStateToUpdate = &points_[pointIndex];
1217
1218         /* Do updates from previous update steps that were skipped because this point was at that time non-local. */
1219         if (params.skipUpdates())
1220         {
1221             pointStateToUpdate->performPreviouslySkippedUpdates(
1222                     params, histogramSize_.numUpdates(), weightHistScalingSkipped, logPmfsumScalingSkipped);
1223         }
1224
1225         /* Now do an update with new sampling data. */
1226         pointStateToUpdate->updateWithNewSampling(
1227                 params, histogramSize_.numUpdates(), weightHistScalingNew, logPmfsumScalingNew);
1228     }
1229
1230     /* Only update the histogram size after we are done with the local point updates */
1231     histogramSize_.setHistogramSize(newHistogramSize, weightHistScalingNew);
1232
1233     if (needToNormalizeFreeEnergy)
1234     {
1235         normalizeFreeEnergyAndPmfSum(&points_);
1236     }
1237
1238     if (needToUpdateTargetDistribution)
1239     {
1240         /* The target distribution is always updated for all points at once. */
1241         updateTargetDistribution(points_, params);
1242     }
1243
1244     /* Update the bias. The bias is updated separately and last since it simply a function of
1245        the free energy and the target distribution and we want to avoid doing extra work. */
1246     for (int pointIndex : *updateList)
1247     {
1248         points_[pointIndex].updateBias();
1249     }
1250
1251     /* Increase the update counter. */
1252     histogramSize_.incrementNumUpdates();
1253 }
1254
1255 double BiasState::updateProbabilityWeightsAndConvolvedBias(const std::vector<DimParams>& dimParams,
1256                                                            const BiasGrid&               grid,
1257                                                            gmx::ArrayRef<const double> neighborLambdaEnergies,
1258                                                            std::vector<double, AlignedAllocator<double>>* weight) const
1259 {
1260     /* Only neighbors of the current coordinate value will have a non-negligible chance of getting sampled */
1261     const std::vector<int>& neighbors = grid.point(coordState_.gridpointIndex()).neighbor;
1262
1263 #if GMX_SIMD_HAVE_DOUBLE
1264     typedef SimdDouble PackType;
1265     constexpr int      packSize = GMX_SIMD_DOUBLE_WIDTH;
1266 #else
1267     typedef double PackType;
1268     constexpr int  packSize = 1;
1269 #endif
1270     /* Round the size of the weight array up to packSize */
1271     const int weightSize = ((neighbors.size() + packSize - 1) / packSize) * packSize;
1272     weight->resize(weightSize);
1273
1274     double* gmx_restrict weightData = weight->data();
1275     PackType             weightSumPack(0.0);
1276     for (size_t i = 0; i < neighbors.size(); i += packSize)
1277     {
1278         for (size_t n = i; n < i + packSize; n++)
1279         {
1280             if (n < neighbors.size())
1281             {
1282                 const int neighbor = neighbors[n];
1283                 (*weight)[n]       = biasedLogWeightFromPoint(dimParams,
1284                                                         points_,
1285                                                         grid,
1286                                                         neighbor,
1287                                                         points_[neighbor].bias(),
1288                                                         coordState_.coordValue(),
1289                                                         neighborLambdaEnergies,
1290                                                         coordState_.gridpointIndex());
1291             }
1292             else
1293             {
1294                 /* Pad with values that don't affect the result */
1295                 (*weight)[n] = detail::c_largeNegativeExponent;
1296             }
1297         }
1298         PackType weightPack = load<PackType>(weightData + i);
1299         weightPack          = gmx::exp(weightPack);
1300         weightSumPack       = weightSumPack + weightPack;
1301         store(weightData + i, weightPack);
1302     }
1303     /* Sum of probability weights */
1304     double weightSum = reduce(weightSumPack);
1305     GMX_RELEASE_ASSERT(weightSum > 0,
1306                        "zero probability weight when updating AWH probability weights.");
1307
1308     /* Normalize probabilities to sum to 1 */
1309     double invWeightSum = 1 / weightSum;
1310
1311     /* When there is a free energy lambda state axis remove the convolved contributions along that
1312      * axis from the total bias. This must be done after calculating invWeightSum (since weightSum
1313      * will be modified), but before normalizing the weights (below). */
1314     if (grid.hasLambdaAxis())
1315     {
1316         /* If there is only one axis the bias will not be convolved in any dimension. */
1317         if (grid.axis().size() == 1)
1318         {
1319             weightSum = gmx::exp(points_[coordState_.gridpointIndex()].bias());
1320         }
1321         else
1322         {
1323             for (size_t i = 0; i < neighbors.size(); i++)
1324             {
1325                 const int neighbor = neighbors[i];
1326                 if (pointsHaveDifferentLambda(grid, coordState_.gridpointIndex(), neighbor))
1327                 {
1328                     weightSum -= weightData[i];
1329                 }
1330             }
1331         }
1332     }
1333
1334     for (double& w : *weight)
1335     {
1336         w *= invWeightSum;
1337     }
1338
1339     /* Return the convolved bias */
1340     return std::log(weightSum);
1341 }
1342
1343 double BiasState::calcConvolvedBias(const std::vector<DimParams>& dimParams,
1344                                     const BiasGrid&               grid,
1345                                     const awh_dvec&               coordValue) const
1346 {
1347     int              point     = grid.nearestIndex(coordValue);
1348     const GridPoint& gridPoint = grid.point(point);
1349
1350     /* Sum the probability weights from the neighborhood of the given point */
1351     double weightSum = 0;
1352     for (int neighbor : gridPoint.neighbor)
1353     {
1354         /* No convolution is required along the lambda dimension. */
1355         if (pointsHaveDifferentLambda(grid, point, neighbor))
1356         {
1357             continue;
1358         }
1359         double logWeight = biasedLogWeightFromPoint(
1360                 dimParams, points_, grid, neighbor, points_[neighbor].bias(), coordValue, {}, point);
1361         weightSum += std::exp(logWeight);
1362     }
1363
1364     /* Returns -GMX_FLOAT_MAX if no neighboring points were in the target region. */
1365     return (weightSum > 0) ? std::log(weightSum) : -GMX_FLOAT_MAX;
1366 }
1367
1368 void BiasState::sampleProbabilityWeights(const BiasGrid& grid, gmx::ArrayRef<const double> probWeightNeighbor)
1369 {
1370     const std::vector<int>& neighbor = grid.point(coordState_.gridpointIndex()).neighbor;
1371
1372     /* Save weights for next update */
1373     for (size_t n = 0; n < neighbor.size(); n++)
1374     {
1375         points_[neighbor[n]].increaseWeightSumIteration(probWeightNeighbor[n]);
1376     }
1377
1378     /* Update the local update range. Two corner points define this rectangular
1379      * domain. We need to choose two new corner points such that the new domain
1380      * contains both the old update range and the current neighborhood.
1381      * In the simplest case when an update is performed every sample,
1382      * the update range would simply equal the current neighborhood.
1383      */
1384     int neighborStart = neighbor[0];
1385     int neighborLast  = neighbor[neighbor.size() - 1];
1386     for (int d = 0; d < grid.numDimensions(); d++)
1387     {
1388         int origin = grid.point(neighborStart).index[d];
1389         int last   = grid.point(neighborLast).index[d];
1390
1391         if (origin > last)
1392         {
1393             /* Unwrap if wrapped around the boundary (only happens for periodic
1394              * boundaries). This has been already for the stored index interval.
1395              */
1396             /* TODO: what we want to do is to find the smallest the update
1397              * interval that contains all points that need to be updated.
1398              * This amounts to combining two intervals, the current
1399              * [origin, end] update interval and the new touched neighborhood
1400              * into a new interval that contains all points from both the old
1401              * intervals.
1402              *
1403              * For periodic boundaries it becomes slightly more complicated
1404              * than for closed boundaries because then it needs not be
1405              * true that origin < end (so one can't simply relate the origin/end
1406              * in the min()/max() below). The strategy here is to choose the
1407              * origin closest to a reference point (index 0) and then unwrap
1408              * the end index if needed and choose the largest end index.
1409              * This ensures that both intervals are in the new interval
1410              * but it's not necessarily the smallest.
1411              * Currently we solve this by going through each possibility
1412              * and checking them.
1413              */
1414             last += grid.axis(d).numPointsInPeriod();
1415         }
1416
1417         originUpdatelist_[d] = std::min(originUpdatelist_[d], origin);
1418         endUpdatelist_[d]    = std::max(endUpdatelist_[d], last);
1419     }
1420 }
1421
1422 void BiasState::sampleCoordAndPmf(const std::vector<DimParams>& dimParams,
1423                                   const BiasGrid&               grid,
1424                                   gmx::ArrayRef<const double>   probWeightNeighbor,
1425                                   double                        convolvedBias)
1426 {
1427     /* Sampling-based deconvolution extracting the PMF.
1428      * Update the PMF histogram with the current coordinate value.
1429      *
1430      * Because of the finite width of the harmonic potential, the free energy
1431      * defined for each coordinate point does not exactly equal that of the
1432      * actual coordinate, the PMF. However, the PMF can be estimated by applying
1433      * the relation exp(-PMF) = exp(-bias_convolved)*P_biased/Z, i.e. by keeping a
1434      * reweighted histogram of the coordinate value. Strictly, this relies on
1435      * the unknown normalization constant Z being either constant or known. Here,
1436      * neither is true except in the long simulation time limit. Empirically however,
1437      * it works (mainly because how the PMF histogram is rescaled).
1438      */
1439
1440     const int                gridPointIndex  = coordState_.gridpointIndex();
1441     const std::optional<int> lambdaAxisIndex = grid.lambdaAxisIndex();
1442
1443     /* Update the PMF of points along a lambda axis with their bias. */
1444     if (lambdaAxisIndex)
1445     {
1446         const std::vector<int>& neighbors = grid.point(gridPointIndex).neighbor;
1447
1448         std::vector<double> lambdaMarginalDistribution =
1449                 calculateFELambdaMarginalDistribution(grid, neighbors, probWeightNeighbor);
1450
1451         awh_dvec coordValueAlongLambda = { coordState_.coordValue()[0],
1452                                            coordState_.coordValue()[1],
1453                                            coordState_.coordValue()[2],
1454                                            coordState_.coordValue()[3] };
1455         for (size_t i = 0; i < neighbors.size(); i++)
1456         {
1457             const int neighbor = neighbors[i];
1458             double    bias;
1459             if (pointsAlongLambdaAxis(grid, gridPointIndex, neighbor))
1460             {
1461                 const double neighborLambda = grid.point(neighbor).coordValue[lambdaAxisIndex.value()];
1462                 if (neighbor == gridPointIndex)
1463                 {
1464                     bias = convolvedBias;
1465                 }
1466                 else
1467                 {
1468                     coordValueAlongLambda[lambdaAxisIndex.value()] = neighborLambda;
1469                     bias = calcConvolvedBias(dimParams, grid, coordValueAlongLambda);
1470                 }
1471
1472                 const double probWeight = lambdaMarginalDistribution[neighborLambda];
1473                 const double weightedBias = bias - std::log(std::max(probWeight, GMX_DOUBLE_MIN)); // avoid log(0)
1474
1475                 if (neighbor == gridPointIndex && grid.covers(coordState_.coordValue()))
1476                 {
1477                     points_[neighbor].samplePmf(weightedBias);
1478                 }
1479                 else
1480                 {
1481                     points_[neighbor].updatePmfUnvisited(weightedBias);
1482                 }
1483             }
1484         }
1485     }
1486     else
1487     {
1488         /* Only save coordinate data that is in range (the given index is always
1489          * in range even if the coordinate value is not).
1490          */
1491         if (grid.covers(coordState_.coordValue()))
1492         {
1493             /* Save PMF sum and keep a histogram of the sampled coordinate values */
1494             points_[gridPointIndex].samplePmf(convolvedBias);
1495         }
1496     }
1497
1498     /* Save probability weights for the update */
1499     sampleProbabilityWeights(grid, probWeightNeighbor);
1500 }
1501
1502 void BiasState::initHistoryFromState(AwhBiasHistory* biasHistory) const
1503 {
1504     biasHistory->pointState.resize(points_.size());
1505 }
1506
1507 void BiasState::updateHistory(AwhBiasHistory* biasHistory, const BiasGrid& grid) const
1508 {
1509     GMX_RELEASE_ASSERT(biasHistory->pointState.size() == points_.size(),
1510                        "The AWH history setup does not match the AWH state.");
1511
1512     AwhBiasStateHistory* stateHistory = &biasHistory->state;
1513     stateHistory->umbrellaGridpoint   = coordState_.umbrellaGridpoint();
1514
1515     for (size_t m = 0; m < biasHistory->pointState.size(); m++)
1516     {
1517         AwhPointStateHistory* psh = &biasHistory->pointState[m];
1518
1519         points_[m].storeState(psh);
1520
1521         psh->weightsum_covering = weightSumCovering_[m];
1522     }
1523
1524     histogramSize_.storeState(stateHistory);
1525
1526     stateHistory->origin_index_updatelist = multiDimGridIndexToLinear(grid, originUpdatelist_);
1527     stateHistory->end_index_updatelist    = multiDimGridIndexToLinear(grid, endUpdatelist_);
1528 }
1529
1530 void BiasState::restoreFromHistory(const AwhBiasHistory& biasHistory, const BiasGrid& grid)
1531 {
1532     const AwhBiasStateHistory& stateHistory = biasHistory.state;
1533
1534     coordState_.restoreFromHistory(stateHistory);
1535
1536     if (biasHistory.pointState.size() != points_.size())
1537     {
1538         GMX_THROW(
1539                 InvalidInputError("Bias grid size in checkpoint and simulation do not match. "
1540                                   "Likely you provided a checkpoint from a different simulation."));
1541     }
1542     for (size_t m = 0; m < points_.size(); m++)
1543     {
1544         points_[m].setFromHistory(biasHistory.pointState[m]);
1545     }
1546
1547     for (size_t m = 0; m < weightSumCovering_.size(); m++)
1548     {
1549         weightSumCovering_[m] = biasHistory.pointState[m].weightsum_covering;
1550     }
1551
1552     histogramSize_.restoreFromHistory(stateHistory);
1553
1554     linearGridindexToMultiDim(grid, stateHistory.origin_index_updatelist, originUpdatelist_);
1555     linearGridindexToMultiDim(grid, stateHistory.end_index_updatelist, endUpdatelist_);
1556 }
1557
1558 void BiasState::broadcast(const t_commrec* commRecord)
1559 {
1560     gmx_bcast(sizeof(coordState_), &coordState_, commRecord->mpi_comm_mygroup);
1561
1562     gmx_bcast(points_.size() * sizeof(PointState), points_.data(), commRecord->mpi_comm_mygroup);
1563
1564     gmx_bcast(weightSumCovering_.size() * sizeof(double), weightSumCovering_.data(), commRecord->mpi_comm_mygroup);
1565
1566     gmx_bcast(sizeof(histogramSize_), &histogramSize_, commRecord->mpi_comm_mygroup);
1567 }
1568
1569 void BiasState::setFreeEnergyToConvolvedPmf(const std::vector<DimParams>& dimParams, const BiasGrid& grid)
1570 {
1571     std::vector<float> convolvedPmf;
1572
1573     calcConvolvedPmf(dimParams, grid, &convolvedPmf);
1574
1575     for (size_t m = 0; m < points_.size(); m++)
1576     {
1577         points_[m].setFreeEnergy(convolvedPmf[m]);
1578     }
1579 }
1580
1581 /*! \brief
1582  * Count trailing data rows containing only zeros.
1583  *
1584  * \param[in] data        2D data array.
1585  * \param[in] numRows     Number of rows in array.
1586  * \param[in] numColumns  Number of cols in array.
1587  * \returns the number of trailing zero rows.
1588  */
1589 static int countTrailingZeroRows(const double* const* data, int numRows, int numColumns)
1590 {
1591     int numZeroRows = 0;
1592     for (int m = numRows - 1; m >= 0; m--)
1593     {
1594         bool rowIsZero = true;
1595         for (int d = 0; d < numColumns; d++)
1596         {
1597             if (data[d][m] != 0)
1598             {
1599                 rowIsZero = false;
1600                 break;
1601             }
1602         }
1603
1604         if (!rowIsZero)
1605         {
1606             /* At a row with non-zero data */
1607             break;
1608         }
1609         else
1610         {
1611             /* Still at a zero data row, keep checking rows higher up. */
1612             numZeroRows++;
1613         }
1614     }
1615
1616     return numZeroRows;
1617 }
1618
1619 /*! \brief
1620  * Initializes the PMF and target with data read from an input table.
1621  *
1622  * \param[in]     dimParams   The dimension parameters.
1623  * \param[in]     grid        The grid.
1624  * \param[in]     filename    The filename to read PMF and target from.
1625  * \param[in]     numBias     Number of biases.
1626  * \param[in]     biasIndex   The index of the bias.
1627  * \param[in,out] pointState  The state of the points in this bias.
1628  */
1629 static void readUserPmfAndTargetDistribution(const std::vector<DimParams>& dimParams,
1630                                              const BiasGrid&               grid,
1631                                              const std::string&            filename,
1632                                              int                           numBias,
1633                                              int                           biasIndex,
1634                                              std::vector<PointState>*      pointState)
1635 {
1636     /* Read the PMF and target distribution.
1637        From the PMF, the convolved PMF, or the reference value free energy, can be calculated
1638        base on the force constant. The free energy and target together determine the bias.
1639      */
1640     std::string filenameModified(filename);
1641     if (numBias > 1)
1642     {
1643         size_t n = filenameModified.rfind('.');
1644         GMX_RELEASE_ASSERT(n != std::string::npos,
1645                            "The filename should contain an extension starting with .");
1646         filenameModified.insert(n, formatString("%d", biasIndex));
1647     }
1648
1649     std::string correctFormatMessage = formatString(
1650             "%s is expected in the following format. "
1651             "The first ndim column(s) should contain the coordinate values for each point, "
1652             "each column containing values of one dimension (in ascending order). "
1653             "For a multidimensional coordinate, points should be listed "
1654             "in the order obtained by traversing lower dimensions first. "
1655             "E.g. for two-dimensional grid of size nxn: "
1656             "(1, 1), (1, 2),..., (1, n), (2, 1), (2, 2), ..., , (n, n - 1), (n, n). "
1657             "Column ndim +  1 should contain the PMF value for each coordinate value. "
1658             "The target distribution values should be in column ndim + 2  or column ndim + 5. "
1659             "Make sure the input file ends with a new line but has no trailing new lines.",
1660             filename.c_str());
1661     gmx::TextLineWrapper wrapper;
1662     wrapper.settings().setLineLength(c_linewidth);
1663     correctFormatMessage = wrapper.wrapToString(correctFormatMessage);
1664
1665     double** data;
1666     int      numColumns;
1667     int      numRows = read_xvg(filenameModified.c_str(), &data, &numColumns);
1668
1669     /* Check basic data properties here. BiasGrid takes care of more complicated things. */
1670
1671     if (numRows <= 0)
1672     {
1673         std::string mesg = gmx::formatString(
1674                 "%s is empty!.\n\n%s", filename.c_str(), correctFormatMessage.c_str());
1675         GMX_THROW(InvalidInputError(mesg));
1676     }
1677
1678     /* Less than 2 points is not useful for PMF or target. */
1679     if (numRows < 2)
1680     {
1681         std::string mesg = gmx::formatString(
1682                 "%s contains too few data points (%d)."
1683                 "The minimum number of points is 2.",
1684                 filename.c_str(),
1685                 numRows);
1686         GMX_THROW(InvalidInputError(mesg));
1687     }
1688
1689     /* Make sure there are enough columns of data.
1690
1691        Two formats are allowed. Either with columns  {coords, PMF, target} or
1692        {coords, PMF, x, y, z, target, ...}. The latter format is allowed since that
1693        is how AWH output is written (x, y, z being other AWH variables). For this format,
1694        trailing columns are ignored.
1695      */
1696     int columnIndexTarget;
1697     int numColumnsMin  = dimParams.size() + 2;
1698     int columnIndexPmf = dimParams.size();
1699     if (numColumns == numColumnsMin)
1700     {
1701         columnIndexTarget = columnIndexPmf + 1;
1702     }
1703     else
1704     {
1705         columnIndexTarget = columnIndexPmf + 4;
1706     }
1707
1708     if (numColumns < numColumnsMin)
1709     {
1710         std::string mesg = gmx::formatString(
1711                 "The number of columns in %s should be at least %d."
1712                 "\n\n%s",
1713                 filename.c_str(),
1714                 numColumnsMin,
1715                 correctFormatMessage.c_str());
1716         GMX_THROW(InvalidInputError(mesg));
1717     }
1718
1719     /* read_xvg can give trailing zero data rows for trailing new lines in the input. We allow 1 zero row,
1720        since this could be real data. But multiple trailing zero rows cannot correspond to valid data. */
1721     int numZeroRows = countTrailingZeroRows(data, numRows, numColumns);
1722     if (numZeroRows > 1)
1723     {
1724         std::string mesg = gmx::formatString(
1725                 "Found %d trailing zero data rows in %s. Please remove trailing empty lines and "
1726                 "try again.",
1727                 numZeroRows,
1728                 filename.c_str());
1729         GMX_THROW(InvalidInputError(mesg));
1730     }
1731
1732     /* Convert from user units to internal units before sending the data of to grid. */
1733     for (size_t d = 0; d < dimParams.size(); d++)
1734     {
1735         double scalingFactor = dimParams[d].scaleUserInputToInternal(1);
1736         if (scalingFactor == 1)
1737         {
1738             continue;
1739         }
1740         for (size_t m = 0; m < pointState->size(); m++)
1741         {
1742             data[d][m] *= scalingFactor;
1743         }
1744     }
1745
1746     /* Get a data point for each AWH grid point so that they all get data. */
1747     std::vector<int> gridIndexToDataIndex(grid.numPoints());
1748     mapGridToDataGrid(&gridIndexToDataIndex, data, numRows, filename, grid, correctFormatMessage);
1749
1750     /* Extract the data for each grid point.
1751      * We check if the target distribution is zero for all points.
1752      */
1753     bool targetDistributionIsZero = true;
1754     for (size_t m = 0; m < pointState->size(); m++)
1755     {
1756         (*pointState)[m].setLogPmfSum(-data[columnIndexPmf][gridIndexToDataIndex[m]]);
1757         double target = data[columnIndexTarget][gridIndexToDataIndex[m]];
1758
1759         /* Check if the values are allowed. */
1760         if (target < 0)
1761         {
1762             std::string mesg = gmx::formatString(
1763                     "Target distribution weight at point %zu (%g) in %s is negative.",
1764                     m,
1765                     target,
1766                     filename.c_str());
1767             GMX_THROW(InvalidInputError(mesg));
1768         }
1769         if (target > 0)
1770         {
1771             targetDistributionIsZero = false;
1772         }
1773         (*pointState)[m].setTargetConstantWeight(target);
1774     }
1775
1776     if (targetDistributionIsZero)
1777     {
1778         std::string mesg =
1779                 gmx::formatString("The target weights given in column %d in %s are all 0",
1780                                   columnIndexTarget,
1781                                   filename.c_str());
1782         GMX_THROW(InvalidInputError(mesg));
1783     }
1784
1785     /* Free the arrays. */
1786     for (int m = 0; m < numColumns; m++)
1787     {
1788         sfree(data[m]);
1789     }
1790     sfree(data);
1791 }
1792
1793 void BiasState::normalizePmf(int numSharingSims)
1794 {
1795     /* The normalization of the PMF estimate matters because it determines how big effect the next sample has.
1796        Approximately (for large enough force constant) we should have:
1797        sum_x(exp(-pmf(x)) = nsamples*sum_xref(exp(-f(xref)).
1798      */
1799
1800     /* Calculate the normalization factor, i.e. divide by the pmf sum, multiply by the number of samples and the f sum */
1801     double expSumPmf = 0;
1802     double expSumF   = 0;
1803     for (const PointState& pointState : points_)
1804     {
1805         if (pointState.inTargetRegion())
1806         {
1807             expSumPmf += std::exp(pointState.logPmfSum());
1808             expSumF += std::exp(-pointState.freeEnergy());
1809         }
1810     }
1811     double numSamples = histogramSize_.histogramSize() / numSharingSims;
1812
1813     /* Renormalize */
1814     double logRenorm = std::log(numSamples * expSumF / expSumPmf);
1815     for (PointState& pointState : points_)
1816     {
1817         if (pointState.inTargetRegion())
1818         {
1819             pointState.setLogPmfSum(pointState.logPmfSum() + logRenorm);
1820         }
1821     }
1822 }
1823
1824 void BiasState::initGridPointState(const AwhBiasParams&          awhBiasParams,
1825                                    const std::vector<DimParams>& dimParams,
1826                                    const BiasGrid&               grid,
1827                                    const BiasParams&             params,
1828                                    const std::string&            filename,
1829                                    int                           numBias)
1830 {
1831     /* Modify PMF, free energy and the constant target distribution factor
1832      * to user input values if there is data given.
1833      */
1834     if (awhBiasParams.bUserData)
1835     {
1836         readUserPmfAndTargetDistribution(dimParams, grid, filename, numBias, params.biasIndex, &points_);
1837         setFreeEnergyToConvolvedPmf(dimParams, grid);
1838     }
1839
1840     /* The local Boltzmann distribution is special because the target distribution is updated as a function of the reference weighthistogram. */
1841     GMX_RELEASE_ASSERT(params.eTarget != AwhTargetType::LocalBoltzmann || points_[0].weightSumRef() != 0,
1842                        "AWH reference weight histogram not initialized properly with local "
1843                        "Boltzmann target distribution.");
1844
1845     updateTargetDistribution(points_, params);
1846
1847     for (PointState& pointState : points_)
1848     {
1849         if (pointState.inTargetRegion())
1850         {
1851             pointState.updateBias();
1852         }
1853         else
1854         {
1855             /* Note that for zero target this is a value that represents -infinity but should not be used for biasing. */
1856             pointState.setTargetToZero();
1857         }
1858     }
1859
1860     /* Set the initial reference weighthistogram. */
1861     const double histogramSize = histogramSize_.histogramSize();
1862     for (auto& pointState : points_)
1863     {
1864         pointState.setInitialReferenceWeightHistogram(histogramSize);
1865     }
1866
1867     /* Make sure the pmf is normalized consistently with the histogram size.
1868        Note: the target distribution and free energy need to be set here. */
1869     normalizePmf(params.numSharedUpdate);
1870 }
1871
1872 BiasState::BiasState(const AwhBiasParams&          awhBiasParams,
1873                      double                        histogramSizeInitial,
1874                      const std::vector<DimParams>& dimParams,
1875                      const BiasGrid&               grid) :
1876     coordState_(awhBiasParams, dimParams, grid),
1877     points_(grid.numPoints()),
1878     weightSumCovering_(grid.numPoints()),
1879     histogramSize_(awhBiasParams, histogramSizeInitial)
1880 {
1881     /* The minimum and maximum multidimensional point indices that are affected by the next update */
1882     for (size_t d = 0; d < dimParams.size(); d++)
1883     {
1884         int index            = grid.point(coordState_.gridpointIndex()).index[d];
1885         originUpdatelist_[d] = index;
1886         endUpdatelist_[d]    = index;
1887     }
1888 }
1889
1890 } // namespace gmx