Fix output string for GPU support
[alexxy/gromacs.git] / src / gromacs / applied_forces / awh / pointstate.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2015,2016,2017,2018,2019,2020, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *
38  * \brief
39  * Declares and defines the PointState class.
40  *
41  * Since nearly all operations on PointState objects occur in loops over
42  * (parts of) the grid of an AWH bias, all these methods should be inlined.
43  * Only samplePmf() is called only once per step and is thus not inlined.
44  *
45  * \author Viveca Lindahl
46  * \author Berk Hess <hess@kth.se>
47  * \ingroup module_awh
48  */
49
50 #ifndef GMX_AWH_POINTSTATE_H
51 #define GMX_AWH_POINTSTATE_H
52
53 #include <cmath>
54
55 #include "gromacs/mdtypes/awh_history.h"
56 #include "gromacs/mdtypes/awh_params.h"
57 #include "gromacs/utility/gmxassert.h"
58
59 #include "biasparams.h"
60
61 namespace gmx
62 {
63
64 namespace detail
65 {
66
67 //! A value that can be passed to exp() with result 0, also with SIMD
68 constexpr double c_largeNegativeExponent = -10000.0;
69
70 //! The largest acceptable positive exponent for variables that are passed to exp().
71 constexpr double c_largePositiveExponent = 700.0;
72
73 } // namespace detail
74
75 /*! \internal
76  * \brief The state of a coordinate point.
77  *
78  * This class contains all the state variables of a coordinate point
79  * (on the bias grid) and methods to update the state of a point.
80  */
81 class PointState
82 {
83 public:
84     /*! \brief Constructs a point state with default values. */
85     PointState() :
86         bias_(0),
87         freeEnergy_(0),
88         target_(1),
89         targetConstantWeight_(1),
90         weightSumIteration_(0),
91         weightSumTot_(0),
92         weightSumRef_(1),
93         lastUpdateIndex_(0),
94         logPmfSum_(0),
95         numVisitsIteration_(0),
96         numVisitsTot_(0)
97     {
98     }
99
100     /*! \brief
101      * Set all values in the state to those from a history.
102      *
103      * \param[in] psh  Coordinate point history to copy from.
104      */
105     void setFromHistory(const AwhPointStateHistory& psh)
106     {
107         target_             = psh.target;
108         freeEnergy_         = psh.free_energy;
109         bias_               = psh.bias;
110         weightSumIteration_ = psh.weightsum_iteration;
111         weightSumTot_       = psh.weightsum_tot;
112         weightSumRef_       = psh.weightsum_ref;
113         lastUpdateIndex_    = psh.last_update_index;
114         logPmfSum_          = psh.log_pmfsum;
115         numVisitsIteration_ = psh.visits_iteration;
116         numVisitsTot_       = psh.visits_tot;
117     }
118
119     /*! \brief
120      * Store the state of a point in a history struct.
121      *
122      * \param[in,out] psh  Coordinate point history to copy to.
123      */
124     void storeState(AwhPointStateHistory* psh) const
125     {
126         psh->target              = target_;
127         psh->free_energy         = freeEnergy_;
128         psh->bias                = bias_;
129         psh->weightsum_iteration = weightSumIteration_;
130         psh->weightsum_tot       = weightSumTot_;
131         psh->weightsum_ref       = weightSumRef_;
132         psh->last_update_index   = lastUpdateIndex_;
133         psh->log_pmfsum          = logPmfSum_;
134         psh->visits_iteration    = numVisitsIteration_;
135         psh->visits_tot          = numVisitsTot_;
136     }
137
138     /*! \brief
139      * Query if the point is in the target region.
140      *
141      * \returns true if the point is in the target region.
142      */
143     bool inTargetRegion() const { return target_ > 0; }
144
145     /*! \brief Return the bias function estimate. */
146     double bias() const { return bias_; }
147
148     /*! \brief Set the target to zero and the bias to minus infinity. */
149     void setTargetToZero()
150     {
151         target_ = 0;
152         /* the bias = log(target) + const = -infty */
153         bias_ = detail::c_largeNegativeExponent;
154     }
155
156     /*! \brief Return the free energy. */
157     double freeEnergy() const { return freeEnergy_; }
158
159     /*! \brief Set the free energy, only to be used at initialization.
160      *
161      * \param[in] freeEnergy  The free energy.
162      */
163     void setFreeEnergy(double freeEnergy) { freeEnergy_ = freeEnergy; }
164
165     /*! \brief Return the target distribution value. */
166     double target() const { return target_; }
167
168     /*! \brief Return the weight accumulated since the last update. */
169     double weightSumIteration() const { return weightSumIteration_; }
170
171     /*! \brief Increases the weight accumulated since the last update.
172      *
173      * \param[in] weight  The amount to add to the weight
174      */
175     void increaseWeightSumIteration(double weight) { weightSumIteration_ += weight; }
176
177     /*! \brief Returns the accumulated weight */
178     double weightSumTot() const { return weightSumTot_; }
179
180     /*! \brief Return the reference weight histogram. */
181     double weightSumRef() const { return weightSumRef_; }
182
183     /*! \brief Return log(PmfSum). */
184     double logPmfSum() const { return logPmfSum_; }
185
186     /*! \brief Set log(PmfSum).
187      *
188      * TODO: Replace this setter function with a more elegant solution.
189      *
190      * \param[in] logPmfSum  The log(PmfSum).
191      */
192     void setLogPmfSum(double logPmfSum) { logPmfSum_ = logPmfSum; }
193
194     /*! \brief Return the number of visits since the last update */
195     double numVisitsIteration() const { return numVisitsIteration_; }
196
197     /*! \brief Return the total number of visits */
198     double numVisitsTot() const { return numVisitsTot_; }
199
200     /*! \brief Set the constant target weight factor.
201      *
202      * \param[in] targetConstantWeight  The target weight factor.
203      */
204     void setTargetConstantWeight(double targetConstantWeight)
205     {
206         targetConstantWeight_ = targetConstantWeight;
207     }
208
209     /*! \brief Updates the bias of a point. */
210     void updateBias()
211     {
212         GMX_ASSERT(target_ > 0, "AWH target distribution must be > 0 to calculate the point bias.");
213
214         bias_ = freeEnergy() + std::log(target_);
215     }
216
217     /*! \brief Set the initial reference weighthistogram.
218      *
219      * \param[in] histogramSize  The weight histogram size.
220      */
221     void setInitialReferenceWeightHistogram(double histogramSize)
222     {
223         weightSumRef_ = histogramSize * target_;
224     }
225
226     /*! \brief Correct free energy and PMF sum for the change in minimum.
227      *
228      * \param[in] minimumFreeEnergy  The free energy at the minimum;
229      */
230     void normalizeFreeEnergyAndPmfSum(double minimumFreeEnergy)
231     {
232         if (inTargetRegion())
233         {
234             /* The sign of the free energy and PMF constants are opposite
235              * because the PMF samples are reweighted with the negative
236              * bias e^(-bias) ~ e^(-free energy).
237              */
238             freeEnergy_ -= minimumFreeEnergy;
239             logPmfSum_ += minimumFreeEnergy;
240         }
241     }
242
243     /*! \brief Apply previous updates that were skipped.
244      *
245      * An update can only be skipped if the parameters needed for the update are constant or
246      * deterministic so that the same update can be performed at a later time.
247      * Here, the necessary parameters are the sampled weight and scaling factors for the
248      * histograms. The scaling factors are provided as arguments only to avoid recalculating
249      * them for each point
250      *
251      * The last update index is also updated here.
252      *
253      * \param[in] params             The AWH bias parameters.
254      * \param[in] numUpdates         The global number of updates.
255      * \param[in] weighthistScaling  Scale factor for the reference weight histogram.
256      * \param[in] logPmfSumScaling   Scale factor for the reference PMF histogram.
257      * \returns true if at least one update was applied.
258      */
259     bool performPreviouslySkippedUpdates(const BiasParams& params,
260                                          int64_t           numUpdates,
261                                          double            weighthistScaling,
262                                          double            logPmfSumScaling)
263     {
264         GMX_ASSERT(params.skipUpdates(),
265                    "Calling function for skipped updates when skipping updates is not allowed");
266
267         if (!inTargetRegion())
268         {
269             return false;
270         }
271
272         /* The most current past update */
273         int64_t lastUpdateIndex   = numUpdates;
274         int64_t numUpdatesSkipped = lastUpdateIndex - lastUpdateIndex_;
275
276         if (numUpdatesSkipped == 0)
277         {
278             /* Was not updated */
279             return false;
280         }
281
282         for (int64_t i = 0; i < numUpdatesSkipped; i++)
283         {
284             /* This point was non-local at the time of the update meaning no weight */
285             updateFreeEnergyAndWeight(params, 0, weighthistScaling, logPmfSumScaling);
286         }
287
288         /* Only past updates are applied here. */
289         lastUpdateIndex_ = lastUpdateIndex;
290
291         return true;
292     }
293
294     /*! \brief Apply a point update with new sampling.
295      *
296      * \note The last update index is also updated here.
297      * \note The new sampling containers are cleared here.
298      *
299      * \param[in] params              The AWH bias parameters.
300      * \param[in] numUpdates          The global number of updates.
301      * \param[in] weighthistScaling   Scaling factor for the reference weight histogram.
302      * \param[in] logPmfSumScaling    Log of the scaling factor for the PMF histogram.
303      */
304     void updateWithNewSampling(const BiasParams& params, int64_t numUpdates, double weighthistScaling, double logPmfSumScaling)
305     {
306         GMX_RELEASE_ASSERT(lastUpdateIndex_ == numUpdates,
307                            "When doing a normal update, the point update index should match the "
308                            "global index, otherwise we lost (skipped?) updates.");
309
310         updateFreeEnergyAndWeight(params, weightSumIteration_, weighthistScaling, logPmfSumScaling);
311         lastUpdateIndex_ += 1;
312
313         /* Clear the iteration collection data */
314         weightSumIteration_ = 0;
315         numVisitsIteration_ = 0;
316     }
317
318
319     /*! \brief Update the PMF histogram with the current coordinate value.
320      *
321      * \param[in] convolvedBias  The convolved bias.
322      */
323     void samplePmf(double convolvedBias);
324
325     /*! \brief Update the PMF histogram of unvisited coordinate values
326      * (along a lambda axis)
327      *
328      * \param[in] bias  The bias to update with.
329      */
330     void updatePmfUnvisited(double bias);
331
332 private:
333     /*! \brief Update the free energy estimate of a point.
334      *
335      * The free energy update here is inherently local, i.e. it just depends on local sampling and
336      * on constant AWH parameters. This assumes that the variables used here are kept constant, at
337      * least in between global updates.
338      *
339      * \param[in] params          The AWH bias parameters.
340      * \param[in] weightAtPoint   Sampled probability weight at this point.
341      */
342     void updateFreeEnergy(const BiasParams& params, double weightAtPoint)
343     {
344         double weighthistSampled = weightSumRef() + weightAtPoint;
345         double weighthistTarget  = weightSumRef() + params.updateWeight * target_;
346
347         double df = -std::log(weighthistSampled / weighthistTarget);
348         freeEnergy_ += df;
349
350         GMX_RELEASE_ASSERT(std::abs(freeEnergy_) < detail::c_largePositiveExponent,
351                            "Very large free energy differences or badly normalized free energy in "
352                            "AWH update.");
353     }
354
355     /*! \brief Update the reference weight histogram of a point.
356      *
357      * \param[in] params         The AWH bias parameters.
358      * \param[in] weightAtPoint  Sampled probability weight at this point.
359      * \param[in] scaleFactor    Factor to rescale the histogram with.
360      */
361     void updateWeightHistogram(const BiasParams& params, double weightAtPoint, double scaleFactor)
362     {
363         if (params.idealWeighthistUpdate)
364         {
365             /* Grow histogram using the target distribution. */
366             weightSumRef_ += target_ * params.updateWeight * params.localWeightScaling;
367         }
368         else
369         {
370             /* Grow using the actual samples (which are distributed ~ as target). */
371             weightSumRef_ += weightAtPoint * params.localWeightScaling;
372         }
373
374         weightSumRef_ *= scaleFactor;
375     }
376
377     /*! \brief Apply a point update.
378      *
379      * This updates local properties that can be updated without
380      * accessing or affecting all points.
381      * This excludes updating the size of reference weight histogram and
382      * the target distribution. The bias update is excluded only because
383      * if updates have been skipped this function will be called multiple
384      * times, while the bias only needs to be updated once (last).
385      *
386      * Since this function only performs the update with the given
387      * arguments and does not know anything about the time of the update,
388      * the last update index is not updated here. The caller should take
389      * care of updating the update index.
390      *
391      * \param[in] params             The AWH bias parameters.
392      * \param[in] weightAtPoint      Sampled probability weight at this point.
393      * \param[in] weighthistScaling  Scaling factor for the reference weight histogram.
394      * \param[in] logPmfSumScaling   Log of the scaling factor for the PMF histogram.
395      */
396     void updateFreeEnergyAndWeight(const BiasParams& params,
397                                    double            weightAtPoint,
398                                    double            weighthistScaling,
399                                    double            logPmfSumScaling)
400     {
401         updateFreeEnergy(params, weightAtPoint);
402         updateWeightHistogram(params, weightAtPoint, weighthistScaling);
403         logPmfSum_ += logPmfSumScaling;
404     }
405
406
407 public:
408     /*! \brief Update the target weight of a point.
409      *
410      * Note that renormalization over all points is needed after the update.
411      *
412      * \param[in] params            The AWH bias parameters.
413      * \param[in] freeEnergyCutoff  The cut-off for the free energy for target type "cutoff".
414      * \returns the updated value of the target.
415      */
416     double updateTargetWeight(const BiasParams& params, double freeEnergyCutoff)
417     {
418         switch (params.eTarget)
419         {
420             case eawhtargetCONSTANT: target_ = 1; break;
421             case eawhtargetCUTOFF:
422             {
423                 double df = freeEnergy_ - freeEnergyCutoff;
424                 target_   = 1 / (1 + std::exp(df));
425                 break;
426             }
427             case eawhtargetBOLTZMANN:
428                 target_ = std::exp(-params.temperatureScaleFactor * freeEnergy_);
429                 break;
430             case eawhtargetLOCALBOLTZMANN: target_ = weightSumRef_; break;
431         }
432
433         /* All target types can be modulated by a constant factor. */
434         target_ *= targetConstantWeight_;
435
436         return target_;
437     }
438
439     /*! \brief Set the weight and count accumulated since the last update.
440      *
441      * \param[in] weightSum  The weight-sum value
442      * \param[in] numVisits  The number of visits
443      */
444     void setPartialWeightAndCount(double weightSum, double numVisits)
445     {
446         weightSumIteration_ = weightSum;
447         numVisitsIteration_ = numVisits;
448     }
449
450     /*! \brief Add the weights and counts accumulated between updates. */
451     void addPartialWeightAndCount()
452     {
453         weightSumTot_ += weightSumIteration_;
454         numVisitsTot_ += numVisitsIteration_;
455     }
456
457     /*! \brief Scale the target weight of the point.
458      *
459      * \param[in] scaleFactor  Factor to scale with.
460      */
461     void scaleTarget(double scaleFactor) { target_ *= scaleFactor; }
462
463 private:
464     double bias_;                 /**< Current biasing function estimate */
465     double freeEnergy_;           /**< Current estimate of the convolved free energy/PMF. */
466     double target_;               /**< Current target distribution, normalized to 1 */
467     double targetConstantWeight_; /**< Constant target weight, from user data. */
468     double weightSumIteration_; /**< Accumulated weight this iteration; note: only contains data for this Bias, even when sharing biases. */
469     double weightSumTot_;       /**< Accumulated weights, never reset */
470     double weightSumRef_; /**< The reference weight histogram determining the free energy updates */
471     int64_t lastUpdateIndex_; /**< The last update that was performed at this point, in units of number of updates. */
472     double  logPmfSum_;          /**< Logarithm of the PMF histogram */
473     double  numVisitsIteration_; /**< Visits to this bin this iteration; note: only contains data for this Bias, even when sharing biases. */
474     double  numVisitsTot_;       /**< Accumulated visits to this bin */
475 };
476
477 } // namespace gmx
478
479 #endif /* GMX_AWH_POINTSTATE_H */