Merge release-5-0 into master
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2009, The GROMACS Development Team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "config.h"
38
39 #include <math.h>
40
41 #include "types/simple.h"
42
43 #include "gromacs/math/vec.h"
44 #include "gromacs/utility/smalloc.h"
45
46 #include "nb_kernel_allvsallgb.h"
47 #include "nrnb.h"
48
49 typedef struct
50 {
51     real **    pvdwparam;
52     int *      jindex;
53     int **     exclusion_mask;
54 }
55 gmx_allvsall_data_t;
56
57 static int
58 calc_maxoffset(int i, int natoms)
59 {
60     int maxoffset;
61
62     if ((natoms % 2) == 1)
63     {
64         /* Odd number of atoms, easy */
65         maxoffset = natoms/2;
66     }
67     else if ((natoms % 4) == 0)
68     {
69         /* Multiple of four is hard */
70         if (i < natoms/2)
71         {
72             if ((i % 2) == 0)
73             {
74                 maxoffset = natoms/2;
75             }
76             else
77             {
78                 maxoffset = natoms/2-1;
79             }
80         }
81         else
82         {
83             if ((i % 2) == 1)
84             {
85                 maxoffset = natoms/2;
86             }
87             else
88             {
89                 maxoffset = natoms/2-1;
90             }
91         }
92     }
93     else
94     {
95         /* natoms/2 = odd */
96         if ((i % 2) == 0)
97         {
98             maxoffset = natoms/2;
99         }
100         else
101         {
102             maxoffset = natoms/2-1;
103         }
104     }
105
106     return maxoffset;
107 }
108
109
110 static void
111 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
112                              t_blocka *              excl,
113                              int                     natoms)
114 {
115     int i, j, k;
116     int nj0, nj1;
117     int max_offset;
118     int max_excl_offset;
119     int iexcl;
120     int nj;
121
122     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
123      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
124      * whether they should interact or not.
125      *
126      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
127      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
128      * we create a jindex array with three elements per i atom: the starting point, the point to
129      * which we need to check exclusions, and the end point.
130      * This way we only have to allocate a short exclusion mask per i atom.
131      */
132
133     /* Allocate memory for our modified jindex array */
134     snew(aadata->jindex, 3*natoms);
135
136     /* Pointer to lists with exclusion masks */
137     snew(aadata->exclusion_mask, natoms);
138
139     for (i = 0; i < natoms; i++)
140     {
141         /* Start */
142         aadata->jindex[3*i]   = i+1;
143         max_offset            = calc_maxoffset(i, natoms);
144
145         /* Exclusions */
146         nj0   = excl->index[i];
147         nj1   = excl->index[i+1];
148
149         /* first check the max range */
150         max_excl_offset = -1;
151
152         for (j = nj0; j < nj1; j++)
153         {
154             iexcl = excl->a[j];
155
156             k = iexcl - i;
157
158             if (k+natoms <= max_offset)
159             {
160                 k += natoms;
161             }
162
163             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
164         }
165
166         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
167
168         aadata->jindex[3*i+1] = i+1+max_excl_offset;
169
170         snew(aadata->exclusion_mask[i], max_excl_offset);
171         /* Include everything by default */
172         for (j = 0; j < max_excl_offset; j++)
173         {
174             /* Use all-ones to mark interactions that should be present, compatible with SSE */
175             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
176         }
177
178         /* Go through exclusions again */
179         for (j = nj0; j < nj1; j++)
180         {
181             iexcl = excl->a[j];
182
183             k = iexcl - i;
184
185             if (k+natoms <= max_offset)
186             {
187                 k += natoms;
188             }
189
190             if (k > 0 && k <= max_excl_offset)
191             {
192                 /* Excluded, kill it! */
193                 aadata->exclusion_mask[i][k-1] = 0;
194             }
195         }
196
197         /* End */
198         aadata->jindex[3*i+2] = i+1+max_offset;
199     }
200 }
201
202
203 static void
204 setup_aadata(gmx_allvsall_data_t **  p_aadata,
205              t_blocka *              excl,
206              int                     natoms,
207              int *                   type,
208              int                     ntype,
209              real *                  pvdwparam)
210 {
211     int                  i, j, idx;
212     gmx_allvsall_data_t *aadata;
213     real                *p;
214
215     snew(aadata, 1);
216     *p_aadata = aadata;
217
218     /* Generate vdw params */
219     snew(aadata->pvdwparam, ntype);
220
221     for (i = 0; i < ntype; i++)
222     {
223         snew(aadata->pvdwparam[i], 2*natoms);
224         p = aadata->pvdwparam[i];
225
226         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
227         for (j = 0; j < natoms; j++)
228         {
229             idx             = i*ntype+type[j];
230             p[2*j]          = pvdwparam[2*idx];
231             p[2*j+1]        = pvdwparam[2*idx+1];
232         }
233     }
234
235     setup_exclusions_and_indices(aadata, excl, natoms);
236 }
237
238
239
240 void
241 nb_kernel_allvsallgb(t_nblist gmx_unused *     nlist,
242                      rvec *                    xx,
243                      rvec *                    ff,
244                      t_forcerec *              fr,
245                      t_mdatoms *               mdatoms,
246                      nb_kernel_data_t *        kernel_data,
247                      t_nrnb *                  nrnb)
248 {
249     gmx_allvsall_data_t *aadata;
250     int                  natoms;
251     int                  ni0, ni1;
252     int                  nj0, nj1, nj2;
253     int                  i, j, k;
254     real           *     charge;
255     int           *      type;
256     real                 facel;
257     real           *     pvdw;
258     int                  ggid;
259     int           *      mask;
260     real           *     GBtab;
261     real                 gbfactor;
262     real           *     invsqrta;
263     real           *     dvda;
264     real                 vgbtot, dvdasum;
265     int                  nnn, n0;
266
267     real                 ix, iy, iz, iq;
268     real                 fix, fiy, fiz;
269     real                 jx, jy, jz, qq;
270     real                 dx, dy, dz;
271     real                 tx, ty, tz;
272     real                 rsq, rinv, rinvsq, rinvsix;
273     real                 vcoul, vctot;
274     real                 c6, c12, Vvdw6, Vvdw12, Vvdwtot;
275     real                 fscal, dvdatmp, fijC, vgb;
276     real                 Y, F, Fp, Geps, Heps2, VV, FF, eps, eps2, r, rt;
277     real                 dvdaj, gbscale, isaprod, isai, isaj, gbtabscale;
278     real           *     f;
279     real           *     x;
280     t_blocka           * excl;
281     real           *     Vvdw;
282     real           *     Vc;
283     real           *     vpol;
284
285     x                   = xx[0];
286     f                   = ff[0];
287     charge              = mdatoms->chargeA;
288     type                = mdatoms->typeA;
289     gbfactor            = ((1.0/fr->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
290     facel               = fr->epsfac;
291     GBtab               = fr->gbtab.data;
292     gbtabscale          = fr->gbtab.scale;
293     invsqrta            = fr->invsqrta;
294     dvda                = fr->dvda;
295     vpol                = kernel_data->energygrp_polarization;
296
297     natoms              = mdatoms->nr;
298     ni0                 = 0;
299     ni1                 = mdatoms->homenr;
300
301     aadata              = fr->AllvsAll_work;
302     excl                = kernel_data->exclusions;
303
304     Vc                  = kernel_data->energygrp_elec;
305     Vvdw                = kernel_data->energygrp_vdw;
306
307     if (aadata == NULL)
308     {
309         setup_aadata(&aadata, excl, natoms, type, fr->ntype, fr->nbfp);
310         fr->AllvsAll_work  = aadata;
311     }
312
313     for (i = ni0; i < ni1; i++)
314     {
315         /* We assume shifts are NOT used for all-vs-all interactions */
316
317         /* Load i atom data */
318         ix                = x[3*i];
319         iy                = x[3*i+1];
320         iz                = x[3*i+2];
321         iq                = facel*charge[i];
322
323         isai              = invsqrta[i];
324
325         pvdw              = aadata->pvdwparam[type[i]];
326
327         /* Zero the potential energy for this list */
328         Vvdwtot           = 0.0;
329         vctot             = 0.0;
330         vgbtot            = 0.0;
331         dvdasum           = 0.0;
332
333         /* Clear i atom forces */
334         fix               = 0.0;
335         fiy               = 0.0;
336         fiz               = 0.0;
337
338         /* Load limits for loop over neighbors */
339         nj0              = aadata->jindex[3*i];
340         nj1              = aadata->jindex[3*i+1];
341         nj2              = aadata->jindex[3*i+2];
342
343         mask             = aadata->exclusion_mask[i];
344
345         /* Prologue part, including exclusion mask */
346         for (j = nj0; j < nj1; j++, mask++)
347         {
348             if (*mask != 0)
349             {
350                 k = j%natoms;
351
352                 /* load j atom coordinates */
353                 jx                = x[3*k];
354                 jy                = x[3*k+1];
355                 jz                = x[3*k+2];
356
357                 /* Calculate distance */
358                 dx                = ix - jx;
359                 dy                = iy - jy;
360                 dz                = iz - jz;
361                 rsq               = dx*dx+dy*dy+dz*dz;
362
363                 /* Calculate 1/r and 1/r2 */
364                 rinv             = gmx_invsqrt(rsq);
365
366                 /* Load parameters for j atom */
367                 isaj              = invsqrta[k];
368                 isaprod           = isai*isaj;
369                 qq                = iq*charge[k];
370                 vcoul             = qq*rinv;
371                 fscal             = vcoul*rinv;
372                 qq                = isaprod*(-qq)*gbfactor;
373                 gbscale           = isaprod*gbtabscale;
374                 c6                = pvdw[2*k];
375                 c12               = pvdw[2*k+1];
376                 rinvsq            = rinv*rinv;
377
378                 /* Tabulated Generalized-Born interaction */
379                 dvdaj            = dvda[k];
380                 r                = rsq*rinv;
381
382                 /* Calculate table index */
383                 rt               = r*gbscale;
384                 n0               = rt;
385                 eps              = rt-n0;
386                 eps2             = eps*eps;
387                 nnn              = 4*n0;
388                 Y                = GBtab[nnn];
389                 F                = GBtab[nnn+1];
390                 Geps             = eps*GBtab[nnn+2];
391                 Heps2            = eps2*GBtab[nnn+3];
392                 Fp               = F+Geps+Heps2;
393                 VV               = Y+eps*Fp;
394                 FF               = Fp+Geps+2.0*Heps2;
395                 vgb              = qq*VV;
396                 fijC             = qq*FF*gbscale;
397                 dvdatmp          = -0.5*(vgb+fijC*r);
398                 dvdasum          = dvdasum + dvdatmp;
399                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
400                 vctot            = vctot + vcoul;
401                 vgbtot           = vgbtot + vgb;
402
403                 /* Lennard-Jones interaction */
404                 rinvsix          = rinvsq*rinvsq*rinvsq;
405                 Vvdw6            = c6*rinvsix;
406                 Vvdw12           = c12*rinvsix*rinvsix;
407                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
408                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
409
410                 /* Calculate temporary vectorial force */
411                 tx                = fscal*dx;
412                 ty                = fscal*dy;
413                 tz                = fscal*dz;
414
415                 /* Increment i atom force */
416                 fix               = fix + tx;
417                 fiy               = fiy + ty;
418                 fiz               = fiz + tz;
419
420                 /* Decrement j atom force */
421                 f[3*k]            = f[3*k]   - tx;
422                 f[3*k+1]          = f[3*k+1] - ty;
423                 f[3*k+2]          = f[3*k+2] - tz;
424             }
425             /* Inner loop uses 38 flops/iteration */
426         }
427
428         /* Main part, no exclusions */
429         for (j = nj1; j < nj2; j++)
430         {
431             k = j%natoms;
432
433             /* load j atom coordinates */
434             jx                = x[3*k];
435             jy                = x[3*k+1];
436             jz                = x[3*k+2];
437
438             /* Calculate distance */
439             dx                = ix - jx;
440             dy                = iy - jy;
441             dz                = iz - jz;
442             rsq               = dx*dx+dy*dy+dz*dz;
443
444             /* Calculate 1/r and 1/r2 */
445             rinv             = gmx_invsqrt(rsq);
446
447             /* Load parameters for j atom */
448             isaj              = invsqrta[k];
449             isaprod           = isai*isaj;
450             qq                = iq*charge[k];
451             vcoul             = qq*rinv;
452             fscal             = vcoul*rinv;
453             qq                = isaprod*(-qq)*gbfactor;
454             gbscale           = isaprod*gbtabscale;
455             c6                = pvdw[2*k];
456             c12               = pvdw[2*k+1];
457             rinvsq            = rinv*rinv;
458
459             /* Tabulated Generalized-Born interaction */
460             dvdaj            = dvda[k];
461             r                = rsq*rinv;
462
463             /* Calculate table index */
464             rt               = r*gbscale;
465             n0               = rt;
466             eps              = rt-n0;
467             eps2             = eps*eps;
468             nnn              = 4*n0;
469             Y                = GBtab[nnn];
470             F                = GBtab[nnn+1];
471             Geps             = eps*GBtab[nnn+2];
472             Heps2            = eps2*GBtab[nnn+3];
473             Fp               = F+Geps+Heps2;
474             VV               = Y+eps*Fp;
475             FF               = Fp+Geps+2.0*Heps2;
476             vgb              = qq*VV;
477             fijC             = qq*FF*gbscale;
478             dvdatmp          = -0.5*(vgb+fijC*r);
479             dvdasum          = dvdasum + dvdatmp;
480             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
481             vctot            = vctot + vcoul;
482             vgbtot           = vgbtot + vgb;
483
484             /* Lennard-Jones interaction */
485             rinvsix          = rinvsq*rinvsq*rinvsq;
486             Vvdw6            = c6*rinvsix;
487             Vvdw12           = c12*rinvsix*rinvsix;
488             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
489             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
490
491             /* Calculate temporary vectorial force */
492             tx                = fscal*dx;
493             ty                = fscal*dy;
494             tz                = fscal*dz;
495
496             /* Increment i atom force */
497             fix               = fix + tx;
498             fiy               = fiy + ty;
499             fiz               = fiz + tz;
500
501             /* Decrement j atom force */
502             f[3*k]            = f[3*k]   - tx;
503             f[3*k+1]          = f[3*k+1] - ty;
504             f[3*k+2]          = f[3*k+2] - tz;
505
506             /* Inner loop uses 38 flops/iteration */
507         }
508
509         f[3*i]   += fix;
510         f[3*i+1] += fiy;
511         f[3*i+2] += fiz;
512
513         /* Add potential energies to the group for this list */
514         ggid             = 0;
515
516         Vc[ggid]         = Vc[ggid] + vctot;
517         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
518         vpol[ggid]       = vpol[ggid] + vgbtot;
519         dvda[i]          = dvda[i] + dvdasum*isai*isai;
520
521         /* Outer loop uses 6 flops/iteration */
522     }
523
524     /* 12 flops per outer iteration
525      * 19 flops per inner iteration
526      */
527     inc_nrnb(nrnb, eNR_NBKERNEL_ELEC_VDW_VF, (ni1-ni0)*12 + ((ni1-ni0)*natoms/2)*19);
528 }