Update copyright statements and change license to LGPL
[alexxy/gromacs.git] / src / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2009, The GROMACS Development Team
6  * Copyright (c) 2012, by the GROMACS development team, led by
7  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
8  * others, as listed in the AUTHORS file in the top-level source
9  * directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #ifdef HAVE_CONFIG_H
38 #include <config.h>
39 #endif
40
41 #include <math.h>
42
43 #include "types/simple.h"
44
45 #include "vec.h"
46 #include "smalloc.h"
47
48 #include "nb_kernel_allvsallgb.h"
49 #include "nrnb.h"
50
51 typedef struct 
52 {
53     real **    pvdwparam;
54     int *      jindex;
55     int **     exclusion_mask;
56
57 gmx_allvsall_data_t;
58                 
59 static int 
60 calc_maxoffset(int i,int natoms)
61 {
62     int maxoffset;
63     
64     if ((natoms % 2) == 1)
65     {
66         /* Odd number of atoms, easy */
67         maxoffset = natoms/2;
68     }
69     else if ((natoms % 4) == 0)
70     {
71         /* Multiple of four is hard */
72         if (i < natoms/2)
73         {
74             if ((i % 2) == 0)
75             {
76                 maxoffset=natoms/2;
77             }
78             else
79             {
80                 maxoffset=natoms/2-1;
81             }
82         }
83         else
84         {
85             if ((i % 2) == 1)
86             {
87                 maxoffset=natoms/2;
88             }
89             else
90             {
91                 maxoffset=natoms/2-1;
92             }
93         }
94     }
95     else
96     {
97         /* natoms/2 = odd */
98         if ((i % 2) == 0)
99         {
100             maxoffset=natoms/2;
101         }
102         else
103         {
104             maxoffset=natoms/2-1;
105         }
106     }
107     
108     return maxoffset;
109 }
110
111
112 static void
113 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
114                              t_blocka *              excl, 
115                              int                     natoms)
116 {
117     int i,j,k;
118     int nj0,nj1;
119     int max_offset;
120     int max_excl_offset;
121     int iexcl;
122     int nj;
123     
124     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
125      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
126      * whether they should interact or not. 
127      *
128      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
129      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
130      * we create a jindex array with three elements per i atom: the starting point, the point to
131      * which we need to check exclusions, and the end point.
132      * This way we only have to allocate a short exclusion mask per i atom.
133      */
134     
135     /* Allocate memory for our modified jindex array */
136     snew(aadata->jindex,3*natoms);
137     
138     /* Pointer to lists with exclusion masks */
139     snew(aadata->exclusion_mask,natoms);
140     
141     for(i=0;i<natoms;i++)
142     {
143         /* Start */
144         aadata->jindex[3*i]   = i+1;
145         max_offset = calc_maxoffset(i,natoms);
146
147         /* Exclusions */
148         nj0   = excl->index[i];
149         nj1   = excl->index[i+1];
150
151         /* first check the max range */
152         max_excl_offset = -1;
153         
154         for(j=nj0; j<nj1; j++)
155         {
156             iexcl = excl->a[j];
157             
158             k = iexcl - i;
159             
160             if( k+natoms <= max_offset )
161             {
162                 k+=natoms;
163             }
164             
165             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
166         }
167         
168         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
169         
170         aadata->jindex[3*i+1] = i+1+max_excl_offset;        
171         
172         snew(aadata->exclusion_mask[i],max_excl_offset);
173         /* Include everything by default */
174         for(j=0;j<max_excl_offset;j++)
175         {
176             /* Use all-ones to mark interactions that should be present, compatible with SSE */
177             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
178         }
179         
180         /* Go through exclusions again */
181         for(j=nj0; j<nj1; j++)
182         {
183             iexcl = excl->a[j];
184             
185             k = iexcl - i;
186             
187             if( k+natoms <= max_offset )
188             {
189                 k+=natoms;
190             }
191             
192             if(k>0 && k<=max_excl_offset)
193             {
194                 /* Excluded, kill it! */
195                 aadata->exclusion_mask[i][k-1] = 0;
196             }
197         }
198         
199         /* End */
200         aadata->jindex[3*i+2] = i+1+max_offset;        
201     }
202 }
203
204
205 static void
206 setup_aadata(gmx_allvsall_data_t **  p_aadata,
207                          t_blocka *              excl, 
208              int                     natoms,
209              int *                   type,
210              int                     ntype,
211              real *                  pvdwparam)
212 {
213         int i,j,idx;
214         gmx_allvsall_data_t *aadata;
215     real *p;
216         
217         snew(aadata,1);
218         *p_aadata = aadata;
219     
220     /* Generate vdw params */
221     snew(aadata->pvdwparam,ntype);
222         
223     for(i=0;i<ntype;i++)
224     {
225         snew(aadata->pvdwparam[i],2*natoms);
226         p = aadata->pvdwparam[i];
227
228         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
229         for(j=0;j<natoms;j++)
230         {
231             idx             = i*ntype+type[j];
232             p[2*j]          = pvdwparam[2*idx];
233             p[2*j+1]        = pvdwparam[2*idx+1];
234         }        
235     }
236     
237     setup_exclusions_and_indices(aadata,excl,natoms);
238 }
239
240
241
242 void
243 nb_kernel_allvsallgb(t_nblist *                nlist,
244                      rvec *                    xx,
245                      rvec *                    ff,
246                      t_forcerec *              fr,
247                      t_mdatoms *               mdatoms,
248                      nb_kernel_data_t *        kernel_data,
249                      t_nrnb *                  nrnb)
250 {
251         gmx_allvsall_data_t *aadata;
252         int        natoms;
253         int        ni0,ni1;
254         int        nj0,nj1,nj2;
255         int        i,j,k;
256         real *     charge;
257         int *      type;
258     real       facel;
259         real *     pvdw;
260         int        ggid;
261     int *      mask;
262     real *     GBtab;
263     real       gbfactor;
264     real *     invsqrta;
265     real *     dvda;
266     real       vgbtot,dvdasum;
267     int        nnn,n0;
268     
269     real       ix,iy,iz,iq;
270     real       fix,fiy,fiz;
271     real       jx,jy,jz,qq;
272     real       dx,dy,dz;
273     real       tx,ty,tz;
274     real       rsq,rinv,rinvsq,rinvsix;
275     real       vcoul,vctot;
276     real       c6,c12,Vvdw6,Vvdw12,Vvdwtot;
277     real       fscal,dvdatmp,fijC,vgb;
278     real       Y,F,Fp,Geps,Heps2,VV,FF,eps,eps2,r,rt;
279     real       dvdaj,gbscale,isaprod,isai,isaj,gbtabscale;
280     real *     f;
281     real *     x;
282     t_blocka * excl;
283     real *     Vvdw;
284     real *     Vc;
285     real *     vpol;
286
287     x                   = xx[0];
288     f                   = ff[0];
289         charge              = mdatoms->chargeA;
290         type                = mdatoms->typeA;
291     gbfactor            = ((1.0/fr->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
292         facel               = fr->epsfac;
293     GBtab               = fr->gbtab.data;
294     gbtabscale          = fr->gbtab.scale;
295     invsqrta            = fr->invsqrta;
296     dvda                = fr->dvda;
297     vpol                = kernel_data->energygrp_polarization;
298
299     natoms              = mdatoms->nr;
300         ni0                 = mdatoms->start;
301         ni1                 = mdatoms->start+mdatoms->homenr;
302     
303     aadata              = fr->AllvsAll_work;
304     excl                = kernel_data->exclusions;
305
306     Vc                  = kernel_data->energygrp_elec;
307     Vvdw                = kernel_data->energygrp_vdw;
308
309         if(aadata==NULL)
310         {
311                 setup_aadata(&aadata,excl,natoms,type,fr->ntype,fr->nbfp);
312         fr->AllvsAll_work  = aadata;
313         }
314
315         for(i=ni0; i<ni1; i++)
316         {
317                 /* We assume shifts are NOT used for all-vs-all interactions */
318                 
319                 /* Load i atom data */
320         ix                = x[3*i];
321         iy                = x[3*i+1];
322         iz                = x[3*i+2];
323         iq                = facel*charge[i];
324         
325         isai              = invsqrta[i];
326
327         pvdw              = aadata->pvdwparam[type[i]];
328         
329                 /* Zero the potential energy for this list */
330                 Vvdwtot           = 0.0;
331         vctot             = 0.0;
332         vgbtot            = 0.0;
333         dvdasum           = 0.0;              
334
335                 /* Clear i atom forces */
336         fix               = 0.0;
337         fiy               = 0.0;
338         fiz               = 0.0;
339         
340                 /* Load limits for loop over neighbors */
341                 nj0              = aadata->jindex[3*i];
342                 nj1              = aadata->jindex[3*i+1];
343                 nj2              = aadata->jindex[3*i+2];
344
345         mask             = aadata->exclusion_mask[i];
346                 
347         /* Prologue part, including exclusion mask */
348         for(j=nj0; j<nj1; j++,mask++)
349         {          
350             if(*mask!=0)
351             {
352                 k = j%natoms;
353                 
354                 /* load j atom coordinates */
355                 jx                = x[3*k];
356                 jy                = x[3*k+1];
357                 jz                = x[3*k+2];
358                 
359                 /* Calculate distance */
360                 dx                = ix - jx;      
361                 dy                = iy - jy;      
362                 dz                = iz - jz;      
363                 rsq               = dx*dx+dy*dy+dz*dz;
364                 
365                 /* Calculate 1/r and 1/r2 */
366                 rinv             = gmx_invsqrt(rsq);
367                 
368                 /* Load parameters for j atom */
369                 isaj             = invsqrta[k];  
370                 isaprod          = isai*isaj;      
371                 qq               = iq*charge[k]; 
372                 vcoul            = qq*rinv;      
373                 fscal            = vcoul*rinv;   
374                 qq               = isaprod*(-qq)*gbfactor;  
375                 gbscale          = isaprod*gbtabscale;
376                 c6                = pvdw[2*k];
377                 c12               = pvdw[2*k+1];
378                 rinvsq           = rinv*rinv;  
379                 
380                 /* Tabulated Generalized-Born interaction */
381                 dvdaj            = dvda[k];      
382                 r                = rsq*rinv;   
383                 
384                 /* Calculate table index */
385                 rt               = r*gbscale;      
386                 n0               = rt;             
387                 eps              = rt-n0;          
388                 eps2             = eps*eps;        
389                 nnn              = 4*n0;           
390                 Y                = GBtab[nnn];     
391                 F                = GBtab[nnn+1];   
392                 Geps             = eps*GBtab[nnn+2];
393                 Heps2            = eps2*GBtab[nnn+3];
394                 Fp               = F+Geps+Heps2;   
395                 VV               = Y+eps*Fp;       
396                 FF               = Fp+Geps+2.0*Heps2;
397                 vgb              = qq*VV;          
398                 fijC             = qq*FF*gbscale;  
399                 dvdatmp          = -0.5*(vgb+fijC*r);
400                 dvdasum          = dvdasum + dvdatmp;
401                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
402                 vctot            = vctot + vcoul;  
403                 vgbtot           = vgbtot + vgb;
404                 
405                 /* Lennard-Jones interaction */
406                 rinvsix          = rinvsq*rinvsq*rinvsq;
407                 Vvdw6            = c6*rinvsix;     
408                 Vvdw12           = c12*rinvsix*rinvsix;
409                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
410                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
411                                 
412                 /* Calculate temporary vectorial force */
413                 tx                = fscal*dx;     
414                 ty                = fscal*dy;     
415                 tz                = fscal*dz;     
416                 
417                 /* Increment i atom force */
418                 fix               = fix + tx;      
419                 fiy               = fiy + ty;      
420                 fiz               = fiz + tz;      
421             
422                 /* Decrement j atom force */
423                 f[3*k]            = f[3*k]   - tx;
424                 f[3*k+1]          = f[3*k+1] - ty;
425                 f[3*k+2]          = f[3*k+2] - tz;
426             }
427             /* Inner loop uses 38 flops/iteration */
428         }
429
430         /* Main part, no exclusions */
431         for(j=nj1; j<nj2; j++)
432         {       
433             k = j%natoms;
434
435             /* load j atom coordinates */
436             jx                = x[3*k];
437             jy                = x[3*k+1];
438             jz                = x[3*k+2];
439             
440             /* Calculate distance */
441             dx                = ix - jx;      
442             dy                = iy - jy;      
443             dz                = iz - jz;      
444             rsq               = dx*dx+dy*dy+dz*dz;
445             
446             /* Calculate 1/r and 1/r2 */
447             rinv             = gmx_invsqrt(rsq);
448             
449             /* Load parameters for j atom */
450             isaj             = invsqrta[k];  
451             isaprod          = isai*isaj;      
452             qq               = iq*charge[k]; 
453             vcoul            = qq*rinv;      
454             fscal            = vcoul*rinv;   
455             qq               = isaprod*(-qq)*gbfactor;  
456             gbscale          = isaprod*gbtabscale;
457             c6                = pvdw[2*k];
458             c12               = pvdw[2*k+1];
459             rinvsq           = rinv*rinv;  
460             
461             /* Tabulated Generalized-Born interaction */
462             dvdaj            = dvda[k];      
463             r                = rsq*rinv;   
464             
465             /* Calculate table index */
466             rt               = r*gbscale;      
467             n0               = rt;             
468             eps              = rt-n0;          
469             eps2             = eps*eps;        
470             nnn              = 4*n0;           
471             Y                = GBtab[nnn];     
472             F                = GBtab[nnn+1];   
473             Geps             = eps*GBtab[nnn+2];
474             Heps2            = eps2*GBtab[nnn+3];
475             Fp               = F+Geps+Heps2;   
476             VV               = Y+eps*Fp;       
477             FF               = Fp+Geps+2.0*Heps2;
478             vgb              = qq*VV;          
479             fijC             = qq*FF*gbscale;  
480             dvdatmp          = -0.5*(vgb+fijC*r);
481             dvdasum          = dvdasum + dvdatmp;
482             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
483             vctot            = vctot + vcoul;  
484             vgbtot           = vgbtot + vgb;
485
486             /* Lennard-Jones interaction */
487             rinvsix          = rinvsq*rinvsq*rinvsq;
488             Vvdw6            = c6*rinvsix;     
489             Vvdw12           = c12*rinvsix*rinvsix;
490             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
491             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
492             
493             /* Calculate temporary vectorial force */
494             tx                = fscal*dx;     
495             ty                = fscal*dy;     
496             tz                = fscal*dz;     
497             
498             /* Increment i atom force */
499             fix               = fix + tx;      
500             fiy               = fiy + ty;      
501             fiz               = fiz + tz;      
502             
503             /* Decrement j atom force */
504             f[3*k]            = f[3*k]   - tx;
505             f[3*k+1]          = f[3*k+1] - ty;
506             f[3*k+2]          = f[3*k+2] - tz;
507             
508             /* Inner loop uses 38 flops/iteration */
509         }
510         
511         f[3*i]   += fix;
512         f[3*i+1] += fiy;
513         f[3*i+2] += fiz;
514                 
515                 /* Add potential energies to the group for this list */
516                 ggid             = 0;         
517         
518                 Vc[ggid]         = Vc[ggid] + vctot;
519         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
520         vpol[ggid]       = vpol[ggid] + vgbtot;
521         dvda[i]          = dvda[i] + dvdasum*isai*isai;
522
523                 /* Outer loop uses 6 flops/iteration */
524         }    
525
526     /* 12 flops per outer iteration
527      * 19 flops per inner iteration
528      */
529     inc_nrnb(nrnb,eNR_NBKERNEL_ELEC_VDW_VF,(ni1-ni0)*12 + ((ni1-ni0)*natoms/2)*19);
530 }
531
532