Moved statistics source to C++
[alexxy/gromacs.git] / src / gromacs / statistics / statistics.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2012,2014,2015, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "statistics.h"
40
41 #include <cmath>
42
43 #include "gromacs/math/vec.h"
44 #include "gromacs/utility/real.h"
45 #include "gromacs/utility/smalloc.h"
46
47 static int gmx_dnint(double x)
48 {
49     return static_cast<int>(x+0.5);
50 }
51
52 typedef struct gmx_stats {
53     double  aa, a, b, sigma_aa, sigma_a, sigma_b, aver, sigma_aver, error;
54     double  rmsd, Rdata, Rfit, Rfitaa, chi2, chi2aa;
55     double *x, *y, *dx, *dy;
56     int     computed;
57     int     np, np_c, nalloc;
58 } gmx_stats;
59
60 gmx_stats_t gmx_stats_init()
61 {
62     gmx_stats *stats;
63
64     snew(stats, 1);
65
66     return (gmx_stats_t) stats;
67 }
68
69 int gmx_stats_get_npoints(gmx_stats_t gstats, int *N)
70 {
71     gmx_stats *stats = (gmx_stats *) gstats;
72
73     *N = stats->np;
74
75     return estatsOK;
76 }
77
78 int gmx_stats_done(gmx_stats_t gstats)
79 {
80     gmx_stats *stats = (gmx_stats *) gstats;
81
82     sfree(stats->x);
83     stats->x = NULL;
84     sfree(stats->y);
85     stats->y = NULL;
86     sfree(stats->dx);
87     stats->dx = NULL;
88     sfree(stats->dy);
89     stats->dy = NULL;
90
91     return estatsOK;
92 }
93
94 int gmx_stats_add_point(gmx_stats_t gstats, double x, double y,
95                         double dx, double dy)
96 {
97     gmx_stats *stats = gstats;
98
99     if (stats->np+1 >= stats->nalloc)
100     {
101         if (stats->nalloc == 0)
102         {
103             stats->nalloc = 1024;
104         }
105         else
106         {
107             stats->nalloc *= 2;
108         }
109         srenew(stats->x, stats->nalloc);
110         srenew(stats->y, stats->nalloc);
111         srenew(stats->dx, stats->nalloc);
112         srenew(stats->dy, stats->nalloc);
113         for (int i = stats->np; (i < stats->nalloc); i++)
114         {
115             stats->x[i]  = 0;
116             stats->y[i]  = 0;
117             stats->dx[i] = 0;
118             stats->dy[i] = 0;
119         }
120     }
121     stats->x[stats->np]  = x;
122     stats->y[stats->np]  = y;
123     stats->dx[stats->np] = dx;
124     stats->dy[stats->np] = dy;
125     stats->np++;
126     stats->computed = 0;
127
128     return estatsOK;
129 }
130
131 int gmx_stats_get_point(gmx_stats_t gstats, real *x, real *y,
132                         real *dx, real *dy, real level)
133 {
134     gmx_stats *stats = gstats;
135     int        ok, outlier;
136     real       rmsd, r;
137
138     if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
139     {
140         return ok;
141     }
142     outlier = 0;
143     while ((outlier == 0) && (stats->np_c < stats->np))
144     {
145         r       = std::abs(stats->x[stats->np_c] - stats->y[stats->np_c]);
146         outlier = (r > rmsd*level);
147         if (outlier)
148         {
149             if (NULL != x)
150             {
151                 *x  = stats->x[stats->np_c];
152             }
153             if (NULL != y)
154             {
155                 *y  = stats->y[stats->np_c];
156             }
157             if (NULL != dx)
158             {
159                 *dx = stats->dx[stats->np_c];
160             }
161             if (NULL != dy)
162             {
163                 *dy = stats->dy[stats->np_c];
164             }
165         }
166         stats->np_c++;
167
168         if (outlier)
169         {
170             return estatsOK;
171         }
172     }
173
174     stats->np_c = 0;
175
176     return estatsNO_POINTS;
177 }
178
179 int gmx_stats_add_points(gmx_stats_t gstats, int n, real *x, real *y,
180                          real *dx, real *dy)
181 {
182     for (int i = 0; (i < n); i++)
183     {
184         int ok;
185         if ((ok = gmx_stats_add_point(gstats, x[i], y[i],
186                                       (NULL != dx) ? dx[i] : 0,
187                                       (NULL != dy) ? dy[i] : 0)) != estatsOK)
188         {
189             return ok;
190         }
191     }
192     return estatsOK;
193 }
194
195 static int gmx_stats_compute(gmx_stats *stats, int weight)
196 {
197     double yy, yx, xx, sx, sy, dy, chi2, chi2aa, d2;
198     double ssxx, ssyy, ssxy;
199     double w, wtot, yx_nw, sy_nw, sx_nw, yy_nw, xx_nw, dx2, dy2;
200
201     int    N = stats->np;
202
203     if (stats->computed == 0)
204     {
205         if (N < 1)
206         {
207             return estatsNO_POINTS;
208         }
209
210         xx   = xx_nw = 0;
211         yy   = yy_nw = 0;
212         yx   = yx_nw = 0;
213         sx   = sx_nw = 0;
214         sy   = sy_nw = 0;
215         wtot = 0;
216         d2   = 0;
217         for (int i = 0; (i < N); i++)
218         {
219             d2 += dsqr(stats->x[i]-stats->y[i]);
220             if ((stats->dy[i]) && (weight == elsqWEIGHT_Y))
221             {
222                 w = 1/dsqr(stats->dy[i]);
223             }
224             else
225             {
226                 w = 1;
227             }
228
229             wtot  += w;
230
231             xx    += w*dsqr(stats->x[i]);
232             xx_nw += dsqr(stats->x[i]);
233
234             yy    += w*dsqr(stats->y[i]);
235             yy_nw += dsqr(stats->y[i]);
236
237             yx    += w*stats->y[i]*stats->x[i];
238             yx_nw += stats->y[i]*stats->x[i];
239
240             sx    += w*stats->x[i];
241             sx_nw += stats->x[i];
242
243             sy    += w*stats->y[i];
244             sy_nw += stats->y[i];
245         }
246
247         /* Compute average, sigma and error */
248         stats->aver       = sy_nw/N;
249         stats->sigma_aver = std::sqrt(yy_nw/N - dsqr(sy_nw/N));
250         stats->error      = stats->sigma_aver/std::sqrt(static_cast<double>(N));
251
252         /* Compute RMSD between x and y */
253         stats->rmsd = std::sqrt(d2/N);
254
255         /* Correlation coefficient for data */
256         yx_nw       /= N;
257         xx_nw       /= N;
258         yy_nw       /= N;
259         sx_nw       /= N;
260         sy_nw       /= N;
261         ssxx         = N*(xx_nw - dsqr(sx_nw));
262         ssyy         = N*(yy_nw - dsqr(sy_nw));
263         ssxy         = N*(yx_nw - (sx_nw*sy_nw));
264         stats->Rdata = std::sqrt(dsqr(ssxy)/(ssxx*ssyy));
265
266         /* Compute straight line through datapoints, either with intercept
267            zero (result in aa) or with intercept variable (results in a
268            and b) */
269         yx = yx/wtot;
270         xx = xx/wtot;
271         sx = sx/wtot;
272         sy = sy/wtot;
273
274         stats->aa = (yx/xx);
275         stats->a  = (yx-sx*sy)/(xx-sx*sx);
276         stats->b  = (sy)-(stats->a)*(sx);
277
278         /* Compute chi2, deviation from a line y = ax+b. Also compute
279            chi2aa which returns the deviation from a line y = ax. */
280         chi2   = 0;
281         chi2aa = 0;
282         for (int i = 0; (i < N); i++)
283         {
284             if (stats->dy[i] > 0)
285             {
286                 dy = stats->dy[i];
287             }
288             else
289             {
290                 dy = 1;
291             }
292             chi2aa += dsqr((stats->y[i]-(stats->aa*stats->x[i]))/dy);
293             chi2   += dsqr((stats->y[i]-(stats->a*stats->x[i]+stats->b))/dy);
294         }
295         if (N > 2)
296         {
297             stats->chi2   = std::sqrt(chi2/(N-2));
298             stats->chi2aa = std::sqrt(chi2aa/(N-2));
299
300             /* Look up equations! */
301             dx2            = (xx-sx*sx);
302             dy2            = (yy-sy*sy);
303             stats->sigma_a = std::sqrt(stats->chi2/((N-2)*dx2));
304             stats->sigma_b = stats->sigma_a*std::sqrt(xx);
305             stats->Rfit    = std::abs(ssxy)/std::sqrt(ssxx*ssyy);
306             stats->Rfitaa  = stats->aa*std::sqrt(dx2/dy2);
307         }
308         else
309         {
310             stats->chi2    = 0;
311             stats->chi2aa  = 0;
312             stats->sigma_a = 0;
313             stats->sigma_b = 0;
314             stats->Rfit    = 0;
315             stats->Rfitaa  = 0;
316         }
317
318         stats->computed = 1;
319     }
320
321     return estatsOK;
322 }
323
324 int gmx_stats_get_ab(gmx_stats_t gstats, int weight,
325                      real *a, real *b, real *da, real *db,
326                      real *chi2, real *Rfit)
327 {
328     gmx_stats *stats = gstats;
329     int        ok;
330
331     if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
332     {
333         return ok;
334     }
335     if (NULL != a)
336     {
337         *a    = stats->a;
338     }
339     if (NULL != b)
340     {
341         *b    = stats->b;
342     }
343     if (NULL != da)
344     {
345         *da   = stats->sigma_a;
346     }
347     if (NULL != db)
348     {
349         *db   = stats->sigma_b;
350     }
351     if (NULL != chi2)
352     {
353         *chi2 = stats->chi2;
354     }
355     if (NULL != Rfit)
356     {
357         *Rfit = stats->Rfit;
358     }
359
360     return estatsOK;
361 }
362
363 int gmx_stats_get_a(gmx_stats_t gstats, int weight, real *a, real *da,
364                     real *chi2, real *Rfit)
365 {
366     gmx_stats *stats = gstats;
367     int        ok;
368
369     if ((ok = gmx_stats_compute(stats, weight)) != estatsOK)
370     {
371         return ok;
372     }
373     if (NULL != a)
374     {
375         *a    = stats->aa;
376     }
377     if (NULL != da)
378     {
379         *da   = stats->sigma_aa;
380     }
381     if (NULL != chi2)
382     {
383         *chi2 = stats->chi2aa;
384     }
385     if (NULL != Rfit)
386     {
387         *Rfit = stats->Rfitaa;
388     }
389
390     return estatsOK;
391 }
392
393 int gmx_stats_get_average(gmx_stats_t gstats, real *aver)
394 {
395     gmx_stats *stats = gstats;
396     int        ok;
397
398     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
399     {
400         return ok;
401     }
402
403     *aver = stats->aver;
404
405     return estatsOK;
406 }
407
408 int gmx_stats_get_ase(gmx_stats_t gstats, real *aver, real *sigma, real *error)
409 {
410     gmx_stats *stats = gstats;
411     int        ok;
412
413     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
414     {
415         return ok;
416     }
417
418     if (NULL != aver)
419     {
420         *aver  = stats->aver;
421     }
422     if (NULL != sigma)
423     {
424         *sigma = stats->sigma_aver;
425     }
426     if (NULL != error)
427     {
428         *error = stats->error;
429     }
430
431     return estatsOK;
432 }
433
434 int gmx_stats_get_sigma(gmx_stats_t gstats, real *sigma)
435 {
436     gmx_stats *stats = gstats;
437     int        ok;
438
439     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
440     {
441         return ok;
442     }
443
444     *sigma = stats->sigma_aver;
445
446     return estatsOK;
447 }
448
449 int gmx_stats_get_error(gmx_stats_t gstats, real *error)
450 {
451     gmx_stats *stats = gstats;
452     int        ok;
453
454     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
455     {
456         return ok;
457     }
458
459     *error = stats->error;
460
461     return estatsOK;
462 }
463
464 int gmx_stats_get_corr_coeff(gmx_stats_t gstats, real *R)
465 {
466     gmx_stats *stats = gstats;
467     int        ok;
468
469     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
470     {
471         return ok;
472     }
473
474     *R = stats->Rdata;
475
476     return estatsOK;
477 }
478
479 int gmx_stats_get_rmsd(gmx_stats_t gstats, real *rmsd)
480 {
481     gmx_stats *stats = gstats;
482     int        ok;
483
484     if ((ok = gmx_stats_compute(stats, elsqWEIGHT_NONE)) != estatsOK)
485     {
486         return ok;
487     }
488
489     *rmsd = stats->rmsd;
490
491     return estatsOK;
492 }
493
494 int gmx_stats_dump_xy(gmx_stats_t gstats, FILE *fp)
495 {
496     gmx_stats *stats = gstats;
497
498     for (int i = 0; (i < stats->np); i++)
499     {
500         fprintf(fp, "%12g  %12g  %12g  %12g\n", stats->x[i], stats->y[i],
501                 stats->dx[i], stats->dy[i]);
502     }
503
504     return estatsOK;
505 }
506
507 int gmx_stats_remove_outliers(gmx_stats_t gstats, double level)
508 {
509     gmx_stats *stats = gstats;
510     int        iter  = 1, done = 0, ok;
511     real       rmsd, r;
512
513     while ((stats->np >= 10) && !done)
514     {
515         if ((ok = gmx_stats_get_rmsd(gstats, &rmsd)) != estatsOK)
516         {
517             return ok;
518         }
519         done = 1;
520         for (int i = 0; (i < stats->np); )
521         {
522             r = std::abs(stats->x[i]-stats->y[i]);
523             if (r > level*rmsd)
524             {
525                 fprintf(stderr, "Removing outlier, iter = %d, rmsd = %g, x = %g, y = %g\n",
526                         iter, rmsd, stats->x[i], stats->y[i]);
527                 if (i < stats->np-1)
528                 {
529                     stats->x[i]  = stats->x[stats->np-1];
530                     stats->y[i]  = stats->y[stats->np-1];
531                     stats->dx[i] = stats->dx[stats->np-1];
532                     stats->dy[i] = stats->dy[stats->np-1];
533                 }
534                 stats->np--;
535                 done = 0;
536             }
537             else
538             {
539                 i++;
540             }
541         }
542         iter++;
543     }
544
545     return estatsOK;
546 }
547
548 int gmx_stats_make_histogram(gmx_stats_t gstats, real binwidth, int *nb,
549                              int ehisto, int normalized, real **x, real **y)
550 {
551     gmx_stats *stats = gstats;
552     int        index = 0, nbins = *nb, *nindex;
553     double     minx, maxx, maxy, miny, delta, dd, minh;
554
555     if (((binwidth <= 0) && (nbins <= 0)) ||
556         ((binwidth > 0) && (nbins > 0)))
557     {
558         return estatsINVALID_INPUT;
559     }
560     if (stats->np <= 2)
561     {
562         return estatsNO_POINTS;
563     }
564     minx = maxx = stats->x[0];
565     miny = maxy = stats->y[0];
566     for (int i = 1; (i < stats->np); i++)
567     {
568         miny = (stats->y[i] < miny) ? stats->y[i] : miny;
569         maxy = (stats->y[i] > maxy) ? stats->y[i] : maxy;
570         minx = (stats->x[i] < minx) ? stats->x[i] : minx;
571         maxx = (stats->x[i] > maxx) ? stats->x[i] : maxx;
572     }
573     if (ehisto == ehistoX)
574     {
575         delta = maxx-minx;
576         minh  = minx;
577     }
578     else if (ehisto == ehistoY)
579     {
580         delta = maxy-miny;
581         minh  = miny;
582     }
583     else
584     {
585         return estatsINVALID_INPUT;
586     }
587
588     if (binwidth == 0)
589     {
590         binwidth = (delta)/nbins;
591     }
592     else
593     {
594         nbins = gmx_dnint((delta)/binwidth + 0.5);
595     }
596     snew(*x, nbins);
597     snew(nindex, nbins);
598     for (int i = 0; (i < nbins); i++)
599     {
600         (*x)[i] = minh + binwidth*(i+0.5);
601     }
602     if (normalized == 0)
603     {
604         dd = 1;
605     }
606     else
607     {
608         dd = 1.0/(binwidth*stats->np);
609     }
610
611     snew(*y, nbins);
612     for (int i = 0; (i < stats->np); i++)
613     {
614         if (ehisto == ehistoY)
615         {
616             index = static_cast<int>((stats->y[i]-miny)/binwidth);
617         }
618         else if (ehisto == ehistoX)
619         {
620             index = static_cast<int>((stats->x[i]-minx)/binwidth);
621         }
622         if (index < 0)
623         {
624             index = 0;
625         }
626         if (index > nbins-1)
627         {
628             index = nbins-1;
629         }
630         (*y)[index] += dd;
631         nindex[index]++;
632     }
633     if (*nb == 0)
634     {
635         *nb = nbins;
636     }
637     for (int i = 0; (i < nbins); i++)
638     {
639         if (nindex[i] > 0)
640         {
641             (*y)[i] /= nindex[i];
642         }
643     }
644
645     sfree(nindex);
646
647     return estatsOK;
648 }
649
650 static const char *stats_error[estatsNR] =
651 {
652     "All well in STATS land",
653     "No points",
654     "Not enough memory",
655     "Invalid histogram input",
656     "Unknown error",
657     "Not implemented yet"
658 };
659
660 const char *gmx_stats_message(int estats)
661 {
662     if ((estats >= 0) && (estats < estatsNR))
663     {
664         return stats_error[estats];
665     }
666     else
667     {
668         return stats_error[estatsERROR];
669     }
670 }
671
672 /* Old convenience functions, should be merged with the core
673    statistics above. */
674 int lsq_y_ax(int n, real x[], real y[], real *a)
675 {
676     gmx_stats_t lsq = gmx_stats_init();
677     int         ok;
678     real        da, chi2, Rfit;
679
680     gmx_stats_add_points(lsq, n, x, y, 0, 0);
681     if ((ok = gmx_stats_get_a(lsq, elsqWEIGHT_NONE, a, &da, &chi2, &Rfit)) != estatsOK)
682     {
683         return ok;
684     }
685
686     return estatsOK;
687 }
688
689 static int low_lsq_y_ax_b(int n, real *xr, double *xd, real yr[],
690                           real *a, real *b, real *r, real *chi2)
691 {
692     gmx_stats_t lsq = gmx_stats_init();
693     int         ok;
694
695     for (int i = 0; (i < n); i++)
696     {
697         double pt;
698
699         if (xd != NULL)
700         {
701             pt = xd[i];
702         }
703         else if (xr != NULL)
704         {
705             pt = xr[i];
706         }
707         else
708         {
709             gmx_incons("Either xd or xr has to be non-NULL in low_lsq_y_ax_b()");
710         }
711
712         if ((ok = gmx_stats_add_point(lsq, pt, yr[i], 0, 0)) != estatsOK)
713         {
714             return ok;
715         }
716     }
717     if ((ok = gmx_stats_get_ab(lsq, elsqWEIGHT_NONE, a, b, NULL, NULL, chi2, r)) != estatsOK)
718     {
719         return ok;
720     }
721
722     return estatsOK;
723 }
724
725 int lsq_y_ax_b(int n, real x[], real y[], real *a, real *b, real *r, real *chi2)
726 {
727     return low_lsq_y_ax_b(n, x, NULL, y, a, b, r, chi2);
728 }
729
730 int lsq_y_ax_b_xdouble(int n, double x[], real y[], real *a, real *b,
731                        real *r, real *chi2)
732 {
733     return low_lsq_y_ax_b(n, NULL, x, y, a, b, r, chi2);
734 }
735
736 int lsq_y_ax_b_error(int n, real x[], real y[], real dy[],
737                      real *a, real *b, real *da, real *db,
738                      real *r, real *chi2)
739 {
740     gmx_stats_t lsq = gmx_stats_init();
741     int         ok;
742
743     for (int i = 0; (i < n); i++)
744     {
745         if ((ok = gmx_stats_add_point(lsq, x[i], y[i], 0, dy[i])) != estatsOK)
746         {
747             return ok;
748         }
749     }
750     if ((ok = gmx_stats_get_ab(lsq, elsqWEIGHT_Y, a, b, da, db, chi2, r)) != estatsOK)
751     {
752         return ok;
753     }
754     if ((ok = gmx_stats_done(lsq)) != estatsOK)
755     {
756         return ok;
757     }
758     sfree(lsq);
759
760     return estatsOK;
761 }