Merge remote-tracking branch 'gerrit/release-4-6'
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsallgb.c
1 /*
2  *                This source code is part of
3  *
4  *                 G   R   O   M   A   C   S
5  *
6  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
7  * Copyright (c) 2001-2009, The GROMACS Development Team
8  *
9  * Gromacs is a library for molecular simulation and trajectory analysis,
10  * written by Erik Lindahl, David van der Spoel, Berk Hess, and others - for
11  * a full list of developers and information, check out http://www.gromacs.org
12  *
13  * This program is free software; you can redistribute it and/or modify it under 
14  * the terms of the GNU Lesser General Public License as published by the Free 
15  * Software Foundation; either version 2 of the License, or (at your option) any 
16  * later version.
17  * As a special exception, you may use this file as part of a free software
18  * library without restriction.  Specifically, if other files instantiate
19  * templates or use macros or inline functions from this file, or you compile
20  * this file and link it with other files to produce an executable, this
21  * file does not by itself cause the resulting executable to be covered by
22  * the GNU Lesser General Public License.  
23  *
24  * In plain-speak: do not worry about classes/macros/templates either - only
25  * changes to the library have to be LGPL, not an application linking with it.
26  *
27  * To help fund GROMACS development, we humbly ask that you cite
28  * the papers people have written on it - you can find them on the website!
29  */
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33
34 #include <math.h>
35
36 #include "types/simple.h"
37
38 #include "vec.h"
39 #include "smalloc.h"
40
41 #include "nb_kernel_allvsallgb.h"
42
43 typedef struct 
44 {
45     real **    pvdwparam;
46     int *      jindex;
47     int **     exclusion_mask;
48
49 gmx_allvsall_data_t;
50                 
51 static int 
52 calc_maxoffset(int i,int natoms)
53 {
54     int maxoffset;
55     
56     if ((natoms % 2) == 1)
57     {
58         /* Odd number of atoms, easy */
59         maxoffset = natoms/2;
60     }
61     else if ((natoms % 4) == 0)
62     {
63         /* Multiple of four is hard */
64         if (i < natoms/2)
65         {
66             if ((i % 2) == 0)
67             {
68                 maxoffset=natoms/2;
69             }
70             else
71             {
72                 maxoffset=natoms/2-1;
73             }
74         }
75         else
76         {
77             if ((i % 2) == 1)
78             {
79                 maxoffset=natoms/2;
80             }
81             else
82             {
83                 maxoffset=natoms/2-1;
84             }
85         }
86     }
87     else
88     {
89         /* natoms/2 = odd */
90         if ((i % 2) == 0)
91         {
92             maxoffset=natoms/2;
93         }
94         else
95         {
96             maxoffset=natoms/2-1;
97         }
98     }
99     
100     return maxoffset;
101 }
102
103
104 static void
105 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
106                              t_blocka *              excl, 
107                              int                     natoms)
108 {
109     int i,j,k;
110     int nj0,nj1;
111     int max_offset;
112     int max_excl_offset;
113     int iexcl;
114     int nj;
115     
116     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
117      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
118      * whether they should interact or not. 
119      *
120      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
121      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
122      * we create a jindex array with three elements per i atom: the starting point, the point to
123      * which we need to check exclusions, and the end point.
124      * This way we only have to allocate a short exclusion mask per i atom.
125      */
126     
127     /* Allocate memory for our modified jindex array */
128     snew(aadata->jindex,3*natoms);
129     
130     /* Pointer to lists with exclusion masks */
131     snew(aadata->exclusion_mask,natoms);
132     
133     for(i=0;i<natoms;i++)
134     {
135         /* Start */
136         aadata->jindex[3*i]   = i+1;
137         max_offset = calc_maxoffset(i,natoms);
138
139         /* Exclusions */
140         nj0   = excl->index[i];
141         nj1   = excl->index[i+1];
142
143         /* first check the max range */
144         max_excl_offset = -1;
145         
146         for(j=nj0; j<nj1; j++)
147         {
148             iexcl = excl->a[j];
149             
150             k = iexcl - i;
151             
152             if( k+natoms <= max_offset )
153             {
154                 k+=natoms;
155             }
156             
157             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
158         }
159         
160         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
161         
162         aadata->jindex[3*i+1] = i+1+max_excl_offset;        
163         
164         snew(aadata->exclusion_mask[i],max_excl_offset);
165         /* Include everything by default */
166         for(j=0;j<max_excl_offset;j++)
167         {
168             /* Use all-ones to mark interactions that should be present, compatible with SSE */
169             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
170         }
171         
172         /* Go through exclusions again */
173         for(j=nj0; j<nj1; j++)
174         {
175             iexcl = excl->a[j];
176             
177             k = iexcl - i;
178             
179             if( k+natoms <= max_offset )
180             {
181                 k+=natoms;
182             }
183             
184             if(k>0 && k<=max_excl_offset)
185             {
186                 /* Excluded, kill it! */
187                 aadata->exclusion_mask[i][k-1] = 0;
188             }
189         }
190         
191         /* End */
192         aadata->jindex[3*i+2] = i+1+max_offset;        
193     }
194 }
195
196
197 static void
198 setup_aadata(gmx_allvsall_data_t **  p_aadata,
199                          t_blocka *              excl, 
200              int                     natoms,
201              int *                   type,
202              int                     ntype,
203              real *                  pvdwparam)
204 {
205         int i,j,idx;
206         gmx_allvsall_data_t *aadata;
207     real *p;
208         
209         snew(aadata,1);
210         *p_aadata = aadata;
211     
212     /* Generate vdw params */
213     snew(aadata->pvdwparam,ntype);
214         
215     for(i=0;i<ntype;i++)
216     {
217         snew(aadata->pvdwparam[i],2*natoms);
218         p = aadata->pvdwparam[i];
219
220         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
221         for(j=0;j<natoms;j++)
222         {
223             idx             = i*ntype+type[j];
224             p[2*j]          = pvdwparam[2*idx];
225             p[2*j+1]        = pvdwparam[2*idx+1];
226         }        
227     }
228     
229     setup_exclusions_and_indices(aadata,excl,natoms);
230 }
231
232
233
234 void
235 nb_kernel_allvsallgb(t_forcerec *           fr,
236                      t_mdatoms *            mdatoms,
237                      t_blocka *             excl,    
238                      real *                 x,
239                      real *                 f,
240                      real *                 Vc,
241                      real *                 Vvdw,
242                      real *                 vpol,
243                      int *                  outeriter,
244                      int *                  inneriter,
245                      void *                 work)
246 {
247         gmx_allvsall_data_t *aadata;
248         int        natoms;
249         int        ni0,ni1;
250         int        nj0,nj1,nj2;
251         int        i,j,k;
252         real *     charge;
253         int *      type;
254     real       facel;
255         real *     pvdw;
256         int        ggid;
257     int *      mask;
258     real *     GBtab;
259     real       gbfactor;
260     real *     invsqrta;
261     real *     dvda;
262     real       vgbtot,dvdasum;
263     int        nnn,n0;
264     
265     real       ix,iy,iz,iq;
266     real       fix,fiy,fiz;
267     real       jx,jy,jz,qq;
268     real       dx,dy,dz;
269     real       tx,ty,tz;
270     real       rsq,rinv,rinvsq,rinvsix;
271     real       vcoul,vctot;
272     real       c6,c12,Vvdw6,Vvdw12,Vvdwtot;
273     real       fscal,dvdatmp,fijC,vgb;
274     real       Y,F,Fp,Geps,Heps2,VV,FF,eps,eps2,r,rt;
275     real       dvdaj,gbscale,isaprod,isai,isaj,gbtabscale;
276     
277         charge              = mdatoms->chargeA;
278         type                = mdatoms->typeA;
279     gbfactor            = ((1.0/fr->epsilon_r) - (1.0/fr->gb_epsilon_solvent));
280         facel               = fr->epsfac;
281     GBtab               = fr->gbtab.tab;
282     gbtabscale          = fr->gbtab.scale;
283     invsqrta            = fr->invsqrta;
284     dvda                = fr->dvda;
285     
286     natoms              = mdatoms->nr;
287         ni0                 = mdatoms->start;
288         ni1                 = mdatoms->start+mdatoms->homenr;
289     
290     aadata = *((gmx_allvsall_data_t **)work);
291
292         if(aadata==NULL)
293         {
294                 setup_aadata(&aadata,excl,natoms,type,fr->ntype,fr->nbfp);
295         *((gmx_allvsall_data_t **)work) = aadata;
296         }
297
298         for(i=ni0; i<ni1; i++)
299         {
300                 /* We assume shifts are NOT used for all-vs-all interactions */
301                 
302                 /* Load i atom data */
303         ix                = x[3*i];
304         iy                = x[3*i+1];
305         iz                = x[3*i+2];
306         iq                = facel*charge[i];
307         
308         isai              = invsqrta[i];
309
310         pvdw              = aadata->pvdwparam[type[i]];
311         
312                 /* Zero the potential energy for this list */
313                 Vvdwtot           = 0.0;
314         vctot             = 0.0;
315         vgbtot            = 0.0;
316         dvdasum           = 0.0;              
317
318                 /* Clear i atom forces */
319         fix               = 0.0;
320         fiy               = 0.0;
321         fiz               = 0.0;
322         
323                 /* Load limits for loop over neighbors */
324                 nj0              = aadata->jindex[3*i];
325                 nj1              = aadata->jindex[3*i+1];
326                 nj2              = aadata->jindex[3*i+2];
327
328         mask             = aadata->exclusion_mask[i];
329                 
330         /* Prologue part, including exclusion mask */
331         for(j=nj0; j<nj1; j++,mask++)
332         {          
333             if(*mask!=0)
334             {
335                 k = j%natoms;
336                 
337                 /* load j atom coordinates */
338                 jx                = x[3*k];
339                 jy                = x[3*k+1];
340                 jz                = x[3*k+2];
341                 
342                 /* Calculate distance */
343                 dx                = ix - jx;      
344                 dy                = iy - jy;      
345                 dz                = iz - jz;      
346                 rsq               = dx*dx+dy*dy+dz*dz;
347                 
348                 /* Calculate 1/r and 1/r2 */
349                 rinv             = gmx_invsqrt(rsq);
350                 
351                 /* Load parameters for j atom */
352                 isaj             = invsqrta[k];  
353                 isaprod          = isai*isaj;      
354                 qq               = iq*charge[k]; 
355                 vcoul            = qq*rinv;      
356                 fscal            = vcoul*rinv;   
357                 qq               = isaprod*(-qq)*gbfactor;  
358                 gbscale          = isaprod*gbtabscale;
359                 c6                = pvdw[2*k];
360                 c12               = pvdw[2*k+1];
361                 rinvsq           = rinv*rinv;  
362                 
363                 /* Tabulated Generalized-Born interaction */
364                 dvdaj            = dvda[k];      
365                 r                = rsq*rinv;   
366                 
367                 /* Calculate table index */
368                 rt               = r*gbscale;      
369                 n0               = rt;             
370                 eps              = rt-n0;          
371                 eps2             = eps*eps;        
372                 nnn              = 4*n0;           
373                 Y                = GBtab[nnn];     
374                 F                = GBtab[nnn+1];   
375                 Geps             = eps*GBtab[nnn+2];
376                 Heps2            = eps2*GBtab[nnn+3];
377                 Fp               = F+Geps+Heps2;   
378                 VV               = Y+eps*Fp;       
379                 FF               = Fp+Geps+2.0*Heps2;
380                 vgb              = qq*VV;          
381                 fijC             = qq*FF*gbscale;  
382                 dvdatmp          = -0.5*(vgb+fijC*r);
383                 dvdasum          = dvdasum + dvdatmp;
384                 dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
385                 vctot            = vctot + vcoul;  
386                 vgbtot           = vgbtot + vgb;
387                 
388                 /* Lennard-Jones interaction */
389                 rinvsix          = rinvsq*rinvsq*rinvsq;
390                 Vvdw6            = c6*rinvsix;     
391                 Vvdw12           = c12*rinvsix*rinvsix;
392                 Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
393                 fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
394                                 
395                 /* Calculate temporary vectorial force */
396                 tx                = fscal*dx;     
397                 ty                = fscal*dy;     
398                 tz                = fscal*dz;     
399                 
400                 /* Increment i atom force */
401                 fix               = fix + tx;      
402                 fiy               = fiy + ty;      
403                 fiz               = fiz + tz;      
404             
405                 /* Decrement j atom force */
406                 f[3*k]            = f[3*k]   - tx;
407                 f[3*k+1]          = f[3*k+1] - ty;
408                 f[3*k+2]          = f[3*k+2] - tz;
409             }
410             /* Inner loop uses 38 flops/iteration */
411         }
412
413         /* Main part, no exclusions */
414         for(j=nj1; j<nj2; j++)
415         {       
416             k = j%natoms;
417
418             /* load j atom coordinates */
419             jx                = x[3*k];
420             jy                = x[3*k+1];
421             jz                = x[3*k+2];
422             
423             /* Calculate distance */
424             dx                = ix - jx;      
425             dy                = iy - jy;      
426             dz                = iz - jz;      
427             rsq               = dx*dx+dy*dy+dz*dz;
428             
429             /* Calculate 1/r and 1/r2 */
430             rinv             = gmx_invsqrt(rsq);
431             
432             /* Load parameters for j atom */
433             isaj             = invsqrta[k];  
434             isaprod          = isai*isaj;      
435             qq               = iq*charge[k]; 
436             vcoul            = qq*rinv;      
437             fscal            = vcoul*rinv;   
438             qq               = isaprod*(-qq)*gbfactor;  
439             gbscale          = isaprod*gbtabscale;
440             c6                = pvdw[2*k];
441             c12               = pvdw[2*k+1];
442             rinvsq           = rinv*rinv;  
443             
444             /* Tabulated Generalized-Born interaction */
445             dvdaj            = dvda[k];      
446             r                = rsq*rinv;   
447             
448             /* Calculate table index */
449             rt               = r*gbscale;      
450             n0               = rt;             
451             eps              = rt-n0;          
452             eps2             = eps*eps;        
453             nnn              = 4*n0;           
454             Y                = GBtab[nnn];     
455             F                = GBtab[nnn+1];   
456             Geps             = eps*GBtab[nnn+2];
457             Heps2            = eps2*GBtab[nnn+3];
458             Fp               = F+Geps+Heps2;   
459             VV               = Y+eps*Fp;       
460             FF               = Fp+Geps+2.0*Heps2;
461             vgb              = qq*VV;          
462             fijC             = qq*FF*gbscale;  
463             dvdatmp          = -0.5*(vgb+fijC*r);
464             dvdasum          = dvdasum + dvdatmp;
465             dvda[k]          = dvdaj+dvdatmp*isaj*isaj;
466             vctot            = vctot + vcoul;  
467             vgbtot           = vgbtot + vgb;
468
469             /* Lennard-Jones interaction */
470             rinvsix          = rinvsq*rinvsq*rinvsq;
471             Vvdw6            = c6*rinvsix;     
472             Vvdw12           = c12*rinvsix*rinvsix;
473             Vvdwtot          = Vvdwtot+Vvdw12-Vvdw6;
474             fscal            = (12.0*Vvdw12-6.0*Vvdw6)*rinvsq-(fijC-fscal)*rinv;
475             
476             /* Calculate temporary vectorial force */
477             tx                = fscal*dx;     
478             ty                = fscal*dy;     
479             tz                = fscal*dz;     
480             
481             /* Increment i atom force */
482             fix               = fix + tx;      
483             fiy               = fiy + ty;      
484             fiz               = fiz + tz;      
485             
486             /* Decrement j atom force */
487             f[3*k]            = f[3*k]   - tx;
488             f[3*k+1]          = f[3*k+1] - ty;
489             f[3*k+2]          = f[3*k+2] - tz;
490             
491             /* Inner loop uses 38 flops/iteration */
492         }
493         
494         f[3*i]   += fix;
495         f[3*i+1] += fiy;
496         f[3*i+2] += fiz;
497                 
498                 /* Add potential energies to the group for this list */
499                 ggid             = 0;         
500         
501                 Vc[ggid]         = Vc[ggid] + vctot;
502         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
503         vpol[ggid]       = vpol[ggid] + vgbtot;
504         dvda[i]          = dvda[i] + dvdasum*isai*isai;
505
506                 /* Outer loop uses 6 flops/iteration */
507         }    
508
509     /* Write outer/inner iteration count to pointers */
510     *outeriter       = ni1-ni0;         
511     *inneriter       = (ni1-ni0)*natoms/2;         
512 }
513
514