Merge release-4-6 into master
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.c
1 /*
2  *                This source code is part of
3  *
4  *                 G   R   O   M   A   C   S
5  *
6  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
7  * Copyright (c) 2001-2009, The GROMACS Development Team
8  *
9  * Gromacs is a library for molecular simulation and trajectory analysis,
10  * written by Erik Lindahl, David van der Spoel, Berk Hess, and others - for
11  * a full list of developers and information, check out http://www.gromacs.org
12  *
13  * This program is free software; you can redistribute it and/or modify it under
14  * the terms of the GNU Lesser General Public License as published by the Free
15  * Software Foundation; either version 2 of the License, or (at your option) any
16  * later version.
17  * As a special exception, you may use this file as part of a free software
18  * library without restriction.  Specifically, if other files instantiate
19  * templates or use macros or inline functions from this file, or you compile
20  * this file and link it with other files to produce an executable, this
21  * file does not by itself cause the resulting executable to be covered by
22  * the GNU Lesser General Public License.
23  *
24  * In plain-speak: do not worry about classes/macros/templates either - only
25  * changes to the library have to be LGPL, not an application linking with it.
26  *
27  * To help fund GROMACS development, we humbly ask that you cite
28  * the papers people have written on it - you can find them on the website!
29  */
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33
34 #include <math.h>
35
36 #include "types/simple.h"
37
38 #include "vec.h"
39 #include "smalloc.h"
40
41 #include "nb_kernel_allvsallgb.h"
42 #include "nrnb.h"
43
44 typedef struct
45 {
46     real **    pvdwparam;
47     int *      jindex;
48     int **     exclusion_mask;
49 }
50 gmx_allvsall_data_t;
51
52 static int
53 calc_maxoffset(int i, int natoms)
54 {
55     int maxoffset;
56
57     if ((natoms % 2) == 1)
58     {
59         /* Odd number of atoms, easy */
60         maxoffset = natoms/2;
61     }
62     else if ((natoms % 4) == 0)
63     {
64         /* Multiple of four is hard */
65         if (i < natoms/2)
66         {
67             if ((i % 2) == 0)
68             {
69                 maxoffset = natoms/2;
70             }
71             else
72             {
73                 maxoffset = natoms/2-1;
74             }
75         }
76         else
77         {
78             if ((i % 2) == 1)
79             {
80                 maxoffset = natoms/2;
81             }
82             else
83             {
84                 maxoffset = natoms/2-1;
85             }
86         }
87     }
88     else
89     {
90         /* natoms/2 = odd */
91         if ((i % 2) == 0)
92         {
93             maxoffset = natoms/2;
94         }
95         else
96         {
97             maxoffset = natoms/2-1;
98         }
99     }
100
101     return maxoffset;
102 }
103
104
105 static void
106 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
107                              t_blocka *              excl,
108                              int                     natoms)
109 {
110     int i, j, k;
111     int nj0, nj1;
112     int max_offset;
113     int max_excl_offset;
114     int iexcl;
115     int nj;
116
117     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
118      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
119      * whether they should interact or not.
120      *
121      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
122      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
123      * we create a jindex array with three elements per i atom: the starting point, the point to
124      * which we need to check exclusions, and the end point.
125      * This way we only have to allocate a short exclusion mask per i atom.
126      */
127
128     /* Allocate memory for our modified jindex array */
129     snew(aadata->jindex, 3*natoms);
130
131     /* Pointer to lists with exclusion masks */
132     snew(aadata->exclusion_mask, natoms);
133
134     for (i = 0; i < natoms; i++)
135     {
136         /* Start */
137         aadata->jindex[3*i]   = i+1;
138         max_offset            = calc_maxoffset(i, natoms);
139
140         /* Exclusions */
141         nj0   = excl->index[i];
142         nj1   = excl->index[i+1];
143
144         /* first check the max range */
145         max_excl_offset = -1;
146
147         for (j = nj0; j < nj1; j++)
148         {
149             iexcl = excl->a[j];
150
151             k = iexcl - i;
152
153             if (k+natoms <= max_offset)
154             {
155                 k += natoms;
156             }
157
158             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
159         }
160
161         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
162
163         aadata->jindex[3*i+1] = i+1+max_excl_offset;
164
165         snew(aadata->exclusion_mask[i], max_excl_offset);
166         /* Include everything by default */
167         for (j = 0; j < max_excl_offset; j++)
168         {
169             /* Use all-ones to mark interactions that should be present, compatible with SSE */
170             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
171         }
172
173         /* Go through exclusions again */
174         for (j = nj0; j < nj1; j++)
175         {
176             iexcl = excl->a[j];
177
178             k = iexcl - i;
179
180             if (k+natoms <= max_offset)
181             {
182                 k += natoms;
183             }
184
185             if (k > 0 && k <= max_excl_offset)
186             {
187                 /* Excluded, kill it! */
188                 aadata->exclusion_mask[i][k-1] = 0;
189             }
190         }
191
192         /* End */
193         aadata->jindex[3*i+2] = i+1+max_offset;
194     }
195 }
196
197
198 static void
199 setup_aadata(gmx_allvsall_data_t **  p_aadata,
200              t_blocka *              excl,
201              int                     natoms,
202              int *                   type,
203              int                     ntype,
204              real *                  pvdwparam)
205 {
206     int                  i, j, idx;
207     gmx_allvsall_data_t *aadata;
208     real                *p;
209
210     snew(aadata, 1);
211     *p_aadata = aadata;
212
213     /* Generate vdw params */
214     snew(aadata->pvdwparam, ntype);
215
216     for (i = 0; i < ntype; i++)
217     {
218         snew(aadata->pvdwparam[i], 2*natoms);
219         p = aadata->pvdwparam[i];
220
221         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
222         for (j = 0; j < natoms; j++)
223         {
224             idx             = i*ntype+type[j];
225             p[2*j]          = pvdwparam[2*idx];
226             p[2*j+1]        = pvdwparam[2*idx+1];
227         }
228     }
229
230     setup_exclusions_and_indices(aadata, excl, natoms);
231 }
232
233
234
235 void
236 nb_kernel_allvsallgb(t_nblist gmx_unused *     nlist,
237                      rvec *                    xx,
238                      rvec *                    ff,
239                      t_forcerec *              fr,
240                      t_mdatoms *               mdatoms,
241                      nb_kernel_data_t *        kernel_data,
242                      t_nrnb *                  nrnb)
243 {
244     gmx_allvsall_data_t *aadata;
245     int                  natoms;
246     int                  ni0, ni1;
247     int                  nj0, nj1, nj2;
248     int                  i, j, k;
249     real           *     charge;
250     int           *      type;
251     real                 facel;
252     real           *     pvdw;
253     int                  ggid;
254     int           *      mask;
255     real           *     GBtab;
256     real                 gbfactor;
257     real           *     invsqrta;
258     real           *     dvda;
259     real                 vgbtot, dvdasum;
260     int                  nnn, n0;
261
262     real                 ix, iy, iz, iq;
263     real                 fix, fiy, fiz;
264     real                 jx, jy, jz, qq;
265     real                 dx, dy, dz;
266     real                 tx, ty, tz;
267     real                 rsq, rinv, rinvsq, rinvsix;
268     real                 vcoul, vctot;
269     real                 c6, c12, Vvdw6, Vvdw12, Vvdwtot;
270     real                 fscal, dvdatmp, fijC, vgb;
271     real                 Y, F, Fp, Geps, Heps2, VV, FF, eps, eps2, r, rt;
272     real                 dvdaj, gbscale, isaprod, isai, isaj, gbtabscale;
273     real           *     f;
274     real           *     x;
275     t_blocka           * excl;
276     real           *     Vvdw;
277     real           *     Vc;
278     real           *     vpol;
279
280     x                   = xx[0];
281     f                   = ff[0];
282     charge              = mdatoms->chargeA;
283     type                = mdatoms->typeA;
284     gbfactor            = ((1.0/fr->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
285     facel               = fr->epsfac;
286     GBtab               = fr->gbtab.data;
287     gbtabscale          = fr->gbtab.scale;
288     invsqrta            = fr->invsqrta;
289     dvda                = fr->dvda;
290     vpol                = kernel_data->energygrp_polarization;
291
292     natoms              = mdatoms->nr;
293     ni0                 = mdatoms->start;
294     ni1                 = mdatoms->start+mdatoms->homenr;
295
296     aadata              = fr->AllvsAll_work;
297     excl                = kernel_data->exclusions;
298
299     Vc                  = kernel_data->energygrp_elec;
300     Vvdw                = kernel_data->energygrp_vdw;
301
302     if (aadata == NULL)
303     {
304         setup_aadata(&aadata, excl, natoms, type, fr->ntype, fr->nbfp);
305         fr->AllvsAll_work  = aadata;
306     }
307
308     for (i = ni0; i < ni1; i++)
309     {
310         /* We assume shifts are NOT used for all-vs-all interactions */
311
312         /* Load i atom data */
313         ix                = x[3*i];
314         iy                = x[3*i+1];
315         iz                = x[3*i+2];
316         iq                = facel*charge[i];
317
318         isai              = invsqrta[i];
319
320         pvdw              = aadata->pvdwparam[type[i]];
321
322         /* Zero the potential energy for this list */
323         Vvdwtot           = 0.0;
324         vctot             = 0.0;
325         vgbtot            = 0.0;
326         dvdasum           = 0.0;
327
328         /* Clear i atom forces */
329         fix               = 0.0;
330         fiy               = 0.0;
331         fiz               = 0.0;
332
333         /* Load limits for loop over neighbors */
334         nj0              = aadata->jindex[3*i];
335         nj1              = aadata->jindex[3*i+1];
336         nj2              = aadata->jindex[3*i+2];
337
338         mask             = aadata->exclusion_mask[i];
339
340         /* Prologue part, including exclusion mask */
341         for (j = nj0; j < nj1; j++, mask++)
342         {
343             if (*mask != 0)
344             {
345                 k = j%natoms;
346
347                 /* load j atom coordinates */
348                 jx                = x[3*k];
349                 jy                = x[3*k+1];
350                 jz                = x[3*k+2];
351
352                 /* Calculate distance */
353                 dx                = ix - jx;
354                 dy                = iy - jy;
355                 dz                = iz - jz;
356                 rsq               = dx*dx+dy*dy+dz*dz;
357
358                 /* Calculate 1/r and 1/r2 */
359                 rinv             = gmx_invsqrt(rsq);
360
361                 /* Load parameters for j atom */
362                 isaj              = invsqrta[k];
363                 isaprod           = isai*isaj;
364                 qq                = iq*charge[k];
365                 vcoul             = qq*rinv;
366                 fscal             = vcoul*rinv;
367                 qq                = isaprod*(-qq)*gbfactor;
368                 gbscale           = isaprod*gbtabscale;
369                 c6                = pvdw[2*k];
370                 c12               = pvdw[2*k+1];
371                 rinvsq            = rinv*rinv;
372
373                 /* Tabulated Generalized-Born interaction */
374                 dvdaj            = dvda[k];
375                 r                = rsq*rinv;
376
377                 /* Calculate table index */
378                 rt               = r*gbscale;
379                 n0               = rt;
380                 eps              = rt-n0;
381                 eps2             = eps*eps;
382                 nnn              = 4*n0;
383                 Y                = GBtab[nnn];
384                 F                = GBtab[nnn+1];
385                 Geps             = eps*GBtab[nnn+2];
386                 Heps2            = eps2*GBtab[nnn+3];
387                 Fp               = F+Geps+Heps2;
388                 VV               = Y+eps*Fp;
389                 FF               = Fp+Geps+2.0*Heps2;
390                 vgb              = qq*VV;
391                 fijC             = qq*FF*gbscale;
392                 dvdatmp          = -0.5*(vgb+fijC*r);
393                 dvdasum          = dvdasum + dvdatmp;
394                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
395                 vctot            = vctot + vcoul;
396                 vgbtot           = vgbtot + vgb;
397
398                 /* Lennard-Jones interaction */
399                 rinvsix          = rinvsq*rinvsq*rinvsq;
400                 Vvdw6            = c6*rinvsix;
401                 Vvdw12           = c12*rinvsix*rinvsix;
402                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
403                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
404
405                 /* Calculate temporary vectorial force */
406                 tx                = fscal*dx;
407                 ty                = fscal*dy;
408                 tz                = fscal*dz;
409
410                 /* Increment i atom force */
411                 fix               = fix + tx;
412                 fiy               = fiy + ty;
413                 fiz               = fiz + tz;
414
415                 /* Decrement j atom force */
416                 f[3*k]            = f[3*k]   - tx;
417                 f[3*k+1]          = f[3*k+1] - ty;
418                 f[3*k+2]          = f[3*k+2] - tz;
419             }
420             /* Inner loop uses 38 flops/iteration */
421         }
422
423         /* Main part, no exclusions */
424         for (j = nj1; j < nj2; j++)
425         {
426             k = j%natoms;
427
428             /* load j atom coordinates */
429             jx                = x[3*k];
430             jy                = x[3*k+1];
431             jz                = x[3*k+2];
432
433             /* Calculate distance */
434             dx                = ix - jx;
435             dy                = iy - jy;
436             dz                = iz - jz;
437             rsq               = dx*dx+dy*dy+dz*dz;
438
439             /* Calculate 1/r and 1/r2 */
440             rinv             = gmx_invsqrt(rsq);
441
442             /* Load parameters for j atom */
443             isaj              = invsqrta[k];
444             isaprod           = isai*isaj;
445             qq                = iq*charge[k];
446             vcoul             = qq*rinv;
447             fscal             = vcoul*rinv;
448             qq                = isaprod*(-qq)*gbfactor;
449             gbscale           = isaprod*gbtabscale;
450             c6                = pvdw[2*k];
451             c12               = pvdw[2*k+1];
452             rinvsq            = rinv*rinv;
453
454             /* Tabulated Generalized-Born interaction */
455             dvdaj            = dvda[k];
456             r                = rsq*rinv;
457
458             /* Calculate table index */
459             rt               = r*gbscale;
460             n0               = rt;
461             eps              = rt-n0;
462             eps2             = eps*eps;
463             nnn              = 4*n0;
464             Y                = GBtab[nnn];
465             F                = GBtab[nnn+1];
466             Geps             = eps*GBtab[nnn+2];
467             Heps2            = eps2*GBtab[nnn+3];
468             Fp               = F+Geps+Heps2;
469             VV               = Y+eps*Fp;
470             FF               = Fp+Geps+2.0*Heps2;
471             vgb              = qq*VV;
472             fijC             = qq*FF*gbscale;
473             dvdatmp          = -0.5*(vgb+fijC*r);
474             dvdasum          = dvdasum + dvdatmp;
475             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
476             vctot            = vctot + vcoul;
477             vgbtot           = vgbtot + vgb;
478
479             /* Lennard-Jones interaction */
480             rinvsix          = rinvsq*rinvsq*rinvsq;
481             Vvdw6            = c6*rinvsix;
482             Vvdw12           = c12*rinvsix*rinvsix;
483             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
484             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
485
486             /* Calculate temporary vectorial force */
487             tx                = fscal*dx;
488             ty                = fscal*dy;
489             tz                = fscal*dz;
490
491             /* Increment i atom force */
492             fix               = fix + tx;
493             fiy               = fiy + ty;
494             fiz               = fiz + tz;
495
496             /* Decrement j atom force */
497             f[3*k]            = f[3*k]   - tx;
498             f[3*k+1]          = f[3*k+1] - ty;
499             f[3*k+2]          = f[3*k+2] - tz;
500
501             /* Inner loop uses 38 flops/iteration */
502         }
503
504         f[3*i]   += fix;
505         f[3*i+1] += fiy;
506         f[3*i+2] += fiz;
507
508         /* Add potential energies to the group for this list */
509         ggid             = 0;
510
511         Vc[ggid]         = Vc[ggid] + vctot;
512         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
513         vpol[ggid]       = vpol[ggid] + vgbtot;
514         dvda[i]          = dvda[i] + dvdasum*isai*isai;
515
516         /* Outer loop uses 6 flops/iteration */
517     }
518
519     /* 12 flops per outer iteration
520      * 19 flops per inner iteration
521      */
522     inc_nrnb(nrnb, eNR_NBKERNEL_ELEC_VDW_VF, (ni1-ni0)*12 + ((ni1-ni0)*natoms/2)*19);
523 }