8462fc660839680961a49b13507ad5eb5e74d86c
[alexxy/gromacs.git] / src / gromacs / linearalgebra / matrix.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *                This source code is part of
4  *
5  *                 G   R   O   M   A   C   S
6  *
7  *          GROningen MAchine for Chemical Simulations
8  *
9  *                        VERSION 4.0.99
10  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
11  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
12  * Copyright (c) 2001-2008, The GROMACS development team,
13  * check out http://www.gromacs.org for more information.
14
15  * This program is free software; you can redistribute it and/or
16  * modify it under the terms of the GNU General Public License
17  * as published by the Free Software Foundation; either version 2
18  * of the License, or (at your option) any later version.
19  *
20  * If you want to redistribute modifications, please consider that
21  * scientific software is very special. Version control is crucial -
22  * bugs must be traceable. We will be happy to consider code for
23  * inclusion in the official distribution, but derived work must not
24  * be called official GROMACS. Details are found in the README & COPYING
25  * files - if they are missing, get the official version at www.gromacs.org.
26  *
27  * To help us fund GROMACS development, we humbly ask that you cite
28  * the papers on the package - you can find them in the top README file.
29  *
30  * For more info, check our website at http://www.gromacs.org
31  *
32  * And Hey:
33  * Groningen Machine for Chemical Simulation
34  */
35 #include "matrix.h"
36
37 #ifdef HAVE_CONFIG_H
38 #include <config.h>
39 #endif
40
41 #include <stdio.h>
42
43 #include "gromacs/legacyheaders/gmx_fatal.h"
44 #include "gromacs/legacyheaders/smalloc.h"
45 #include "gromacs/legacyheaders/vec.h"
46
47 #include "gmx_lapack.h"
48
49 double **alloc_matrix(int n, int m)
50 {
51     double **ptr;
52     int      i;
53
54     /* There's always time for more pointer arithmetic! */
55     /* This is necessary in order to be able to work with LAPACK */
56     snew(ptr, n);
57     snew(ptr[0], n*m);
58     for (i = 1; (i < n); i++)
59     {
60         ptr[i] = ptr[i-1]+m;
61     }
62     return ptr;
63 }
64
65 void free_matrix(double **a)
66 {
67     int i;
68
69     sfree(a[0]);
70     sfree(a);
71 }
72
73 #define DEBUG_MATRIX
74 void matrix_multiply(FILE *fp, int n, int m, double **x, double **y, double **z)
75 {
76     int i, j, k;
77
78 #ifdef DEBUG_MATRIX
79     if (fp)
80     {
81         fprintf(fp, "Multiplying %d x %d matrix with a %d x %d matrix\n",
82                 n, m, m, n);
83     }
84     if (fp)
85     {
86         for (i = 0; (i < n); i++)
87         {
88             for (j = 0; (j < m); j++)
89             {
90                 fprintf(fp, " %7g", x[i][j]);
91             }
92             fprintf(fp, "\n");
93         }
94     }
95 #endif
96     for (i = 0; (i < m); i++)
97     {
98         for (j = 0; (j < m); j++)
99         {
100             z[i][j] = 0;
101             for (k = 0; (k < n); k++)
102             {
103                 z[i][j] += x[k][i]*y[j][k];
104             }
105         }
106     }
107 }
108
109 static void dump_matrix(FILE *fp, const char *title, int n, double **a)
110 {
111     double d = 1;
112     int    i, j;
113
114     fprintf(fp, "%s\n", title);
115     for (i = 0; (i < n); i++)
116     {
117         d = d*a[i][i];
118         for (j = 0; (j < n); j++)
119         {
120             fprintf(fp, " %8.2f", a[i][j]);
121         }
122         fprintf(fp, "\n");
123     }
124     fprintf(fp, "Prod a[i][i] = %g\n", d);
125 }
126
127 int matrix_invert(FILE *fp, int n, double **a)
128 {
129     int      i, j, m, lda, *ipiv, lwork, info;
130     double **test = NULL, **id, *work;
131
132 #ifdef DEBUG_MATRIX
133     if (fp)
134     {
135         fprintf(fp, "Inverting %d square matrix\n", n);
136         test = alloc_matrix(n, n);
137         for (i = 0; (i < n); i++)
138         {
139             for (j = 0; (j < n); j++)
140             {
141                 test[i][j] = a[i][j];
142             }
143         }
144         dump_matrix(fp, "before inversion", n, a);
145     }
146 #endif
147     snew(ipiv, n);
148     lwork = n*n;
149     snew(work, lwork);
150     m     = lda   = n;
151     info  = 0;
152     F77_FUNC(dgetrf, DGETRF) (&n, &m, a[0], &lda, ipiv, &info);
153 #ifdef DEBUG_MATRIX
154     if (fp)
155     {
156         dump_matrix(fp, "after dgetrf", n, a);
157     }
158 #endif
159     if (info != 0)
160     {
161         return info;
162     }
163     F77_FUNC(dgetri, DGETRI) (&n, a[0], &lda, ipiv, work, &lwork, &info);
164 #ifdef DEBUG_MATRIX
165     if (fp)
166     {
167         dump_matrix(fp, "after dgetri", n, a);
168     }
169 #endif
170     if (info != 0)
171     {
172         return info;
173     }
174
175 #ifdef DEBUG_MATRIX
176     if (fp)
177     {
178         id = alloc_matrix(n, n);
179         matrix_multiply(fp, n, n, test, a, id);
180         dump_matrix(fp, "And here is the product of A and Ainv", n, id);
181         free_matrix(id);
182         free_matrix(test);
183     }
184 #endif
185     sfree(ipiv);
186     sfree(work);
187
188     return 0;
189 }
190
191 double multi_regression(FILE *fp, int nrow, double *y, int ncol,
192                         double **xx, double *a0)
193 {
194     int    row, niter, i, j;
195     double ax, chi2, **a, **at, **ata, *atx;
196
197     a   = alloc_matrix(nrow, ncol);
198     at  = alloc_matrix(ncol, nrow);
199     ata = alloc_matrix(ncol, ncol);
200     for (i = 0; (i < nrow); i++)
201     {
202         for (j = 0; (j < ncol); j++)
203         {
204             at[j][i] = a[i][j] = xx[j][i];
205         }
206     }
207     matrix_multiply(fp, nrow, ncol, a, at, ata);
208     if ((row = matrix_invert(fp, ncol, ata)) != 0)
209     {
210         gmx_fatal(FARGS, "Matrix inversion failed. Incorrect row = %d.\nThis probably indicates that you do not have sufficient data points, or that some parameters are linearly dependent.",
211                   row);
212     }
213     snew(atx, ncol);
214
215     for (i = 0; (i < ncol); i++)
216     {
217         atx[i] = 0;
218         for (j = 0; (j < nrow); j++)
219         {
220             atx[i] += at[i][j]*y[j];
221         }
222     }
223     for (i = 0; (i < ncol); i++)
224     {
225         a0[i] = 0;
226         for (j = 0; (j < ncol); j++)
227         {
228             a0[i] += ata[i][j]*atx[j];
229         }
230     }
231     chi2 = 0;
232     for (j = 0; (j < nrow); j++)
233     {
234         ax = 0;
235         for (i = 0; (i < ncol); i++)
236         {
237             ax += a0[i]*a[j][i];
238         }
239         chi2 += sqr(y[j]-ax);
240     }
241
242     sfree(atx);
243     free_matrix(a);
244     free_matrix(at);
245     free_matrix(ata);
246
247     return chi2;
248 }