Merge release-4-6 into master
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.c
1 /*
2  *                This source code is part of
3  *
4  *                 G   R   O   M   A   C   S
5  *
6  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
7  * Copyright (c) 2001-2009, The GROMACS Development Team
8  *
9  * Gromacs is a library for molecular simulation and trajectory analysis,
10  * written by Erik Lindahl, David van der Spoel, Berk Hess, and others - for
11  * a full list of developers and information, check out http://www.gromacs.org
12  *
13  * This program is free software; you can redistribute it and/or modify it under 
14  * the terms of the GNU Lesser General Public License as published by the Free 
15  * Software Foundation; either version 2 of the License, or (at your option) any 
16  * later version.
17  * As a special exception, you may use this file as part of a free software
18  * library without restriction.  Specifically, if other files instantiate
19  * templates or use macros or inline functions from this file, or you compile
20  * this file and link it with other files to produce an executable, this
21  * file does not by itself cause the resulting executable to be covered by
22  * the GNU Lesser General Public License.  
23  *
24  * In plain-speak: do not worry about classes/macros/templates either - only
25  * changes to the library have to be LGPL, not an application linking with it.
26  *
27  * To help fund GROMACS development, we humbly ask that you cite
28  * the papers people have written on it - you can find them on the website!
29  */
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33
34 #include <math.h>
35
36 #include "types/simple.h"
37
38 #include "vec.h"
39 #include "smalloc.h"
40
41 #include "nb_kernel_allvsallgb.h"
42 #include "nrnb.h"
43
44 typedef struct 
45 {
46     real **    pvdwparam;
47     int *      jindex;
48     int **     exclusion_mask;
49
50 gmx_allvsall_data_t;
51                 
52 static int 
53 calc_maxoffset(int i,int natoms)
54 {
55     int maxoffset;
56     
57     if ((natoms % 2) == 1)
58     {
59         /* Odd number of atoms, easy */
60         maxoffset = natoms/2;
61     }
62     else if ((natoms % 4) == 0)
63     {
64         /* Multiple of four is hard */
65         if (i < natoms/2)
66         {
67             if ((i % 2) == 0)
68             {
69                 maxoffset=natoms/2;
70             }
71             else
72             {
73                 maxoffset=natoms/2-1;
74             }
75         }
76         else
77         {
78             if ((i % 2) == 1)
79             {
80                 maxoffset=natoms/2;
81             }
82             else
83             {
84                 maxoffset=natoms/2-1;
85             }
86         }
87     }
88     else
89     {
90         /* natoms/2 = odd */
91         if ((i % 2) == 0)
92         {
93             maxoffset=natoms/2;
94         }
95         else
96         {
97             maxoffset=natoms/2-1;
98         }
99     }
100     
101     return maxoffset;
102 }
103
104
105 static void
106 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
107                              t_blocka *              excl, 
108                              int                     natoms)
109 {
110     int i,j,k;
111     int nj0,nj1;
112     int max_offset;
113     int max_excl_offset;
114     int iexcl;
115     int nj;
116     
117     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
118      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
119      * whether they should interact or not. 
120      *
121      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
122      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
123      * we create a jindex array with three elements per i atom: the starting point, the point to
124      * which we need to check exclusions, and the end point.
125      * This way we only have to allocate a short exclusion mask per i atom.
126      */
127     
128     /* Allocate memory for our modified jindex array */
129     snew(aadata->jindex,3*natoms);
130     
131     /* Pointer to lists with exclusion masks */
132     snew(aadata->exclusion_mask,natoms);
133     
134     for(i=0;i<natoms;i++)
135     {
136         /* Start */
137         aadata->jindex[3*i]   = i+1;
138         max_offset = calc_maxoffset(i,natoms);
139
140         /* Exclusions */
141         nj0   = excl->index[i];
142         nj1   = excl->index[i+1];
143
144         /* first check the max range */
145         max_excl_offset = -1;
146         
147         for(j=nj0; j<nj1; j++)
148         {
149             iexcl = excl->a[j];
150             
151             k = iexcl - i;
152             
153             if( k+natoms <= max_offset )
154             {
155                 k+=natoms;
156             }
157             
158             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
159         }
160         
161         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
162         
163         aadata->jindex[3*i+1] = i+1+max_excl_offset;        
164         
165         snew(aadata->exclusion_mask[i],max_excl_offset);
166         /* Include everything by default */
167         for(j=0;j<max_excl_offset;j++)
168         {
169             /* Use all-ones to mark interactions that should be present, compatible with SSE */
170             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
171         }
172         
173         /* Go through exclusions again */
174         for(j=nj0; j<nj1; j++)
175         {
176             iexcl = excl->a[j];
177             
178             k = iexcl - i;
179             
180             if( k+natoms <= max_offset )
181             {
182                 k+=natoms;
183             }
184             
185             if(k>0 && k<=max_excl_offset)
186             {
187                 /* Excluded, kill it! */
188                 aadata->exclusion_mask[i][k-1] = 0;
189             }
190         }
191         
192         /* End */
193         aadata->jindex[3*i+2] = i+1+max_offset;        
194     }
195 }
196
197
198 static void
199 setup_aadata(gmx_allvsall_data_t **  p_aadata,
200                          t_blocka *              excl, 
201              int                     natoms,
202              int *                   type,
203              int                     ntype,
204              real *                  pvdwparam)
205 {
206         int i,j,idx;
207         gmx_allvsall_data_t *aadata;
208     real *p;
209         
210         snew(aadata,1);
211         *p_aadata = aadata;
212     
213     /* Generate vdw params */
214     snew(aadata->pvdwparam,ntype);
215         
216     for(i=0;i<ntype;i++)
217     {
218         snew(aadata->pvdwparam[i],2*natoms);
219         p = aadata->pvdwparam[i];
220
221         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
222         for(j=0;j<natoms;j++)
223         {
224             idx             = i*ntype+type[j];
225             p[2*j]          = pvdwparam[2*idx];
226             p[2*j+1]        = pvdwparam[2*idx+1];
227         }        
228     }
229     
230     setup_exclusions_and_indices(aadata,excl,natoms);
231 }
232
233
234
235 void
236 nb_kernel_allvsallgb(t_nblist *                nlist,
237                      rvec *                    xx,
238                      rvec *                    ff,
239                      t_forcerec *              fr,
240                      t_mdatoms *               mdatoms,
241                      nb_kernel_data_t *        kernel_data,
242                      t_nrnb *                  nrnb)
243 {
244         gmx_allvsall_data_t *aadata;
245         int        natoms;
246         int        ni0,ni1;
247         int        nj0,nj1,nj2;
248         int        i,j,k;
249         real *     charge;
250         int *      type;
251     real       facel;
252         real *     pvdw;
253         int        ggid;
254     int *      mask;
255     real *     GBtab;
256     real       gbfactor;
257     real *     invsqrta;
258     real *     dvda;
259     real       vgbtot,dvdasum;
260     int        nnn,n0;
261     
262     real       ix,iy,iz,iq;
263     real       fix,fiy,fiz;
264     real       jx,jy,jz,qq;
265     real       dx,dy,dz;
266     real       tx,ty,tz;
267     real       rsq,rinv,rinvsq,rinvsix;
268     real       vcoul,vctot;
269     real       c6,c12,Vvdw6,Vvdw12,Vvdwtot;
270     real       fscal,dvdatmp,fijC,vgb;
271     real       Y,F,Fp,Geps,Heps2,VV,FF,eps,eps2,r,rt;
272     real       dvdaj,gbscale,isaprod,isai,isaj,gbtabscale;
273     real *     f;
274     real *     x;
275     t_blocka * excl;
276     real *     Vvdw;
277     real *     Vc;
278     real *     vpol;
279
280     x                   = xx[0];
281     f                   = ff[0];
282         charge              = mdatoms->chargeA;
283         type                = mdatoms->typeA;
284     gbfactor            = ((1.0/fr->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
285         facel               = fr->epsfac;
286     GBtab               = fr->gbtab.data;
287     gbtabscale          = fr->gbtab.scale;
288     invsqrta            = fr->invsqrta;
289     dvda                = fr->dvda;
290     vpol                = kernel_data->energygrp_polarization;
291
292     natoms              = mdatoms->nr;
293         ni0                 = mdatoms->start;
294         ni1                 = mdatoms->start+mdatoms->homenr;
295     
296     aadata              = fr->AllvsAll_work;
297     excl                = kernel_data->exclusions;
298
299     Vc                  = kernel_data->energygrp_elec;
300     Vvdw                = kernel_data->energygrp_vdw;
301
302         if(aadata==NULL)
303         {
304                 setup_aadata(&aadata,excl,natoms,type,fr->ntype,fr->nbfp);
305         fr->AllvsAll_work  = aadata;
306         }
307
308         for(i=ni0; i<ni1; i++)
309         {
310                 /* We assume shifts are NOT used for all-vs-all interactions */
311                 
312                 /* Load i atom data */
313         ix                = x[3*i];
314         iy                = x[3*i+1];
315         iz                = x[3*i+2];
316         iq                = facel*charge[i];
317         
318         isai              = invsqrta[i];
319
320         pvdw              = aadata->pvdwparam[type[i]];
321         
322                 /* Zero the potential energy for this list */
323                 Vvdwtot           = 0.0;
324         vctot             = 0.0;
325         vgbtot            = 0.0;
326         dvdasum           = 0.0;              
327
328                 /* Clear i atom forces */
329         fix               = 0.0;
330         fiy               = 0.0;
331         fiz               = 0.0;
332         
333                 /* Load limits for loop over neighbors */
334                 nj0              = aadata->jindex[3*i];
335                 nj1              = aadata->jindex[3*i+1];
336                 nj2              = aadata->jindex[3*i+2];
337
338         mask             = aadata->exclusion_mask[i];
339                 
340         /* Prologue part, including exclusion mask */
341         for(j=nj0; j<nj1; j++,mask++)
342         {          
343             if(*mask!=0)
344             {
345                 k = j%natoms;
346                 
347                 /* load j atom coordinates */
348                 jx                = x[3*k];
349                 jy                = x[3*k+1];
350                 jz                = x[3*k+2];
351                 
352                 /* Calculate distance */
353                 dx                = ix - jx;      
354                 dy                = iy - jy;      
355                 dz                = iz - jz;      
356                 rsq               = dx*dx+dy*dy+dz*dz;
357                 
358                 /* Calculate 1/r and 1/r2 */
359                 rinv             = gmx_invsqrt(rsq);
360                 
361                 /* Load parameters for j atom */
362                 isaj             = invsqrta[k];  
363                 isaprod          = isai*isaj;      
364                 qq               = iq*charge[k]; 
365                 vcoul            = qq*rinv;      
366                 fscal            = vcoul*rinv;   
367                 qq               = isaprod*(-qq)*gbfactor;  
368                 gbscale          = isaprod*gbtabscale;
369                 c6                = pvdw[2*k];
370                 c12               = pvdw[2*k+1];
371                 rinvsq           = rinv*rinv;  
372                 
373                 /* Tabulated Generalized-Born interaction */
374                 dvdaj            = dvda[k];      
375                 r                = rsq*rinv;   
376                 
377                 /* Calculate table index */
378                 rt               = r*gbscale;      
379                 n0               = rt;             
380                 eps              = rt-n0;          
381                 eps2             = eps*eps;        
382                 nnn              = 4*n0;           
383                 Y                = GBtab[nnn];     
384                 F                = GBtab[nnn+1];   
385                 Geps             = eps*GBtab[nnn+2];
386                 Heps2            = eps2*GBtab[nnn+3];
387                 Fp               = F+Geps+Heps2;   
388                 VV               = Y+eps*Fp;       
389                 FF               = Fp+Geps+2.0*Heps2;
390                 vgb              = qq*VV;          
391                 fijC             = qq*FF*gbscale;  
392                 dvdatmp          = -0.5*(vgb+fijC*r);
393                 dvdasum          = dvdasum + dvdatmp;
394                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
395                 vctot            = vctot + vcoul;  
396                 vgbtot           = vgbtot + vgb;
397                 
398                 /* Lennard-Jones interaction */
399                 rinvsix          = rinvsq*rinvsq*rinvsq;
400                 Vvdw6            = c6*rinvsix;     
401                 Vvdw12           = c12*rinvsix*rinvsix;
402                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
403                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
404                                 
405                 /* Calculate temporary vectorial force */
406                 tx                = fscal*dx;     
407                 ty                = fscal*dy;     
408                 tz                = fscal*dz;     
409                 
410                 /* Increment i atom force */
411                 fix               = fix + tx;      
412                 fiy               = fiy + ty;      
413                 fiz               = fiz + tz;      
414             
415                 /* Decrement j atom force */
416                 f[3*k]            = f[3*k]   - tx;
417                 f[3*k+1]          = f[3*k+1] - ty;
418                 f[3*k+2]          = f[3*k+2] - tz;
419             }
420             /* Inner loop uses 38 flops/iteration */
421         }
422
423         /* Main part, no exclusions */
424         for(j=nj1; j<nj2; j++)
425         {       
426             k = j%natoms;
427
428             /* load j atom coordinates */
429             jx                = x[3*k];
430             jy                = x[3*k+1];
431             jz                = x[3*k+2];
432             
433             /* Calculate distance */
434             dx                = ix - jx;      
435             dy                = iy - jy;      
436             dz                = iz - jz;      
437             rsq               = dx*dx+dy*dy+dz*dz;
438             
439             /* Calculate 1/r and 1/r2 */
440             rinv             = gmx_invsqrt(rsq);
441             
442             /* Load parameters for j atom */
443             isaj             = invsqrta[k];  
444             isaprod          = isai*isaj;      
445             qq               = iq*charge[k]; 
446             vcoul            = qq*rinv;      
447             fscal            = vcoul*rinv;   
448             qq               = isaprod*(-qq)*gbfactor;  
449             gbscale          = isaprod*gbtabscale;
450             c6                = pvdw[2*k];
451             c12               = pvdw[2*k+1];
452             rinvsq           = rinv*rinv;  
453             
454             /* Tabulated Generalized-Born interaction */
455             dvdaj            = dvda[k];      
456             r                = rsq*rinv;   
457             
458             /* Calculate table index */
459             rt               = r*gbscale;      
460             n0               = rt;             
461             eps              = rt-n0;          
462             eps2             = eps*eps;        
463             nnn              = 4*n0;           
464             Y                = GBtab[nnn];     
465             F                = GBtab[nnn+1];   
466             Geps             = eps*GBtab[nnn+2];
467             Heps2            = eps2*GBtab[nnn+3];
468             Fp               = F+Geps+Heps2;   
469             VV               = Y+eps*Fp;       
470             FF               = Fp+Geps+2.0*Heps2;
471             vgb              = qq*VV;          
472             fijC             = qq*FF*gbscale;  
473             dvdatmp          = -0.5*(vgb+fijC*r);
474             dvdasum          = dvdasum + dvdatmp;
475             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
476             vctot            = vctot + vcoul;  
477             vgbtot           = vgbtot + vgb;
478
479             /* Lennard-Jones interaction */
480             rinvsix          = rinvsq*rinvsq*rinvsq;
481             Vvdw6            = c6*rinvsix;     
482             Vvdw12           = c12*rinvsix*rinvsix;
483             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
484             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
485             
486             /* Calculate temporary vectorial force */
487             tx                = fscal*dx;     
488             ty                = fscal*dy;     
489             tz                = fscal*dz;     
490             
491             /* Increment i atom force */
492             fix               = fix + tx;      
493             fiy               = fiy + ty;      
494             fiz               = fiz + tz;      
495             
496             /* Decrement j atom force */
497             f[3*k]            = f[3*k]   - tx;
498             f[3*k+1]          = f[3*k+1] - ty;
499             f[3*k+2]          = f[3*k+2] - tz;
500             
501             /* Inner loop uses 38 flops/iteration */
502         }
503         
504         f[3*i]   += fix;
505         f[3*i+1] += fiy;
506         f[3*i+2] += fiz;
507                 
508                 /* Add potential energies to the group for this list */
509                 ggid             = 0;         
510         
511                 Vc[ggid]         = Vc[ggid] + vctot;
512         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
513         vpol[ggid]       = vpol[ggid] + vgbtot;
514         dvda[i]          = dvda[i] + dvdasum*isai*isai;
515
516                 /* Outer loop uses 6 flops/iteration */
517         }    
518
519     /* 12 flops per outer iteration
520      * 19 flops per inner iteration
521      */
522     inc_nrnb(nrnb,eNR_NBKERNEL_ELEC_VDW_VF,(ni1-ni0)*12 + ((ni1-ni0)*natoms/2)*19);
523 }
524
525