2e9fbc1b8020201664bded23612d1d6cd63ab62f
[alexxy/gromacs.git] / src / gromacs / statistics / statistics.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2012,2014,2015,2017, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "statistics.h"
40
41 #include <cmath>
42
43 #include "gromacs/math/functions.h"
44 #include "gromacs/math/vec.h"
45 #include "gromacs/utility/fatalerror.h"
46 #include "gromacs/utility/real.h"
47 #include "gromacs/utility/smalloc.h"
48
49 static int gmx_dnint(double x)
50 {
51     return static_cast<int>(x+0.5);
52 }
53
54 typedef struct gmx_stats {
55     double  aa, a, b, sigma_aa, sigma_a, sigma_b, aver, sigma_aver, error;
56     double  rmsd, Rdata, Rfit, Rfitaa, chi2, chi2aa;
57     double *x, *y, *dx, *dy;
58     int     computed;
59     int     np, np_c, nalloc;
60 } gmx_stats;
61
62 gmx_stats_t gmx_stats_init()
63 {
64     gmx_stats *stats;
65
66     snew(stats, 1);
67
68     return (gmx_stats_t) stats;
69 }
70
71 int gmx_stats_get_npoints(gmx_stats_t gstats, int *N)
72 {
73     gmx_stats *stats = (gmx_stats *) gstats;
74
75     *N = stats->np;
76
77     return estatsOK;
78 }
79
80 void gmx_stats_free(gmx_stats_t gstats)
81 {
82     gmx_stats *stats = (gmx_stats *) gstats;
83
84     sfree(stats->x);
85     sfree(stats->y);
86     sfree(stats->dx);
87     sfree(stats->dy);
88     sfree(stats);
89 }
90
91 int gmx_stats_add_point(gmx_stats_t gstats, double x, double y,
92                         double dx, double dy)
93 {
94     gmx_stats *stats = gstats;
95
96     if (stats->np+1 >= stats->nalloc)
97     {
98         if (stats->nalloc == 0)
99         {
100             stats->nalloc = 1024;
101         }
102         else
103         {
104             stats->nalloc *= 2;
105         }
106         srenew(stats->x, stats->nalloc);
107         srenew(stats->y, stats->nalloc);
108         srenew(stats->dx, stats->nalloc);
109         srenew(stats->dy, stats->nalloc);
110         for (int i = stats->np; (i < stats->nalloc); i++)
111         {
112             stats->x[i]  = 0;
113             stats->y[i]  = 0;
114             stats->dx[i] = 0;
115             stats->dy[i] = 0;
116         }
117     }
118     stats->x[stats->np]  = x;
119     stats->y[stats->np]  = y;
120     stats->dx[stats->np] = dx;
121     stats->dy[stats->np] = dy;
122     stats->np++;
123     stats->computed = 0;
124
125     return estatsOK;
126 }
127
128 int gmx_stats_get_point(gmx_stats_t gstats, real *x, real *y,
129                         real *dx, real *dy, real level)
130 {
131     gmx_stats *stats = gstats;
132     int        ok, outlier;
133     real       rmsd, r;
134
135     if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
136     {
137         return ok;
138     }
139     outlier = 0;
140     while ((outlier == 0) && (stats->np_c < stats->np))
141     {
142         r       = std::abs(stats->x[stats->np_c] - stats->y[stats->np_c]);
143         outlier = (r > rmsd*level);
144         if (outlier)
145         {
146             if (nullptr != x)
147             {
148                 *x  = stats->x[stats->np_c];
149             }
150             if (nullptr != y)
151             {
152                 *y  = stats->y[stats->np_c];
153             }
154             if (nullptr != dx)
155             {
156                 *dx = stats->dx[stats->np_c];
157             }
158             if (nullptr != dy)
159             {
160                 *dy = stats->dy[stats->np_c];
161             }
162         }
163         stats->np_c++;
164
165         if (outlier)
166         {
167             return estatsOK;
168         }
169     }
170
171     stats->np_c = 0;
172
173     return estatsNO_POINTS;
174 }
175
176 int gmx_stats_add_points(gmx_stats_t gstats, int n, real *x, real *y,
177                          real *dx, real *dy)
178 {
179     for (int i = 0; (i < n); i++)
180     {
181         int ok;
182         if ((ok = gmx_stats_add_point(gstats, x[i], y[i],
183                                       (nullptr != dx) ? dx[i] : 0,
184                                       (nullptr != dy) ? dy[i] : 0)) != estatsOK)
185         {
186             return ok;
187         }
188     }
189     return estatsOK;
190 }
191
192 static int gmx_stats_compute(gmx_stats *stats, int weight)
193 {
194     double yy, yx, xx, sx, sy, dy, chi2, chi2aa, d2;
195     double ssxx, ssyy, ssxy;
196     double w, wtot, yx_nw, sy_nw, sx_nw, yy_nw, xx_nw, dx2, dy2;
197
198     int    N = stats->np;
199
200     if (stats->computed == 0)
201     {
202         if (N < 1)
203         {
204             return estatsNO_POINTS;
205         }
206
207         xx   = xx_nw = 0;
208         yy   = yy_nw = 0;
209         yx   = yx_nw = 0;
210         sx   = sx_nw = 0;
211         sy   = sy_nw = 0;
212         wtot = 0;
213         d2   = 0;
214         for (int i = 0; (i < N); i++)
215         {
216             d2 += gmx::square(stats->x[i]-stats->y[i]);
217             if ((stats->dy[i]) && (weight == elsqWEIGHT_Y))
218             {
219                 w = 1/gmx::square(stats->dy[i]);
220             }
221             else
222             {
223                 w = 1;
224             }
225
226             wtot  += w;
227
228             xx    += w*gmx::square(stats->x[i]);
229             xx_nw += gmx::square(stats->x[i]);
230
231             yy    += w*gmx::square(stats->y[i]);
232             yy_nw += gmx::square(stats->y[i]);
233
234             yx    += w*stats->y[i]*stats->x[i];
235             yx_nw += stats->y[i]*stats->x[i];
236
237             sx    += w*stats->x[i];
238             sx_nw += stats->x[i];
239
240             sy    += w*stats->y[i];
241             sy_nw += stats->y[i];
242         }
243
244         /* Compute average, sigma and error */
245         stats->aver       = sy_nw/N;
246         stats->sigma_aver = std::sqrt(yy_nw/N - gmx::square(sy_nw/N));
247         stats->error      = stats->sigma_aver/std::sqrt(static_cast<double>(N));
248
249         /* Compute RMSD between x and y */
250         stats->rmsd = std::sqrt(d2/N);
251
252         /* Correlation coefficient for data */
253         yx_nw       /= N;
254         xx_nw       /= N;
255         yy_nw       /= N;
256         sx_nw       /= N;
257         sy_nw       /= N;
258         ssxx         = N*(xx_nw - gmx::square(sx_nw));
259         ssyy         = N*(yy_nw - gmx::square(sy_nw));
260         ssxy         = N*(yx_nw - (sx_nw*sy_nw));
261         stats->Rdata = std::sqrt(gmx::square(ssxy)/(ssxx*ssyy));
262
263         /* Compute straight line through datapoints, either with intercept
264            zero (result in aa) or with intercept variable (results in a
265            and b) */
266         yx = yx/wtot;
267         xx = xx/wtot;
268         sx = sx/wtot;
269         sy = sy/wtot;
270
271         stats->aa = (yx/xx);
272         stats->a  = (yx-sx*sy)/(xx-sx*sx);
273         stats->b  = (sy)-(stats->a)*(sx);
274
275         /* Compute chi2, deviation from a line y = ax+b. Also compute
276            chi2aa which returns the deviation from a line y = ax. */
277         chi2   = 0;
278         chi2aa = 0;
279         for (int i = 0; (i < N); i++)
280         {
281             if (stats->dy[i] > 0)
282             {
283                 dy = stats->dy[i];
284             }
285             else
286             {
287                 dy = 1;
288             }
289             chi2aa += gmx::square((stats->y[i]-(stats->aa*stats->x[i]))/dy);
290             chi2   += gmx::square((stats->y[i]-(stats->a*stats->x[i]+stats->b))/dy);
291         }
292         if (N > 2)
293         {
294             stats->chi2   = std::sqrt(chi2/(N-2));
295             stats->chi2aa = std::sqrt(chi2aa/(N-2));
296
297             /* Look up equations! */
298             dx2            = (xx-sx*sx);
299             dy2            = (yy-sy*sy);
300             stats->sigma_a = std::sqrt(stats->chi2/((N-2)*dx2));
301             stats->sigma_b = stats->sigma_a*std::sqrt(xx);
302             stats->Rfit    = std::abs(ssxy)/std::sqrt(ssxx*ssyy);
303             stats->Rfitaa  = stats->aa*std::sqrt(dx2/dy2);
304         }
305         else
306         {
307             stats->chi2    = 0;
308             stats->chi2aa  = 0;
309             stats->sigma_a = 0;
310             stats->sigma_b = 0;
311             stats->Rfit    = 0;
312             stats->Rfitaa  = 0;
313         }
314
315         stats->computed = 1;
316     }
317
318     return estatsOK;
319 }
320
321 int gmx_stats_get_ab(gmx_stats_t gstats, int weight,
322                      real *a, real *b, real *da, real *db,
323                      real *chi2, real *Rfit)
324 {
325     gmx_stats *stats = gstats;
326     int        ok;
327
328     if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
329     {
330         return ok;
331     }
332     if (nullptr != a)
333     {
334         *a    = stats->a;
335     }
336     if (nullptr != b)
337     {
338         *b    = stats->b;
339     }
340     if (nullptr != da)
341     {
342         *da   = stats->sigma_a;
343     }
344     if (nullptr != db)
345     {
346         *db   = stats->sigma_b;
347     }
348     if (nullptr != chi2)
349     {
350         *chi2 = stats->chi2;
351     }
352     if (nullptr != Rfit)
353     {
354         *Rfit = stats->Rfit;
355     }
356
357     return estatsOK;
358 }
359
360 int gmx_stats_get_a(gmx_stats_t gstats, int weight, real *a, real *da,
361                     real *chi2, real *Rfit)
362 {
363     gmx_stats *stats = gstats;
364     int        ok;
365
366     if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
367     {
368         return ok;
369     }
370     if (nullptr != a)
371     {
372         *a    = stats->aa;
373     }
374     if (nullptr != da)
375     {
376         *da   = stats->sigma_aa;
377     }
378     if (nullptr != chi2)
379     {
380         *chi2 = stats->chi2aa;
381     }
382     if (nullptr != Rfit)
383     {
384         *Rfit = stats->Rfitaa;
385     }
386
387     return estatsOK;
388 }
389
390 int gmx_stats_get_average(gmx_stats_t gstats, real *aver)
391 {
392     gmx_stats *stats = gstats;
393     int        ok;
394
395     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
396     {
397         return ok;
398     }
399
400     *aver = stats->aver;
401
402     return estatsOK;
403 }
404
405 int gmx_stats_get_ase(gmx_stats_t gstats, real *aver, real *sigma, real *error)
406 {
407     gmx_stats *stats = gstats;
408     int        ok;
409
410     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
411     {
412         return ok;
413     }
414
415     if (nullptr != aver)
416     {
417         *aver  = stats->aver;
418     }
419     if (nullptr != sigma)
420     {
421         *sigma = stats->sigma_aver;
422     }
423     if (nullptr != error)
424     {
425         *error = stats->error;
426     }
427
428     return estatsOK;
429 }
430
431 int gmx_stats_get_sigma(gmx_stats_t gstats, real *sigma)
432 {
433     gmx_stats *stats = gstats;
434     int        ok;
435
436     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
437     {
438         return ok;
439     }
440
441     *sigma = stats->sigma_aver;
442
443     return estatsOK;
444 }
445
446 int gmx_stats_get_error(gmx_stats_t gstats, real *error)
447 {
448     gmx_stats *stats = gstats;
449     int        ok;
450
451     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
452     {
453         return ok;
454     }
455
456     *error = stats->error;
457
458     return estatsOK;
459 }
460
461 int gmx_stats_get_corr_coeff(gmx_stats_t gstats, real *R)
462 {
463     gmx_stats *stats = gstats;
464     int        ok;
465
466     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
467     {
468         return ok;
469     }
470
471     *R = stats->Rdata;
472
473     return estatsOK;
474 }
475
476 int gmx_stats_get_rmsd(gmx_stats_t gstats, real *rmsd)
477 {
478     gmx_stats *stats = gstats;
479     int        ok;
480
481     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
482     {
483         return ok;
484     }
485
486     *rmsd = stats->rmsd;
487
488     return estatsOK;
489 }
490
491 int gmx_stats_dump_xy(gmx_stats_t gstats, FILE *fp)
492 {
493     gmx_stats *stats = gstats;
494
495     for (int i = 0; (i < stats->np); i++)
496     {
497         fprintf(fp, "%12g  %12g  %12g  %12g\n", stats->x[i], stats->y[i],
498                 stats->dx[i], stats->dy[i]);
499     }
500
501     return estatsOK;
502 }
503
504 int gmx_stats_remove_outliers(gmx_stats_t gstats, double level)
505 {
506     gmx_stats *stats = gstats;
507     int        iter  = 1, done = 0, ok;
508     real       rmsd, r;
509
510     while ((stats->np >= 10) && !done)
511     {
512         if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
513         {
514             return ok;
515         }
516         done = 1;
517         for (int i = 0; (i < stats->np); )
518         {
519             r = std::abs(stats->x[i]-stats->y[i]);
520             if (r > level*rmsd)
521             {
522                 fprintf(stderr, "Removing outlier, iter = %d, rmsd = %g, x = %g, y = %g\n",
523                         iter, rmsd, stats->x[i], stats->y[i]);
524                 if (i < stats->np-1)
525                 {
526                     stats->x[i]  = stats->x[stats->np-1];
527                     stats->y[i]  = stats->y[stats->np-1];
528                     stats->dx[i] = stats->dx[stats->np-1];
529                     stats->dy[i] = stats->dy[stats->np-1];
530                 }
531                 stats->np--;
532                 done = 0;
533             }
534             else
535             {
536                 i++;
537             }
538         }
539         iter++;
540     }
541
542     return estatsOK;
543 }
544
545 int gmx_stats_make_histogram(gmx_stats_t gstats, real binwidth, int *nb,
546                              int ehisto, int normalized, real **x, real **y)
547 {
548     gmx_stats *stats = gstats;
549     int        index = 0, nbins = *nb, *nindex;
550     double     minx, maxx, maxy, miny, delta, dd, minh;
551
552     if (((binwidth <= 0) && (nbins <= 0)) ||
553         ((binwidth > 0) && (nbins > 0)))
554     {
555         return estatsINVALID_INPUT;
556     }
557     if (stats->np <= 2)
558     {
559         return estatsNO_POINTS;
560     }
561     minx = maxx = stats->x[0];
562     miny = maxy = stats->y[0];
563     for (int i = 1; (i < stats->np); i++)
564     {
565         miny = (stats->y[i] < miny) ? stats->y[i] : miny;
566         maxy = (stats->y[i] > maxy) ? stats->y[i] : maxy;
567         minx = (stats->x[i] < minx) ? stats->x[i] : minx;
568         maxx = (stats->x[i] > maxx) ? stats->x[i] : maxx;
569     }
570     if (ehisto == ehistoX)
571     {
572         delta = maxx-minx;
573         minh  = minx;
574     }
575     else if (ehisto == ehistoY)
576     {
577         delta = maxy-miny;
578         minh  = miny;
579     }
580     else
581     {
582         return estatsINVALID_INPUT;
583     }
584
585     if (binwidth == 0)
586     {
587         binwidth = (delta)/nbins;
588     }
589     else
590     {
591         nbins = gmx_dnint((delta)/binwidth + 0.5);
592     }
593     snew(*x, nbins);
594     snew(nindex, nbins);
595     for (int i = 0; (i < nbins); i++)
596     {
597         (*x)[i] = minh + binwidth*(i+0.5);
598     }
599     if (normalized == 0)
600     {
601         dd = 1;
602     }
603     else
604     {
605         dd = 1.0/(binwidth*stats->np);
606     }
607
608     snew(*y, nbins);
609     for (int i = 0; (i < stats->np); i++)
610     {
611         if (ehisto == ehistoY)
612         {
613             index = static_cast<int>((stats->y[i]-miny)/binwidth);
614         }
615         else if (ehisto == ehistoX)
616         {
617             index = static_cast<int>((stats->x[i]-minx)/binwidth);
618         }
619         if (index < 0)
620         {
621             index = 0;
622         }
623         if (index > nbins-1)
624         {
625             index = nbins-1;
626         }
627         (*y)[index] += dd;
628         nindex[index]++;
629     }
630     if (*nb == 0)
631     {
632         *nb = nbins;
633     }
634     for (int i = 0; (i < nbins); i++)
635     {
636         if (nindex[i] > 0)
637         {
638             (*y)[i] /= nindex[i];
639         }
640     }
641
642     sfree(nindex);
643
644     return estatsOK;
645 }
646
647 static const char *stats_error[estatsNR] =
648 {
649     "All well in STATS land",
650     "No points",
651     "Not enough memory",
652     "Invalid histogram input",
653     "Unknown error",
654     "Not implemented yet"
655 };
656
657 const char *gmx_stats_message(int estats)
658 {
659     if ((estats >= 0) && (estats < estatsNR))
660     {
661         return stats_error[estats];
662     }
663     else
664     {
665         return stats_error[estatsERROR];
666     }
667 }
668
669 /* Old convenience functions, should be merged with the core
670    statistics above. */
671 int lsq_y_ax(int n, real x[], real y[], real *a)
672 {
673     gmx_stats_t lsq = gmx_stats_init();
674     int         ok;
675     real        da, chi2, Rfit;
676
677     gmx_stats_add_points(lsq, n, x, y, nullptr, nullptr);
678     ok = gmx_stats_get_a(lsq, elsqWEIGHT_NONE, a, &da, &chi2, &Rfit);
679     gmx_stats_free(lsq);
680
681     return ok;
682 }
683
684 static int low_lsq_y_ax_b(int n, real *xr, double *xd, real yr[],
685                           real *a, real *b, real *r, real *chi2)
686 {
687     gmx_stats_t lsq = gmx_stats_init();
688     int         ok;
689
690     for (int i = 0; (i < n); i++)
691     {
692         double pt;
693
694         if (xd != nullptr)
695         {
696             pt = xd[i];
697         }
698         else if (xr != nullptr)
699         {
700             pt = xr[i];
701         }
702         else
703         {
704             gmx_incons("Either xd or xr has to be non-NULL in low_lsq_y_ax_b()");
705         }
706
707         if ((ok = gmx_stats_add_point(lsq, pt, yr[i], 0, 0)) != estatsOK)
708         {
709             gmx_stats_free(lsq);
710             return ok;
711         }
712     }
713     ok = gmx_stats_get_ab(lsq, elsqWEIGHT_NONE, a, b, nullptr, nullptr, chi2, r);
714     gmx_stats_free(lsq);
715
716     return ok;
717 }
718
719 int lsq_y_ax_b(int n, real x[], real y[], real *a, real *b, real *r, real *chi2)
720 {
721     return low_lsq_y_ax_b(n, x, nullptr, y, a, b, r, chi2);
722 }
723
724 int lsq_y_ax_b_xdouble(int n, double x[], real y[], real *a, real *b,
725                        real *r, real *chi2)
726 {
727     return low_lsq_y_ax_b(n, nullptr, x, y, a, b, r, chi2);
728 }
729
730 int lsq_y_ax_b_error(int n, real x[], real y[], real dy[],
731                      real *a, real *b, real *da, real *db,
732                      real *r, real *chi2)
733 {
734     gmx_stats_t lsq = gmx_stats_init();
735     int         ok;
736
737     for (int i = 0; (i < n); i++)
738     {
739         ok = gmx_stats_add_point(lsq, x[i], y[i], 0, dy[i]);
740         if (ok != estatsOK)
741         {
742             gmx_stats_free(lsq);
743             return ok;
744         }
745     }
746     ok = gmx_stats_get_ab(lsq, elsqWEIGHT_Y, a, b, da, db, chi2, r);
747     gmx_stats_free(lsq);
748
749     return ok;
750 }