Sort all includes in src/gromacs
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2009, The GROMACS Development Team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "nb_kernel_allvsallgb.h"
40
41 #include "config.h"
42
43 #include <math.h>
44
45 #include "gromacs/legacyheaders/nrnb.h"
46 #include "gromacs/legacyheaders/types/simple.h"
47 #include "gromacs/math/vec.h"
48 #include "gromacs/utility/smalloc.h"
49
50 typedef struct
51 {
52     real **    pvdwparam;
53     int *      jindex;
54     int **     exclusion_mask;
55 }
56 gmx_allvsall_data_t;
57
58 static int
59 calc_maxoffset(int i, int natoms)
60 {
61     int maxoffset;
62
63     if ((natoms % 2) == 1)
64     {
65         /* Odd number of atoms, easy */
66         maxoffset = natoms/2;
67     }
68     else if ((natoms % 4) == 0)
69     {
70         /* Multiple of four is hard */
71         if (i < natoms/2)
72         {
73             if ((i % 2) == 0)
74             {
75                 maxoffset = natoms/2;
76             }
77             else
78             {
79                 maxoffset = natoms/2-1;
80             }
81         }
82         else
83         {
84             if ((i % 2) == 1)
85             {
86                 maxoffset = natoms/2;
87             }
88             else
89             {
90                 maxoffset = natoms/2-1;
91             }
92         }
93     }
94     else
95     {
96         /* natoms/2 = odd */
97         if ((i % 2) == 0)
98         {
99             maxoffset = natoms/2;
100         }
101         else
102         {
103             maxoffset = natoms/2-1;
104         }
105     }
106
107     return maxoffset;
108 }
109
110
111 static void
112 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
113                              t_blocka *              excl,
114                              int                     natoms)
115 {
116     int i, j, k;
117     int nj0, nj1;
118     int max_offset;
119     int max_excl_offset;
120     int iexcl;
121     int nj;
122
123     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
124      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
125      * whether they should interact or not.
126      *
127      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
128      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
129      * we create a jindex array with three elements per i atom: the starting point, the point to
130      * which we need to check exclusions, and the end point.
131      * This way we only have to allocate a short exclusion mask per i atom.
132      */
133
134     /* Allocate memory for our modified jindex array */
135     snew(aadata->jindex, 3*natoms);
136
137     /* Pointer to lists with exclusion masks */
138     snew(aadata->exclusion_mask, natoms);
139
140     for (i = 0; i < natoms; i++)
141     {
142         /* Start */
143         aadata->jindex[3*i]   = i+1;
144         max_offset            = calc_maxoffset(i, natoms);
145
146         /* Exclusions */
147         nj0   = excl->index[i];
148         nj1   = excl->index[i+1];
149
150         /* first check the max range */
151         max_excl_offset = -1;
152
153         for (j = nj0; j < nj1; j++)
154         {
155             iexcl = excl->a[j];
156
157             k = iexcl - i;
158
159             if (k+natoms <= max_offset)
160             {
161                 k += natoms;
162             }
163
164             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
165         }
166
167         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
168
169         aadata->jindex[3*i+1] = i+1+max_excl_offset;
170
171         snew(aadata->exclusion_mask[i], max_excl_offset);
172         /* Include everything by default */
173         for (j = 0; j < max_excl_offset; j++)
174         {
175             /* Use all-ones to mark interactions that should be present, compatible with SSE */
176             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
177         }
178
179         /* Go through exclusions again */
180         for (j = nj0; j < nj1; j++)
181         {
182             iexcl = excl->a[j];
183
184             k = iexcl - i;
185
186             if (k+natoms <= max_offset)
187             {
188                 k += natoms;
189             }
190
191             if (k > 0 && k <= max_excl_offset)
192             {
193                 /* Excluded, kill it! */
194                 aadata->exclusion_mask[i][k-1] = 0;
195             }
196         }
197
198         /* End */
199         aadata->jindex[3*i+2] = i+1+max_offset;
200     }
201 }
202
203
204 static void
205 setup_aadata(gmx_allvsall_data_t **  p_aadata,
206              t_blocka *              excl,
207              int                     natoms,
208              int *                   type,
209              int                     ntype,
210              real *                  pvdwparam)
211 {
212     int                  i, j, idx;
213     gmx_allvsall_data_t *aadata;
214     real                *p;
215
216     snew(aadata, 1);
217     *p_aadata = aadata;
218
219     /* Generate vdw params */
220     snew(aadata->pvdwparam, ntype);
221
222     for (i = 0; i < ntype; i++)
223     {
224         snew(aadata->pvdwparam[i], 2*natoms);
225         p = aadata->pvdwparam[i];
226
227         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
228         for (j = 0; j < natoms; j++)
229         {
230             idx             = i*ntype+type[j];
231             p[2*j]          = pvdwparam[2*idx];
232             p[2*j+1]        = pvdwparam[2*idx+1];
233         }
234     }
235
236     setup_exclusions_and_indices(aadata, excl, natoms);
237 }
238
239
240
241 void
242 nb_kernel_allvsallgb(t_nblist gmx_unused *     nlist,
243                      rvec *                    xx,
244                      rvec *                    ff,
245                      t_forcerec *              fr,
246                      t_mdatoms *               mdatoms,
247                      nb_kernel_data_t *        kernel_data,
248                      t_nrnb *                  nrnb)
249 {
250     gmx_allvsall_data_t *aadata;
251     int                  natoms;
252     int                  ni0, ni1;
253     int                  nj0, nj1, nj2;
254     int                  i, j, k;
255     real           *     charge;
256     int           *      type;
257     real                 facel;
258     real           *     pvdw;
259     int                  ggid;
260     int           *      mask;
261     real           *     GBtab;
262     real                 gbfactor;
263     real           *     invsqrta;
264     real           *     dvda;
265     real                 vgbtot, dvdasum;
266     int                  nnn, n0;
267
268     real                 ix, iy, iz, iq;
269     real                 fix, fiy, fiz;
270     real                 jx, jy, jz, qq;
271     real                 dx, dy, dz;
272     real                 tx, ty, tz;
273     real                 rsq, rinv, rinvsq, rinvsix;
274     real                 vcoul, vctot;
275     real                 c6, c12, Vvdw6, Vvdw12, Vvdwtot;
276     real                 fscal, dvdatmp, fijC, vgb;
277     real                 Y, F, Fp, Geps, Heps2, VV, FF, eps, eps2, r, rt;
278     real                 dvdaj, gbscale, isaprod, isai, isaj, gbtabscale;
279     real           *     f;
280     real           *     x;
281     t_blocka           * excl;
282     real           *     Vvdw;
283     real           *     Vc;
284     real           *     vpol;
285
286     x                   = xx[0];
287     f                   = ff[0];
288     charge              = mdatoms->chargeA;
289     type                = mdatoms->typeA;
290     gbfactor            = ((1.0/fr->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
291     facel               = fr->epsfac;
292     GBtab               = fr->gbtab.data;
293     gbtabscale          = fr->gbtab.scale;
294     invsqrta            = fr->invsqrta;
295     dvda                = fr->dvda;
296     vpol                = kernel_data->energygrp_polarization;
297
298     natoms              = mdatoms->nr;
299     ni0                 = 0;
300     ni1                 = mdatoms->homenr;
301
302     aadata              = fr->AllvsAll_work;
303     excl                = kernel_data->exclusions;
304
305     Vc                  = kernel_data->energygrp_elec;
306     Vvdw                = kernel_data->energygrp_vdw;
307
308     if (aadata == NULL)
309     {
310         setup_aadata(&aadata, excl, natoms, type, fr->ntype, fr->nbfp);
311         fr->AllvsAll_work  = aadata;
312     }
313
314     for (i = ni0; i < ni1; i++)
315     {
316         /* We assume shifts are NOT used for all-vs-all interactions */
317
318         /* Load i atom data */
319         ix                = x[3*i];
320         iy                = x[3*i+1];
321         iz                = x[3*i+2];
322         iq                = facel*charge[i];
323
324         isai              = invsqrta[i];
325
326         pvdw              = aadata->pvdwparam[type[i]];
327
328         /* Zero the potential energy for this list */
329         Vvdwtot           = 0.0;
330         vctot             = 0.0;
331         vgbtot            = 0.0;
332         dvdasum           = 0.0;
333
334         /* Clear i atom forces */
335         fix               = 0.0;
336         fiy               = 0.0;
337         fiz               = 0.0;
338
339         /* Load limits for loop over neighbors */
340         nj0              = aadata->jindex[3*i];
341         nj1              = aadata->jindex[3*i+1];
342         nj2              = aadata->jindex[3*i+2];
343
344         mask             = aadata->exclusion_mask[i];
345
346         /* Prologue part, including exclusion mask */
347         for (j = nj0; j < nj1; j++, mask++)
348         {
349             if (*mask != 0)
350             {
351                 k = j%natoms;
352
353                 /* load j atom coordinates */
354                 jx                = x[3*k];
355                 jy                = x[3*k+1];
356                 jz                = x[3*k+2];
357
358                 /* Calculate distance */
359                 dx                = ix - jx;
360                 dy                = iy - jy;
361                 dz                = iz - jz;
362                 rsq               = dx*dx+dy*dy+dz*dz;
363
364                 /* Calculate 1/r and 1/r2 */
365                 rinv             = gmx_invsqrt(rsq);
366
367                 /* Load parameters for j atom */
368                 isaj              = invsqrta[k];
369                 isaprod           = isai*isaj;
370                 qq                = iq*charge[k];
371                 vcoul             = qq*rinv;
372                 fscal             = vcoul*rinv;
373                 qq                = isaprod*(-qq)*gbfactor;
374                 gbscale           = isaprod*gbtabscale;
375                 c6                = pvdw[2*k];
376                 c12               = pvdw[2*k+1];
377                 rinvsq            = rinv*rinv;
378
379                 /* Tabulated Generalized-Born interaction */
380                 dvdaj            = dvda[k];
381                 r                = rsq*rinv;
382
383                 /* Calculate table index */
384                 rt               = r*gbscale;
385                 n0               = rt;
386                 eps              = rt-n0;
387                 eps2             = eps*eps;
388                 nnn              = 4*n0;
389                 Y                = GBtab[nnn];
390                 F                = GBtab[nnn+1];
391                 Geps             = eps*GBtab[nnn+2];
392                 Heps2            = eps2*GBtab[nnn+3];
393                 Fp               = F+Geps+Heps2;
394                 VV               = Y+eps*Fp;
395                 FF               = Fp+Geps+2.0*Heps2;
396                 vgb              = qq*VV;
397                 fijC             = qq*FF*gbscale;
398                 dvdatmp          = -0.5*(vgb+fijC*r);
399                 dvdasum          = dvdasum + dvdatmp;
400                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
401                 vctot            = vctot + vcoul;
402                 vgbtot           = vgbtot + vgb;
403
404                 /* Lennard-Jones interaction */
405                 rinvsix          = rinvsq*rinvsq*rinvsq;
406                 Vvdw6            = c6*rinvsix;
407                 Vvdw12           = c12*rinvsix*rinvsix;
408                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
409                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
410
411                 /* Calculate temporary vectorial force */
412                 tx                = fscal*dx;
413                 ty                = fscal*dy;
414                 tz                = fscal*dz;
415
416                 /* Increment i atom force */
417                 fix               = fix + tx;
418                 fiy               = fiy + ty;
419                 fiz               = fiz + tz;
420
421                 /* Decrement j atom force */
422                 f[3*k]            = f[3*k]   - tx;
423                 f[3*k+1]          = f[3*k+1] - ty;
424                 f[3*k+2]          = f[3*k+2] - tz;
425             }
426             /* Inner loop uses 38 flops/iteration */
427         }
428
429         /* Main part, no exclusions */
430         for (j = nj1; j < nj2; j++)
431         {
432             k = j%natoms;
433
434             /* load j atom coordinates */
435             jx                = x[3*k];
436             jy                = x[3*k+1];
437             jz                = x[3*k+2];
438
439             /* Calculate distance */
440             dx                = ix - jx;
441             dy                = iy - jy;
442             dz                = iz - jz;
443             rsq               = dx*dx+dy*dy+dz*dz;
444
445             /* Calculate 1/r and 1/r2 */
446             rinv             = gmx_invsqrt(rsq);
447
448             /* Load parameters for j atom */
449             isaj              = invsqrta[k];
450             isaprod           = isai*isaj;
451             qq                = iq*charge[k];
452             vcoul             = qq*rinv;
453             fscal             = vcoul*rinv;
454             qq                = isaprod*(-qq)*gbfactor;
455             gbscale           = isaprod*gbtabscale;
456             c6                = pvdw[2*k];
457             c12               = pvdw[2*k+1];
458             rinvsq            = rinv*rinv;
459
460             /* Tabulated Generalized-Born interaction */
461             dvdaj            = dvda[k];
462             r                = rsq*rinv;
463
464             /* Calculate table index */
465             rt               = r*gbscale;
466             n0               = rt;
467             eps              = rt-n0;
468             eps2             = eps*eps;
469             nnn              = 4*n0;
470             Y                = GBtab[nnn];
471             F                = GBtab[nnn+1];
472             Geps             = eps*GBtab[nnn+2];
473             Heps2            = eps2*GBtab[nnn+3];
474             Fp               = F+Geps+Heps2;
475             VV               = Y+eps*Fp;
476             FF               = Fp+Geps+2.0*Heps2;
477             vgb              = qq*VV;
478             fijC             = qq*FF*gbscale;
479             dvdatmp          = -0.5*(vgb+fijC*r);
480             dvdasum          = dvdasum + dvdatmp;
481             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
482             vctot            = vctot + vcoul;
483             vgbtot           = vgbtot + vgb;
484
485             /* Lennard-Jones interaction */
486             rinvsix          = rinvsq*rinvsq*rinvsq;
487             Vvdw6            = c6*rinvsix;
488             Vvdw12           = c12*rinvsix*rinvsix;
489             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
490             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
491
492             /* Calculate temporary vectorial force */
493             tx                = fscal*dx;
494             ty                = fscal*dy;
495             tz                = fscal*dz;
496
497             /* Increment i atom force */
498             fix               = fix + tx;
499             fiy               = fiy + ty;
500             fiz               = fiz + tz;
501
502             /* Decrement j atom force */
503             f[3*k]            = f[3*k]   - tx;
504             f[3*k+1]          = f[3*k+1] - ty;
505             f[3*k+2]          = f[3*k+2] - tz;
506
507             /* Inner loop uses 38 flops/iteration */
508         }
509
510         f[3*i]   += fix;
511         f[3*i+1] += fiy;
512         f[3*i+2] += fiz;
513
514         /* Add potential energies to the group for this list */
515         ggid             = 0;
516
517         Vc[ggid]         = Vc[ggid] + vctot;
518         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
519         vpol[ggid]       = vpol[ggid] + vgbtot;
520         dvda[i]          = dvda[i] + dvdasum*isai*isai;
521
522         /* Outer loop uses 6 flops/iteration */
523     }
524
525     /* 12 flops per outer iteration
526      * 19 flops per inner iteration
527      */
528     inc_nrnb(nrnb, eNR_NBKERNEL_ELEC_VDW_VF, (ni1-ni0)*12 + ((ni1-ni0)*natoms/2)*19);
529 }