38ccc006e770863d3591c69aa0c8d029efd62dc2
[alexxy/gromacs.git] / src / gromacs / statistics / statistics.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2012,2014,2015,2017,2018 by the GROMACS development team.
7  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 #include "gmxpre.h"
39
40 #include "gromacs/utility/enumerationhelpers.h"
41 #include "statistics.h"
42
43 #include <cmath>
44
45 #include "gromacs/math/functions.h"
46 #include "gromacs/math/vec.h"
47 #include "gromacs/utility/enumerationhelpers.h"
48 #include "gromacs/utility/fatalerror.h"
49 #include "gromacs/utility/real.h"
50 #include "gromacs/utility/smalloc.h"
51
52
53 typedef struct gmx_stats
54 {
55     double  aa, a, b, sigma_aa, sigma_a, sigma_b, aver, sigma_aver, error;
56     double  rmsd, Rdata, Rfit, Rfitaa, chi2, chi2aa;
57     double *x, *y, *dx, *dy;
58     int     computed;
59     int     np, np_c, nalloc;
60 } gmx_stats;
61
62 gmx_stats_t gmx_stats_init()
63 {
64     gmx_stats* stats;
65
66     snew(stats, 1);
67
68     return static_cast<gmx_stats_t>(stats);
69 }
70
71 void gmx_stats_free(gmx_stats_t gstats)
72 {
73     gmx_stats* stats = static_cast<gmx_stats*>(gstats);
74
75     sfree(stats->x);
76     sfree(stats->y);
77     sfree(stats->dx);
78     sfree(stats->dy);
79     sfree(stats);
80 }
81
82 StatisticsStatus gmx_stats_add_point(gmx_stats_t gstats, double x, double y, double dx, double dy)
83 {
84     gmx_stats* stats = gstats;
85
86     if (stats->np + 1 >= stats->nalloc)
87     {
88         if (stats->nalloc == 0)
89         {
90             stats->nalloc = 1024;
91         }
92         else
93         {
94             stats->nalloc *= 2;
95         }
96         srenew(stats->x, stats->nalloc);
97         srenew(stats->y, stats->nalloc);
98         srenew(stats->dx, stats->nalloc);
99         srenew(stats->dy, stats->nalloc);
100         for (int i = stats->np; (i < stats->nalloc); i++)
101         {
102             stats->x[i]  = 0;
103             stats->y[i]  = 0;
104             stats->dx[i] = 0;
105             stats->dy[i] = 0;
106         }
107     }
108     stats->x[stats->np]  = x;
109     stats->y[stats->np]  = y;
110     stats->dx[stats->np] = dx;
111     stats->dy[stats->np] = dy;
112     stats->np++;
113     stats->computed = 0;
114
115     return StatisticsStatus::Ok;
116 }
117
118 static StatisticsStatus gmx_stats_compute(gmx_stats* stats, int weight)
119 {
120     double yy, yx, xx, sx, sy, dy, chi2, chi2aa, d2;
121     double ssxx, ssyy, ssxy;
122     double w, wtot, yx_nw, sy_nw, sx_nw, yy_nw, xx_nw, dx2, dy2;
123
124     int N = stats->np;
125
126     if (stats->computed == 0)
127     {
128         if (N < 1)
129         {
130             return StatisticsStatus::NoPoints;
131         }
132
133         xx = xx_nw = 0;
134         yy = yy_nw = 0;
135         yx = yx_nw = 0;
136         sx = sx_nw = 0;
137         sy = sy_nw = 0;
138         wtot       = 0;
139         d2         = 0;
140         for (int i = 0; (i < N); i++)
141         {
142             d2 += gmx::square(stats->x[i] - stats->y[i]);
143             if (((stats->dy[i]) != 0.0) && (weight == elsqWEIGHT_Y))
144             {
145                 w = 1 / gmx::square(stats->dy[i]);
146             }
147             else
148             {
149                 w = 1;
150             }
151
152             wtot += w;
153
154             xx += w * gmx::square(stats->x[i]);
155             xx_nw += gmx::square(stats->x[i]);
156
157             yy += w * gmx::square(stats->y[i]);
158             yy_nw += gmx::square(stats->y[i]);
159
160             yx += w * stats->y[i] * stats->x[i];
161             yx_nw += stats->y[i] * stats->x[i];
162
163             sx += w * stats->x[i];
164             sx_nw += stats->x[i];
165
166             sy += w * stats->y[i];
167             sy_nw += stats->y[i];
168         }
169
170         /* Compute average, sigma and error */
171         stats->aver       = sy_nw / N;
172         stats->sigma_aver = std::sqrt(yy_nw / N - gmx::square(sy_nw / N));
173         stats->error      = stats->sigma_aver / std::sqrt(static_cast<double>(N));
174
175         /* Compute RMSD between x and y */
176         stats->rmsd = std::sqrt(d2 / N);
177
178         /* Correlation coefficient for data */
179         yx_nw /= N;
180         xx_nw /= N;
181         yy_nw /= N;
182         sx_nw /= N;
183         sy_nw /= N;
184         ssxx         = N * (xx_nw - gmx::square(sx_nw));
185         ssyy         = N * (yy_nw - gmx::square(sy_nw));
186         ssxy         = N * (yx_nw - (sx_nw * sy_nw));
187         stats->Rdata = std::sqrt(gmx::square(ssxy) / (ssxx * ssyy));
188
189         /* Compute straight line through datapoints, either with intercept
190            zero (result in aa) or with intercept variable (results in a
191            and b) */
192         yx = yx / wtot;
193         xx = xx / wtot;
194         sx = sx / wtot;
195         sy = sy / wtot;
196
197         stats->aa = (yx / xx);
198         stats->a  = (yx - sx * sy) / (xx - sx * sx);
199         stats->b  = (sy) - (stats->a) * (sx);
200
201         /* Compute chi2, deviation from a line y = ax+b. Also compute
202            chi2aa which returns the deviation from a line y = ax. */
203         chi2   = 0;
204         chi2aa = 0;
205         for (int i = 0; (i < N); i++)
206         {
207             if (stats->dy[i] > 0)
208             {
209                 dy = stats->dy[i];
210             }
211             else
212             {
213                 dy = 1;
214             }
215             chi2aa += gmx::square((stats->y[i] - (stats->aa * stats->x[i])) / dy);
216             chi2 += gmx::square((stats->y[i] - (stats->a * stats->x[i] + stats->b)) / dy);
217         }
218         if (N > 2)
219         {
220             stats->chi2   = std::sqrt(chi2 / (N - 2));
221             stats->chi2aa = std::sqrt(chi2aa / (N - 2));
222
223             /* Look up equations! */
224             dx2            = (xx - sx * sx);
225             dy2            = (yy - sy * sy);
226             stats->sigma_a = std::sqrt(stats->chi2 / ((N - 2) * dx2));
227             stats->sigma_b = stats->sigma_a * std::sqrt(xx);
228             stats->Rfit    = std::abs(ssxy) / std::sqrt(ssxx * ssyy);
229             stats->Rfitaa  = stats->aa * std::sqrt(dx2 / dy2);
230         }
231         else
232         {
233             stats->chi2    = 0;
234             stats->chi2aa  = 0;
235             stats->sigma_a = 0;
236             stats->sigma_b = 0;
237             stats->Rfit    = 0;
238             stats->Rfitaa  = 0;
239         }
240
241         stats->computed = 1;
242     }
243
244     return StatisticsStatus::Ok;
245 }
246
247 StatisticsStatus
248 gmx_stats_get_ab(gmx_stats_t gstats, int weight, real* a, real* b, real* da, real* db, real* chi2, real* Rfit)
249 {
250     gmx_stats*       stats = gstats;
251     StatisticsStatus ok;
252
253     if ((ok = gmx_stats_compute(stats, weight)) != StatisticsStatus::Ok)
254     {
255         return ok;
256     }
257     if (nullptr != a)
258     {
259         *a = stats->a;
260     }
261     if (nullptr != b)
262     {
263         *b = stats->b;
264     }
265     if (nullptr != da)
266     {
267         *da = stats->sigma_a;
268     }
269     if (nullptr != db)
270     {
271         *db = stats->sigma_b;
272     }
273     if (nullptr != chi2)
274     {
275         *chi2 = stats->chi2;
276     }
277     if (nullptr != Rfit)
278     {
279         *Rfit = stats->Rfit;
280     }
281
282     return StatisticsStatus::Ok;
283 }
284
285 StatisticsStatus gmx_stats_get_average(gmx_stats_t gstats, real* aver)
286 {
287     gmx_stats*       stats = gstats;
288     StatisticsStatus ok;
289
290     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != StatisticsStatus::Ok)
291     {
292         return ok;
293     }
294
295     *aver = stats->aver;
296
297     return StatisticsStatus::Ok;
298 }
299
300 StatisticsStatus gmx_stats_get_ase(gmx_stats_t gstats, real* aver, real* sigma, real* error)
301 {
302     gmx_stats*       stats = gstats;
303     StatisticsStatus ok;
304
305     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != StatisticsStatus::Ok)
306     {
307         return ok;
308     }
309
310     if (nullptr != aver)
311     {
312         *aver = stats->aver;
313     }
314     if (nullptr != sigma)
315     {
316         *sigma = stats->sigma_aver;
317     }
318     if (nullptr != error)
319     {
320         *error = stats->error;
321     }
322
323     return StatisticsStatus::Ok;
324 }
325
326 static const char* enumValueToString(StatisticsStatus enumValue)
327 {
328     constexpr gmx::EnumerationArray<StatisticsStatus, const char*> statisticsStatusNames = {
329         "All well in STATS land",  "No points",     "Not enough memory",
330         "Invalid histogram input", "Unknown error", "Not implemented yet"
331     };
332     return statisticsStatusNames[enumValue];
333 }
334
335 void gmx_stats_message([[maybe_unused]] StatisticsStatus estats)
336 {
337     GMX_ASSERT(estats == StatisticsStatus::Ok, enumValueToString(estats));
338 }
339
340 static StatisticsStatus
341 low_lsq_y_ax_b(int n, const real* xr, const double* xd, real yr[], real* a, real* b, real* r, real* chi2)
342 {
343     gmx_stats_t      lsq = gmx_stats_init();
344     StatisticsStatus ok;
345
346     for (int i = 0; (i < n); i++)
347     {
348         double pt;
349
350         if (xd != nullptr)
351         {
352             pt = xd[i];
353         }
354         else if (xr != nullptr)
355         {
356             pt = xr[i];
357         }
358         else
359         {
360             gmx_incons("Either xd or xr has to be non-NULL in low_lsq_y_ax_b()");
361         }
362
363         if ((ok = gmx_stats_add_point(lsq, pt, yr[i], 0, 0)) != StatisticsStatus::Ok)
364         {
365             gmx_stats_free(lsq);
366             return ok;
367         }
368     }
369     ok = gmx_stats_get_ab(lsq, elsqWEIGHT_NONE, a, b, nullptr, nullptr, chi2, r);
370     gmx_stats_free(lsq);
371
372     return ok;
373 }
374
375 StatisticsStatus lsq_y_ax_b(int n, real x[], real y[], real* a, real* b, real* r, real* chi2)
376 {
377     return low_lsq_y_ax_b(n, x, nullptr, y, a, b, r, chi2);
378 }
379
380 StatisticsStatus lsq_y_ax_b_xdouble(int n, double x[], real y[], real* a, real* b, real* r, real* chi2)
381 {
382     return low_lsq_y_ax_b(n, nullptr, x, y, a, b, r, chi2);
383 }
384
385 StatisticsStatus
386 lsq_y_ax_b_error(int n, real x[], real y[], real dy[], real* a, real* b, real* da, real* db, real* r, real* chi2)
387 {
388     gmx_stats_t      lsq = gmx_stats_init();
389     StatisticsStatus ok;
390
391     for (int i = 0; (i < n); i++)
392     {
393         ok = gmx_stats_add_point(lsq, x[i], y[i], 0, dy[i]);
394         if (ok != StatisticsStatus::Ok)
395         {
396             gmx_stats_free(lsq);
397             return ok;
398         }
399     }
400     ok = gmx_stats_get_ab(lsq, elsqWEIGHT_Y, a, b, da, db, chi2, r);
401     gmx_stats_free(lsq);
402
403     return ok;
404 }