56c2a1a2ccdb041cdd78ef856eb452c5a6975b7f
[alexxy/gromacs.git] / python_packaging / sample_restraint / src / cpp / ensemblepotential.cpp
1 /*! \file
2  * \brief Code to implement the potential declared in ensemblepotential.h
3  *
4  * This file currently contains boilerplate that will not be necessary in future gmxapi releases, as
5  * well as additional code used in implementing the restrained ensemble example workflow.
6  *
7  * A simpler restraint potential would only update the calculate() function. If a callback function is
8  * not needed or desired, remove the callback() code from this file and from ensemblepotential.h
9  *
10  * \author M. Eric Irrgang <ericirrgang@gmail.com>
11  */
12
13 #include "ensemblepotential.h"
14
15 #include <cassert>
16 #include <cmath>
17
18 #include <memory>
19 #include <vector>
20
21 #include "gmxapi/context.h"
22 #include "gmxapi/session.h"
23 #include "gmxapi/md/mdsignals.h"
24
25 #include "sessionresources.h"
26
27 namespace plugin
28 {
29
30 /*!
31  * \brief Discretize a density field on a grid.
32  *
33  * Apply a Gaussian blur when building a density grid for a list of values.
34  * Normalize such that the area under each sample is 1.0/num_samples.
35  */
36 class BlurToGrid
37 {
38     public:
39         /*!
40          * \brief Construct the blurring functor.
41          *
42          * \param low The coordinate value of the first grid point.
43          * \param gridSpacing Distance between grid points.
44          * \param sigma Gaussian parameter for blurring inputs onto the grid.
45          */
46         BlurToGrid(double low,
47                    double gridSpacing,
48                    double sigma) :
49             low_{low},
50             binWidth_{gridSpacing},
51             sigma_{sigma}
52         {
53         };
54
55         /*!
56          * \brief Callable for the functor.
57          *
58          * \param samples A list of values to be blurred onto the grid.
59          * \param grid Pointer to the container into which to accumulate a blurred histogram of samples.
60          *
61          * Example:
62          *
63          *     # Acquire 3 samples to be discretized with blurring.
64          *     std::vector<double> someData = {3.7, 8.1, 4.2};
65          *
66          *     # Create an empty grid to store magnitudes for points 0.5, 1.0, ..., 10.0.
67          *     std::vector<double> histogram(20, 0.);
68          *
69          *     # Specify the above grid and a Gaussian parameter of 0.8.
70          *     auto blur = BlurToGrid(0.5, 0.5, 0.8);
71          *
72          *     # Collect the density grid for the samples.
73          *     blur(someData, &histogram);
74          *
75          */
76         void operator()(const std::vector<double>& samples,
77                         std::vector<double>* grid)
78         {
79             const auto nbins = grid->size();
80             const double& dx{binWidth_};
81             const auto num_samples = samples.size();
82
83             const double denominator = 1.0 / (2 * sigma_ * sigma_);
84             const double normalization = 1.0 / (num_samples * sqrt(2.0 * M_PI * sigma_ * sigma_));
85             // We aren't doing any filtering of values too far away to contribute meaningfully, which
86             // is admittedly wasteful for large sigma...
87             for (size_t i = 0;i < nbins;++i)
88             {
89                 double bin_value{0};
90                 const double bin_x{low_ + i * dx};
91                 for (const auto distance : samples)
92                 {
93                     const double relative_distance{bin_x - distance};
94                     const auto numerator = -relative_distance * relative_distance;
95                     bin_value += normalization * exp(numerator * denominator);
96                 }
97                 grid->at(i) = bin_value;
98             }
99         };
100
101     private:
102         /// Minimum value of bin zero
103         const double low_;
104
105         /// Size of each bin
106         const double binWidth_;
107
108         /// Smoothing factor
109         const double sigma_;
110 };
111
112 EnsemblePotential::EnsemblePotential(size_t nbins,
113                                    double binWidth,
114                                    double minDist,
115                                    double maxDist,
116                                    PairHist experimental,
117                                    unsigned int nSamples,
118                                    double samplePeriod,
119                                    unsigned int nWindows,
120                                    double k,
121                                    double sigma) :
122     nBins_{nbins},
123     binWidth_{binWidth},
124     minDist_{minDist},
125     maxDist_{maxDist},
126     histogram_(nbins,
127                0),
128     experimental_{std::move(experimental)},
129     nSamples_{nSamples},
130     currentSample_{0},
131     samplePeriod_{samplePeriod},
132     // In actuality, we have nsamples at (samplePeriod - dt), but we don't have access to dt.
133     nextSampleTime_{samplePeriod},
134     distanceSamples_(nSamples),
135     nWindows_{nWindows},
136     currentWindow_{0},
137     windowStartTime_{0},
138     nextWindowUpdateTime_{nSamples * samplePeriod},
139     windows_{},
140     k_{k},
141     sigma_{sigma}
142 {}
143
144 EnsemblePotential::EnsemblePotential(const input_param_type& params) :
145     EnsemblePotential(params.nBins,
146                      params.binWidth,
147                      params.minDist,
148                      params.maxDist,
149                      params.experimental,
150                      params.nSamples,
151                      params.samplePeriod,
152                      params.nWindows,
153                      params.k,
154                      params.sigma)
155 {
156 }
157
158 //
159 //
160 // HERE is the (optional) function that updates the state of the restraint periodically.
161 // It is called before calculate() once per timestep per simulation (on the master rank of
162 // a parallelized simulation).
163 //
164 //
165 void EnsemblePotential::callback(gmx::Vector v,
166                                  gmx::Vector v0,
167                                  double t,
168                                  const Resources& resources)
169 {
170     const auto rdiff = v - v0;
171     const auto Rsquared = dot(rdiff,
172                               rdiff);
173     const auto R = sqrt(Rsquared);
174
175     // Store historical data every sample_period steps
176     if (t >= nextSampleTime_)
177     {
178         distanceSamples_[currentSample_++] = R;
179         nextSampleTime_ = (currentSample_ + 1) * samplePeriod_ + windowStartTime_;
180     };
181
182     // Every nsteps:
183     //   0. Drop oldest window
184     //   1. Reduce historical data for this restraint in this simulation.
185     //   2. Call out to the global reduction for this window.
186     //   3. On update, checkpoint the historical data source.
187     //   4. Update historic windows.
188     //   5. Use handles retained from previous windows to reconstruct the smoothed working histogram
189     if (t >= nextWindowUpdateTime_)
190     {
191         // Get next histogram array, recycling old one if available.
192         std::unique_ptr<Matrix<double>> new_window = std::make_unique<Matrix<double>>(1,
193                                                                                               nBins_);
194         std::unique_ptr<Matrix<double>> temp_window;
195         if (windows_.size() == nWindows_)
196         {
197             // Recycle the oldest window.
198             // \todo wrap this in a helper class that manages a buffer we can shuffle through.
199             windows_[0].swap(temp_window);
200             windows_.erase(windows_.begin());
201         }
202         else
203         {
204             auto new_temp_window = std::make_unique<Matrix<double>>(1,
205                                                                             nBins_);
206             assert(new_temp_window);
207             temp_window.swap(new_temp_window);
208         }
209
210         // Reduce sampled data for this restraint in this simulation, applying a Gaussian blur to fill a grid.
211         auto blur = BlurToGrid(0.0,
212                                binWidth_,
213                                sigma_);
214         assert(new_window != nullptr);
215         assert(distanceSamples_.size() == nSamples_);
216         assert(currentSample_ == nSamples_);
217         blur(distanceSamples_,
218              new_window->vector());
219         // We can just do the blur locally since there aren't many bins. Bundling these operations for
220         // all restraints could give us a chance at some parallelism. We should at least use some
221         // threading if we can.
222
223         // We request a handle each time before using resources to make error handling easier if there is a failure in
224         // one of the ensemble member processes and to give more freedom to how resources are managed from step to step.
225         auto ensemble = resources.getHandle();
226         // Get global reduction (sum) and checkpoint.
227         assert(temp_window != nullptr);
228         // Todo: in reduce function, give us a mean instead of a sum.
229         ensemble.reduce(*new_window,
230                         temp_window.get());
231
232         // Update window list with smoothed data.
233         windows_.emplace_back(std::move(new_window));
234
235         // Get new histogram difference. Subtract the experimental distribution to get the values to use in our potential.
236         for (auto& bin : histogram_)
237         {
238             bin = 0;
239         }
240         for (const auto& window : windows_)
241         {
242             for (size_t i = 0;i < window->cols();++i)
243             {
244                 histogram_.at(i) += (window->vector()->at(i) - experimental_.at(i)) / windows_.size();
245             }
246         }
247
248
249         // Note we do not have the integer timestep available here. Therefore, we can't guarantee that updates occur
250         // with the same number of MD steps in each interval, and the interval will effectively lose digits as the
251         // simulation progresses, so _update_period should be cleanly representable in binary. When we extract this
252         // to a facility, we can look for a part of the code with access to the current timestep.
253         windowStartTime_ = t;
254         nextWindowUpdateTime_ = nSamples_ * samplePeriod_ + windowStartTime_;
255         ++currentWindow_; // This is currently never used. I'm not sure it will be, either...
256
257         // Reset sample bufering.
258         currentSample_ = 0;
259         // Reset sample times.
260         nextSampleTime_ = t + samplePeriod_;
261     };
262
263 }
264
265
266 //
267 //
268 // HERE is the function that does the calculation of the restraint force.
269 //
270 //
271 gmx::PotentialPointData EnsemblePotential::calculate(gmx::Vector v,
272                                                     gmx::Vector v0,
273                                                     double /* t */)
274 {
275     // This is not the vector from v to v0. It is the position of a site
276     // at v, relative to the origin v0. This is a potentially confusing convention...
277     const auto rdiff = v - v0;
278     const auto Rsquared = dot(rdiff,
279                               rdiff);
280     const auto R = sqrt(Rsquared);
281
282
283     // Compute output
284     gmx::PotentialPointData output;
285     // Energy not needed right now.
286 //    output.energy = 0;
287
288     if (R != 0) // Direction of force is ill-defined when v == v0
289     {
290
291         double f{0};
292
293         if (R > maxDist_)
294         {
295             // apply a force to reduce R
296             f = k_ * (maxDist_ - R);
297         }
298         else if (R < minDist_)
299         {
300             // apply a force to increase R
301             f = k_ * (minDist_ - R);
302         }
303         else
304         {
305             double f_scal{0};
306
307             const size_t numBins = histogram_.size();
308             double normConst = sqrt(2 * M_PI) * sigma_ * sigma_ * sigma_;
309
310             for (size_t n = 0;n < numBins;n++)
311             {
312                 const double x{n * binWidth_ - R};
313                 const double argExp{-0.5 * x * x / (sigma_ * sigma_)};
314                 f_scal += histogram_.at(n) * exp(argExp) * x / normConst;
315             }
316             f = -k_ * f_scal;
317         }
318
319         const auto magnitude = f / norm(rdiff);
320         output.force = rdiff * static_cast<decltype(rdiff[0])>(magnitude);
321     }
322     return output;
323 }
324
325 std::unique_ptr<ensemble_input_param_type>
326 makeEnsembleParams(size_t nbins,
327                    double binWidth,
328                    double minDist,
329                    double maxDist,
330                    const std::vector<double>& experimental,
331                    unsigned int nSamples,
332                    double samplePeriod,
333                    unsigned int nWindows,
334                    double k,
335                    double sigma)
336 {
337     using std::make_unique;
338     auto params = make_unique<ensemble_input_param_type>();
339     params->nBins = nbins;
340     params->binWidth = binWidth;
341     params->minDist = minDist;
342     params->maxDist = maxDist;
343     params->experimental = experimental;
344     params->nSamples = nSamples;
345     params->samplePeriod = samplePeriod;
346     params->nWindows = nWindows;
347     params->k = k;
348     params->sigma = sigma;
349
350     return params;
351 };
352
353 // Important: Explicitly instantiate a definition for the templated class declared in ensemblepotential.h.
354 // Failing to do this will cause a linker error.
355 template
356 class ::plugin::RestraintModule<EnsembleRestraint>;
357
358 } // end namespace plugin