221534441870fe925ffcd377676d6788fc44f4ce
[alexxy/gromacs.git] / src / gromacs / awh / pointstate.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2015,2016,2017,2018,2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *
38  * \brief
39  * Declares and defines the PointState class.
40  *
41  * Since nearly all operations on PointState objects occur in loops over
42  * (parts of) the grid of an AWH bias, all these methods should be inlined.
43  * Only samplePmf() is called only once per step and is thus not inlined.
44  *
45  * \author Viveca Lindahl
46  * \author Berk Hess <hess@kth.se>
47  * \ingroup module_awh
48  */
49
50 #ifndef GMX_AWH_POINTSTATE_H
51 #define GMX_AWH_POINTSTATE_H
52
53 #include <cmath>
54
55 #include "gromacs/mdtypes/awh_history.h"
56 #include "gromacs/mdtypes/awh_params.h"
57 #include "gromacs/utility/gmxassert.h"
58
59 #include "biasparams.h"
60
61 namespace gmx
62 {
63
64 namespace detail
65 {
66
67 //! A value that can be passed to exp() with result 0, also with SIMD
68 constexpr double c_largeNegativeExponent = -10000.0;
69
70 //! The largest acceptable positive exponent for variables that are passed to exp().
71 constexpr double c_largePositiveExponent =  700.0;
72
73 }   // namepace detail
74
75 /*! \internal
76  * \brief The state of a coordinate point.
77  *
78  * This class contains all the state variables of a coordinate point
79  * (on the bias grid) and methods to update the state of a point.
80  */
81 class PointState
82 {
83     public:
84         /*! \brief Constructs a point state with default values. */
85         PointState() : bias_(0),
86                        freeEnergy_(0),
87                        target_(1),
88                        targetConstantWeight_(1),
89                        weightSumIteration_(0),
90                        weightSumTot_(0),
91                        weightSumRef_(1),
92                        lastUpdateIndex_(0),
93                        logPmfSum_(0),
94                        numVisitsIteration_(0),
95                        numVisitsTot_(0)
96         {
97         }
98
99         /*! \brief
100          * Set all values in the state to those from a history.
101          *
102          * \param[in] psh  Coordinate point history to copy from.
103          */
104         void setFromHistory(const AwhPointStateHistory &psh)
105         {
106             target_             = psh.target;
107             freeEnergy_         = psh.free_energy;
108             bias_               = psh.bias;
109             weightSumIteration_ = psh.weightsum_iteration;
110             weightSumTot_       = psh.weightsum_tot;
111             weightSumRef_       = psh.weightsum_ref;
112             lastUpdateIndex_    = psh.last_update_index;
113             logPmfSum_          = psh.log_pmfsum;
114             numVisitsIteration_ = psh.visits_iteration;
115             numVisitsTot_       = psh.visits_tot;
116         }
117
118         /*! \brief
119          * Store the state of a point in a history struct.
120          *
121          * \param[in,out] psh  Coordinate point history to copy to.
122          */
123         void storeState(AwhPointStateHistory *psh) const
124         {
125             psh->target              = target_;
126             psh->free_energy         = freeEnergy_;
127             psh->bias                = bias_;
128             psh->weightsum_iteration = weightSumIteration_;
129             psh->weightsum_tot       = weightSumTot_;
130             psh->weightsum_ref       = weightSumRef_;
131             psh->last_update_index   = lastUpdateIndex_;
132             psh->log_pmfsum          = logPmfSum_;
133             psh->visits_iteration    = numVisitsIteration_;
134             psh->visits_tot          = numVisitsTot_;
135         }
136
137         /*! \brief
138          * Query if the point is in the target region.
139          *
140          * \returns true if the point is in the target region.
141          */
142         bool inTargetRegion() const
143         {
144             return target_ > 0;
145         }
146
147         /*! \brief Return the bias function estimate. */
148         double bias() const
149         {
150             return bias_;
151         }
152
153         /*! \brief Set the target to zero and the bias to minus infinity. */
154         void setTargetToZero()
155         {
156             target_ = 0;
157             /* the bias = log(target) + const = -infty */
158             bias_   = detail::c_largeNegativeExponent;
159         }
160
161         /*! \brief Return the free energy. */
162         double freeEnergy() const
163         {
164             return freeEnergy_;
165         }
166
167         /*! \brief Set the free energy, only to be used at initialization.
168          *
169          * \param[in] freeEnergy  The free energy.
170          */
171         void setFreeEnergy(double freeEnergy)
172         {
173             freeEnergy_ = freeEnergy;
174         }
175
176         /*! \brief Return the target distribution value. */
177         double target() const
178         {
179             return target_;
180         }
181
182         /*! \brief Return the weight accumulated since the last update. */
183         double weightSumIteration() const
184         {
185             return weightSumIteration_;
186         }
187
188         /*! \brief Increases the weight accumulated since the last update.
189          *
190          * \param[in] weight  The amount to add to the weight
191          */
192         void increaseWeightSumIteration(double weight)
193         {
194             weightSumIteration_ += weight;
195         }
196
197         /*! \brief Returns the accumulated weight */
198         double weightSumTot() const
199         {
200             return weightSumTot_;
201         }
202
203         /*! \brief Return the reference weight histogram. */
204         double weightSumRef() const
205         {
206             return weightSumRef_;
207         }
208
209         /*! \brief Return log(PmfSum). */
210         double logPmfSum() const
211         {
212             return logPmfSum_;
213         }
214
215         /*! \brief Set log(PmfSum).
216          *
217          * TODO: Replace this setter function with a more elegant solution.
218          *
219          * \param[in] logPmfSum  The log(PmfSum).
220          */
221         void setLogPmfSum(double logPmfSum)
222         {
223             logPmfSum_ = logPmfSum;
224         }
225
226         /*! \brief Return the number of visits since the last update */
227         double numVisitsIteration() const
228         {
229             return numVisitsIteration_;
230         }
231
232         /*! \brief Return the total number of visits */
233         double numVisitsTot() const
234         {
235             return numVisitsTot_;
236         }
237
238         /*! \brief Set the constant target weight factor.
239          *
240          * \param[in] targetConstantWeight  The target weight factor.
241          */
242         void setTargetConstantWeight(double targetConstantWeight)
243         {
244             targetConstantWeight_ = targetConstantWeight;
245         }
246
247         /*! \brief Updates the bias of a point. */
248         void updateBias()
249         {
250             GMX_ASSERT(target_ > 0, "AWH target distribution must be > 0 to calculate the point bias.");
251
252             bias_ = freeEnergy() + std::log(target_);
253         }
254
255         /*! \brief Set the initial reference weighthistogram.
256          *
257          * \param[in] histogramSize  The weight histogram size.
258          */
259         void setInitialReferenceWeightHistogram(double histogramSize)
260         {
261             weightSumRef_ = histogramSize*target_;
262         }
263
264         /*! \brief Correct free energy and PMF sum for the change in minimum.
265          *
266          * \param[in] minimumFreeEnergy  The free energy at the minimum;
267          */
268         void normalizeFreeEnergyAndPmfSum(double minimumFreeEnergy)
269         {
270             if (inTargetRegion())
271             {
272                 /* The sign of the free energy and PMF constants are opposite
273                  * because the PMF samples are reweighted with the negative
274                  * bias e^(-bias) ~ e^(-free energy).
275                  */
276                 freeEnergy_ -= minimumFreeEnergy;
277                 logPmfSum_  += minimumFreeEnergy;
278             }
279         }
280
281         /*! \brief Apply previous updates that were skipped.
282          *
283          * An update can only be skipped if the parameters needed for the update are constant or
284          * deterministic so that the same update can be performed at a later time.
285          * Here, the necessary parameters are the sampled weight and scaling factors for the
286          * histograms. The scaling factors are provided as arguments only to avoid recalculating
287          * them for each point
288          *
289          * The last update index is also updated here.
290          *
291          * \param[in] params             The AWH bias parameters.
292          * \param[in] numUpdates         The global number of updates.
293          * \param[in] weighthistScaling  Scale factor for the reference weight histogram.
294          * \param[in] logPmfSumScaling   Scale factor for the reference PMF histogram.
295          * \returns true if at least one update was applied.
296          */
297         bool performPreviouslySkippedUpdates(const BiasParams &params,
298                                              int64_t           numUpdates,
299                                              double            weighthistScaling,
300                                              double            logPmfSumScaling)
301         {
302             GMX_ASSERT(params.skipUpdates(), "Calling function for skipped updates when skipping updates is not allowed");
303
304             if (!inTargetRegion())
305             {
306                 return false;
307             }
308
309             /* The most current past update */
310             int64_t lastUpdateIndex   = numUpdates;
311             int64_t numUpdatesSkipped = lastUpdateIndex - lastUpdateIndex_;
312
313             if (numUpdatesSkipped == 0)
314             {
315                 /* Was not updated */
316                 return false;
317             }
318
319             for (int i = 0; i < numUpdatesSkipped; i++)
320             {
321                 /* This point was non-local at the time of the update meaning no weight */
322                 updateFreeEnergyAndWeight(params, 0, weighthistScaling, logPmfSumScaling);
323             }
324
325             /* Only past updates are applied here. */
326             lastUpdateIndex_ = lastUpdateIndex;
327
328             return true;
329         }
330
331         /*! \brief Apply a point update with new sampling.
332          *
333          * \note The last update index is also updated here.
334          * \note The new sampling containers are cleared here.
335          *
336          * \param[in] params              The AWH bias parameters.
337          * \param[in] numUpdates          The global number of updates.
338          * \param[in] weighthistScaling   Scaling factor for the reference weight histogram.
339          * \param[in] logPmfSumScaling    Log of the scaling factor for the PMF histogram.
340          */
341         void updateWithNewSampling(const BiasParams &params,
342                                    int64_t           numUpdates,
343                                    double            weighthistScaling,
344                                    double            logPmfSumScaling)
345         {
346             GMX_RELEASE_ASSERT(lastUpdateIndex_ == numUpdates, "When doing a normal update, the point update index should match the global index, otherwise we lost (skipped?) updates.");
347
348             updateFreeEnergyAndWeight(params, weightSumIteration_, weighthistScaling, logPmfSumScaling);
349             lastUpdateIndex_    += 1;
350
351             /* Clear the iteration collection data */
352             weightSumIteration_  = 0;
353             numVisitsIteration_  = 0;
354         }
355
356
357         /*! \brief Update the PMF histogram with the current coordinate value.
358          *
359          * \param[in] convolvedBias  The convolved bias.
360          */
361         void samplePmf(double convolvedBias);
362
363     private:
364         /*! \brief Update the free energy estimate of a point.
365          *
366          * The free energy update here is inherently local, i.e. it just depends on local sampling and on constant
367          * AWH parameters. This assumes that the variables used here are kept constant, at least in between
368          * global updates.
369          *
370          * \param[in] params          The AWH bias parameters.
371          * \param[in] weightAtPoint   Sampled probability weight at this point.
372          */
373         void updateFreeEnergy(const BiasParams &params,
374                               double            weightAtPoint)
375         {
376             double weighthistSampled  = weightSumRef() + weightAtPoint;
377             double weighthistTarget   = weightSumRef() + params.updateWeight*target_;
378
379             double df                 = -std::log(weighthistSampled/weighthistTarget);
380             freeEnergy_              += df;
381
382             GMX_RELEASE_ASSERT(std::abs(freeEnergy_) < detail::c_largePositiveExponent,
383                                "Very large free energy differences or badly normalized free energy in AWH update.");
384         }
385
386         /*! \brief Update the reference weight histogram of a point.
387          *
388          * \param[in] params         The AWH bias parameters.
389          * \param[in] weightAtPoint  Sampled probability weight at this point.
390          * \param[in] scaleFactor    Factor to rescale the histogram with.
391          */
392         void updateWeightHistogram(const BiasParams &params,
393                                    double            weightAtPoint,
394                                    double            scaleFactor)
395         {
396             if (params.idealWeighthistUpdate)
397             {
398                 /* Grow histogram using the target distribution. */
399                 weightSumRef_ += target_*params.updateWeight*params.localWeightScaling;
400             }
401             else
402             {
403                 /* Grow using the actual samples (which are distributed ~ as target). */
404                 weightSumRef_ += weightAtPoint*params.localWeightScaling;
405             }
406
407             weightSumRef_ *= scaleFactor;
408         }
409
410         /*! \brief Apply a point update.
411          *
412          * This updates local properties that can be updated without
413          * accessing or affecting all points.
414          * This excludes updating the size of reference weight histogram and
415          * the target distribution. The bias update is excluded only because
416          * if updates have been skipped this function will be called multiple
417          * times, while the bias only needs to be updated once (last).
418          *
419          * Since this function only performs the update with the given
420          * arguments and does not know anything about the time of the update,
421          * the last update index is not updated here. The caller should take
422          * care of updating the update index.
423          *
424          * \param[in] params             The AWH bias parameters.
425          * \param[in] weightAtPoint      Sampled probability weight at this point.
426          * \param[in] weighthistScaling  Scaling factor for the reference weight histogram.
427          * \param[in] logPmfSumScaling   Log of the scaling factor for the PMF histogram.
428          */
429         void updateFreeEnergyAndWeight(const BiasParams &params,
430                                        double            weightAtPoint,
431                                        double            weighthistScaling,
432                                        double            logPmfSumScaling)
433         {
434             updateFreeEnergy(params, weightAtPoint);
435             updateWeightHistogram(params, weightAtPoint, weighthistScaling);
436             logPmfSum_ += logPmfSumScaling;
437         }
438
439
440     public:
441         /*! \brief Update the target weight of a point.
442          *
443          * Note that renormalization over all points is needed after the update.
444          *
445          * \param[in] params            The AWH bias parameters.
446          * \param[in] freeEnergyCutoff  The cut-off for the free energy for target type "cutoff".
447          * \returns the updated value of the target.
448          */
449         double updateTargetWeight(const BiasParams &params,
450                                   double            freeEnergyCutoff)
451         {
452             switch (params.eTarget)
453             {
454                 case eawhtargetCONSTANT:
455                     target_   = 1;
456                     break;
457                 case eawhtargetCUTOFF:
458                 {
459                     double df = freeEnergy_ - freeEnergyCutoff;
460                     target_   = 1/(1 + std::exp(df));
461                     break;
462                 }
463                 case eawhtargetBOLTZMANN:
464                     target_   = std::exp(-params.temperatureScaleFactor*freeEnergy_);
465                     break;
466                 case eawhtargetLOCALBOLTZMANN:
467                     target_   = weightSumRef_;
468                     break;
469             }
470
471             /* All target types can be modulated by a constant factor. */
472             target_ *= targetConstantWeight_;
473
474             return target_;
475         }
476
477         /*! \brief Set the weight and count accumulated since the last update.
478          *
479          * \param[in] weightSum  The weight-sum value
480          * \param[in] numVisits  The number of visits
481          */
482         void setPartialWeightAndCount(double weightSum,
483                                       double numVisits)
484         {
485             weightSumIteration_ = weightSum;
486             numVisitsIteration_ = numVisits;
487         }
488
489         /*! \brief Add the weights and counts accumulated between updates. */
490         void addPartialWeightAndCount()
491         {
492             weightSumTot_ += weightSumIteration_;
493             numVisitsTot_ += numVisitsIteration_;
494         }
495
496         /*! \brief Scale the target weight of the point.
497          *
498          * \param[in] scaleFactor  Factor to scale with.
499          */
500         void scaleTarget(double scaleFactor)
501         {
502             target_ *= scaleFactor;
503         }
504
505     private:
506         double      bias_;                 /**< Current biasing function estimate */
507         double      freeEnergy_;           /**< Current estimate of the convolved free energy/PMF. */
508         double      target_;               /**< Current target distribution, normalized to 1 */
509         double      targetConstantWeight_; /**< Constant target weight, from user data. */
510         double      weightSumIteration_;   /**< Accumulated weight this iteration; note: only contains data for this Bias, even when sharing biases. */
511         double      weightSumTot_;         /**< Accumulated weights, never reset */
512         double      weightSumRef_;         /**< The reference weight histogram determining the free energy updates */
513         int64_t     lastUpdateIndex_;      /**< The last update that was performed at this point, in units of number of updates. */
514         double      logPmfSum_;            /**< Logarithm of the PMF histogram */
515         double      numVisitsIteration_;   /**< Visits to this bin this iteration; note: only contains data for this Bias, even when sharing biases. */
516         double      numVisitsTot_;         /**< Accumulated visits to this bin */
517 };
518
519 }      // namespace gmx
520
521 #endif /* GMX_AWH_POINTSTATE_H */