Remove topology support for implicit solvation
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2009, The GROMACS Development Team.
6  * Copyright (c) 2013,2014,2015,2017,2018, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "nb_kernel_allvsallgb.h"
40
41 #include "config.h"
42
43 #include <math.h>
44
45 #include "gromacs/gmxlib/nrnb.h"
46 #include "gromacs/utility/real.h"
47 #include "gromacs/utility/smalloc.h"
48
49 typedef struct
50 {
51     real **    pvdwparam;
52     int *      jindex;
53     int **     exclusion_mask;
54 }
55 gmx_allvsall_data_t;
56
57 static int
58 calc_maxoffset(int i, int natoms)
59 {
60     int maxoffset;
61
62     if ((natoms % 2) == 1)
63     {
64         /* Odd number of atoms, easy */
65         maxoffset = natoms/2;
66     }
67     else if ((natoms % 4) == 0)
68     {
69         /* Multiple of four is hard */
70         if (i < natoms/2)
71         {
72             if ((i % 2) == 0)
73             {
74                 maxoffset = natoms/2;
75             }
76             else
77             {
78                 maxoffset = natoms/2-1;
79             }
80         }
81         else
82         {
83             if ((i % 2) == 1)
84             {
85                 maxoffset = natoms/2;
86             }
87             else
88             {
89                 maxoffset = natoms/2-1;
90             }
91         }
92     }
93     else
94     {
95         /* natoms/2 = odd */
96         if ((i % 2) == 0)
97         {
98             maxoffset = natoms/2;
99         }
100         else
101         {
102             maxoffset = natoms/2-1;
103         }
104     }
105
106     return maxoffset;
107 }
108
109
110 static void
111 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
112                              t_blocka *              excl,
113                              int                     natoms)
114 {
115     int i, j, k;
116     int nj0, nj1;
117     int max_offset;
118     int max_excl_offset;
119     int iexcl;
120
121     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
122      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
123      * whether they should interact or not.
124      *
125      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
126      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
127      * we create a jindex array with three elements per i atom: the starting point, the point to
128      * which we need to check exclusions, and the end point.
129      * This way we only have to allocate a short exclusion mask per i atom.
130      */
131
132     /* Allocate memory for our modified jindex array */
133     snew(aadata->jindex, 3*natoms);
134
135     /* Pointer to lists with exclusion masks */
136     snew(aadata->exclusion_mask, natoms);
137
138     for (i = 0; i < natoms; i++)
139     {
140         /* Start */
141         aadata->jindex[3*i]   = i+1;
142         max_offset            = calc_maxoffset(i, natoms);
143
144         /* Exclusions */
145         nj0   = excl->index[i];
146         nj1   = excl->index[i+1];
147
148         /* first check the max range */
149         max_excl_offset = -1;
150
151         for (j = nj0; j < nj1; j++)
152         {
153             iexcl = excl->a[j];
154
155             k = iexcl - i;
156
157             if (k+natoms <= max_offset)
158             {
159                 k += natoms;
160             }
161
162             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
163         }
164
165         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
166
167         aadata->jindex[3*i+1] = i+1+max_excl_offset;
168
169         snew(aadata->exclusion_mask[i], max_excl_offset);
170         /* Include everything by default */
171         for (j = 0; j < max_excl_offset; j++)
172         {
173             /* Use all-ones to mark interactions that should be present, compatible with SSE */
174             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
175         }
176
177         /* Go through exclusions again */
178         for (j = nj0; j < nj1; j++)
179         {
180             iexcl = excl->a[j];
181
182             k = iexcl - i;
183
184             if (k+natoms <= max_offset)
185             {
186                 k += natoms;
187             }
188
189             if (k > 0 && k <= max_excl_offset)
190             {
191                 /* Excluded, kill it! */
192                 aadata->exclusion_mask[i][k-1] = 0;
193             }
194         }
195
196         /* End */
197         aadata->jindex[3*i+2] = i+1+max_offset;
198     }
199 }
200
201
202 static void
203 setup_aadata(gmx_allvsall_data_t **  p_aadata,
204              t_blocka *              excl,
205              int                     natoms,
206              int *                   type,
207              int                     ntype,
208              real *                  pvdwparam)
209 {
210     int                  i, j, idx;
211     gmx_allvsall_data_t *aadata;
212     real                *p;
213
214     snew(aadata, 1);
215     *p_aadata = aadata;
216
217     /* Generate vdw params */
218     snew(aadata->pvdwparam, ntype);
219
220     for (i = 0; i < ntype; i++)
221     {
222         snew(aadata->pvdwparam[i], 2*natoms);
223         p = aadata->pvdwparam[i];
224
225         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
226         for (j = 0; j < natoms; j++)
227         {
228             idx             = i*ntype+type[j];
229             p[2*j]          = pvdwparam[2*idx];
230             p[2*j+1]        = pvdwparam[2*idx+1];
231         }
232     }
233
234     setup_exclusions_and_indices(aadata, excl, natoms);
235 }
236
237
238
239 void
240 nb_kernel_allvsallgb(t_nblist gmx_unused *     nlist,
241                      rvec *                    xx,
242                      rvec *                    ff,
243                      struct t_forcerec *       fr,
244                      t_mdatoms *               mdatoms,
245                      nb_kernel_data_t *        kernel_data,
246                      t_nrnb *                  nrnb)
247 {
248     gmx_allvsall_data_t *aadata;
249     int                  natoms;
250     int                  ni0, ni1;
251     int                  nj0, nj1, nj2;
252     int                  i, j, k;
253     real           *     charge;
254     int           *      type;
255     real                 facel;
256     real           *     pvdw;
257     int                  ggid;
258     int           *      mask;
259     real           *     GBtab;
260     real                 gbfactor;
261     real           *     invsqrta;
262     real           *     dvda;
263     real                 vgbtot, dvdasum;
264     int                  nnn, n0;
265
266     real                 ix, iy, iz, iq;
267     real                 fix, fiy, fiz;
268     real                 jx, jy, jz, qq;
269     real                 dx, dy, dz;
270     real                 tx, ty, tz;
271     real                 rsq, rinv, rinvsq, rinvsix;
272     real                 vcoul, vctot;
273     real                 c6, c12, Vvdw6, Vvdw12, Vvdwtot;
274     real                 fscal, dvdatmp, fijC, vgb;
275     real                 Y, F, Fp, Geps, Heps2, VV, FF, eps, eps2, r, rt;
276     real                 dvdaj, gbscale, isaprod, isai, isaj, gbtabscale;
277     real           *     f;
278     real           *     x;
279     t_blocka           * excl;
280     real           *     Vvdw;
281     real           *     Vc;
282     real           *     vpol;
283
284     x                   = xx[0];
285     f                   = ff[0];
286     charge              = mdatoms->chargeA;
287     type                = mdatoms->typeA;
288     gbfactor            = ((1.0/fr->ic->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
289     facel               = fr->ic->epsfac;
290     GBtab               = fr->gbtab->data;
291     gbtabscale          = fr->gbtab->scale;
292     invsqrta            = fr->invsqrta;
293     dvda                = fr->dvda;
294     vpol                = kernel_data->energygrp_polarization;
295
296     natoms              = mdatoms->nr;
297     ni0                 = 0;
298     ni1                 = mdatoms->homenr;
299
300     aadata              = reinterpret_cast<gmx_allvsall_data_t *>(fr->AllvsAll_work);
301     excl                = kernel_data->exclusions;
302
303     Vc                  = kernel_data->energygrp_elec;
304     Vvdw                = kernel_data->energygrp_vdw;
305
306     if (aadata == NULL)
307     {
308         setup_aadata(&aadata, excl, natoms, type, fr->ntype, fr->nbfp);
309         fr->AllvsAll_work  = aadata;
310     }
311
312     for (i = ni0; i < ni1; i++)
313     {
314         /* We assume shifts are NOT used for all-vs-all interactions */
315
316         /* Load i atom data */
317         ix                = x[3*i];
318         iy                = x[3*i+1];
319         iz                = x[3*i+2];
320         iq                = facel*charge[i];
321
322         isai              = invsqrta[i];
323
324         pvdw              = aadata->pvdwparam[type[i]];
325
326         /* Zero the potential energy for this list */
327         Vvdwtot           = 0.0;
328         vctot             = 0.0;
329         vgbtot            = 0.0;
330         dvdasum           = 0.0;
331
332         /* Clear i atom forces */
333         fix               = 0.0;
334         fiy               = 0.0;
335         fiz               = 0.0;
336
337         /* Load limits for loop over neighbors */
338         nj0              = aadata->jindex[3*i];
339         nj1              = aadata->jindex[3*i+1];
340         nj2              = aadata->jindex[3*i+2];
341
342         mask             = aadata->exclusion_mask[i];
343
344         /* Prologue part, including exclusion mask */
345         for (j = nj0; j < nj1; j++, mask++)
346         {
347             if (*mask != 0)
348             {
349                 k = j%natoms;
350
351                 /* load j atom coordinates */
352                 jx                = x[3*k];
353                 jy                = x[3*k+1];
354                 jz                = x[3*k+2];
355
356                 /* Calculate distance */
357                 dx                = ix - jx;
358                 dy                = iy - jy;
359                 dz                = iz - jz;
360                 rsq               = dx*dx+dy*dy+dz*dz;
361
362                 /* Calculate 1/r and 1/r2 */
363                 rinv             = 1.0/sqrt(rsq);
364
365                 /* Load parameters for j atom */
366                 isaj              = invsqrta[k];
367                 isaprod           = isai*isaj;
368                 qq                = iq*charge[k];
369                 vcoul             = qq*rinv;
370                 fscal             = vcoul*rinv;
371                 qq                = isaprod*(-qq)*gbfactor;
372                 gbscale           = isaprod*gbtabscale;
373                 c6                = pvdw[2*k];
374                 c12               = pvdw[2*k+1];
375                 rinvsq            = rinv*rinv;
376
377                 /* Tabulated Generalized-Born interaction */
378                 dvdaj            = dvda[k];
379                 r                = rsq*rinv;
380
381                 /* Calculate table index */
382                 rt               = r*gbscale;
383                 n0               = rt;
384                 eps              = rt-n0;
385                 eps2             = eps*eps;
386                 nnn              = 4*n0;
387                 Y                = GBtab[nnn];
388                 F                = GBtab[nnn+1];
389                 Geps             = eps*GBtab[nnn+2];
390                 Heps2            = eps2*GBtab[nnn+3];
391                 Fp               = F+Geps+Heps2;
392                 VV               = Y+eps*Fp;
393                 FF               = Fp+Geps+2.0*Heps2;
394                 vgb              = qq*VV;
395                 fijC             = qq*FF*gbscale;
396                 dvdatmp          = -0.5*(vgb+fijC*r);
397                 dvdasum          = dvdasum + dvdatmp;
398                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
399                 vctot            = vctot + vcoul;
400                 vgbtot           = vgbtot + vgb;
401
402                 /* Lennard-Jones interaction */
403                 rinvsix          = rinvsq*rinvsq*rinvsq;
404                 Vvdw6            = c6*rinvsix;
405                 Vvdw12           = c12*rinvsix*rinvsix;
406                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
407                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
408
409                 /* Calculate temporary vectorial force */
410                 tx                = fscal*dx;
411                 ty                = fscal*dy;
412                 tz                = fscal*dz;
413
414                 /* Increment i atom force */
415                 fix               = fix + tx;
416                 fiy               = fiy + ty;
417                 fiz               = fiz + tz;
418
419                 /* Decrement j atom force */
420                 f[3*k]            = f[3*k]   - tx;
421                 f[3*k+1]          = f[3*k+1] - ty;
422                 f[3*k+2]          = f[3*k+2] - tz;
423             }
424             /* Inner loop uses 38 flops/iteration */
425         }
426
427         /* Main part, no exclusions */
428         for (j = nj1; j < nj2; j++)
429         {
430             k = j%natoms;
431
432             /* load j atom coordinates */
433             jx                = x[3*k];
434             jy                = x[3*k+1];
435             jz                = x[3*k+2];
436
437             /* Calculate distance */
438             dx                = ix - jx;
439             dy                = iy - jy;
440             dz                = iz - jz;
441             rsq               = dx*dx+dy*dy+dz*dz;
442
443             /* Calculate 1/r and 1/r2 */
444             rinv             = 1.0/sqrt(rsq);
445
446             /* Load parameters for j atom */
447             isaj              = invsqrta[k];
448             isaprod           = isai*isaj;
449             qq                = iq*charge[k];
450             vcoul             = qq*rinv;
451             fscal             = vcoul*rinv;
452             qq                = isaprod*(-qq)*gbfactor;
453             gbscale           = isaprod*gbtabscale;
454             c6                = pvdw[2*k];
455             c12               = pvdw[2*k+1];
456             rinvsq            = rinv*rinv;
457
458             /* Tabulated Generalized-Born interaction */
459             dvdaj            = dvda[k];
460             r                = rsq*rinv;
461
462             /* Calculate table index */
463             rt               = r*gbscale;
464             n0               = rt;
465             eps              = rt-n0;
466             eps2             = eps*eps;
467             nnn              = 4*n0;
468             Y                = GBtab[nnn];
469             F                = GBtab[nnn+1];
470             Geps             = eps*GBtab[nnn+2];
471             Heps2            = eps2*GBtab[nnn+3];
472             Fp               = F+Geps+Heps2;
473             VV               = Y+eps*Fp;
474             FF               = Fp+Geps+2.0*Heps2;
475             vgb              = qq*VV;
476             fijC             = qq*FF*gbscale;
477             dvdatmp          = -0.5*(vgb+fijC*r);
478             dvdasum          = dvdasum + dvdatmp;
479             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
480             vctot            = vctot + vcoul;
481             vgbtot           = vgbtot + vgb;
482
483             /* Lennard-Jones interaction */
484             rinvsix          = rinvsq*rinvsq*rinvsq;
485             Vvdw6            = c6*rinvsix;
486             Vvdw12           = c12*rinvsix*rinvsix;
487             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
488             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
489
490             /* Calculate temporary vectorial force */
491             tx                = fscal*dx;
492             ty                = fscal*dy;
493             tz                = fscal*dz;
494
495             /* Increment i atom force */
496             fix               = fix + tx;
497             fiy               = fiy + ty;
498             fiz               = fiz + tz;
499
500             /* Decrement j atom force */
501             f[3*k]            = f[3*k]   - tx;
502             f[3*k+1]          = f[3*k+1] - ty;
503             f[3*k+2]          = f[3*k+2] - tz;
504
505             /* Inner loop uses 38 flops/iteration */
506         }
507
508         f[3*i]   += fix;
509         f[3*i+1] += fiy;
510         f[3*i+2] += fiz;
511
512         /* Add potential energies to the group for this list */
513         ggid             = 0;
514
515         Vc[ggid]         = Vc[ggid] + vctot;
516         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
517         vpol[ggid]       = vpol[ggid] + vgbtot;
518         dvda[i]          = dvda[i] + dvdasum*isai*isai;
519
520         /* Outer loop uses 6 flops/iteration */
521     }
522
523     /* 12 flops per outer iteration
524      * 19 flops per inner iteration
525      */
526     inc_nrnb(nrnb, eNR_NBKERNEL_ELEC_VDW_VF, (ni1-ni0)*12 + ((ni1-ni0)*natoms/2)*19);
527 }