Merge branch 'release-4-6', adds the nbnxn functionality
[alexxy/gromacs.git] / src / gromacs / gmxlib / nonbonded / nb_kernel_c / nb_kernel_allvsall.c
1 /*
2  *                This source code is part of
3  *
4  *                 G   R   O   M   A   C   S
5  *
6  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
7  * Copyright (c) 2001-2009, The GROMACS Development Team
8  *
9  * Gromacs is a library for molecular simulation and trajectory analysis,
10  * written by Erik Lindahl, David van der Spoel, Berk Hess, and others - for
11  * a full list of developers and information, check out http://www.gromacs.org
12  *
13  * This program is free software; you can redistribute it and/or modify it under 
14  * the terms of the GNU Lesser General Public License as published by the Free 
15  * Software Foundation; either version 2 of the License, or (at your option) any 
16  * later version.
17  * As a special exception, you may use this file as part of a free software
18  * library without restriction.  Specifically, if other files instantiate
19  * templates or use macros or inline functions from this file, or you compile
20  * this file and link it with other files to produce an executable, this
21  * file does not by itself cause the resulting executable to be covered by
22  * the GNU Lesser General Public License.  
23  *
24  * In plain-speak: do not worry about classes/macros/templates either - only
25  * changes to the library have to be LGPL, not an application linking with it.
26  *
27  * To help fund GROMACS development, we humbly ask that you cite
28  * the papers people have written on it - you can find them on the website!
29  */
30 #ifdef HAVE_CONFIG_H
31 #include <config.h>
32 #endif
33
34 #include <math.h>
35
36 #include "types/simple.h"
37
38 #include "vec.h"
39 #include "smalloc.h"
40
41 #include "nb_kernel_allvsall.h"
42
43 typedef struct 
44 {
45     real **    pvdwparam;
46     int *      jindex;
47     int **     exclusion_mask;
48
49 gmx_allvsall_data_t;
50                 
51 static int 
52 calc_maxoffset(int i,int natoms)
53 {
54     int maxoffset;
55     
56     if ((natoms % 2) == 1)
57     {
58         /* Odd number of atoms, easy */
59         maxoffset = natoms/2;
60     }
61     else if ((natoms % 4) == 0)
62     {
63         /* Multiple of four is hard */
64         if (i < natoms/2)
65         {
66             if ((i % 2) == 0)
67             {
68                 maxoffset=natoms/2;
69             }
70             else
71             {
72                 maxoffset=natoms/2-1;
73             }
74         }
75         else
76         {
77             if ((i % 2) == 1)
78             {
79                 maxoffset=natoms/2;
80             }
81             else
82             {
83                 maxoffset=natoms/2-1;
84             }
85         }
86     }
87     else
88     {
89         /* natoms/2 = odd */
90         if ((i % 2) == 0)
91         {
92             maxoffset=natoms/2;
93         }
94         else
95         {
96             maxoffset=natoms/2-1;
97         }
98     }
99     
100     return maxoffset;
101 }
102
103
104 static void
105 setup_exclusions_and_indices(gmx_allvsall_data_t *   aadata,
106                              t_blocka *              excl, 
107                              int                     natoms)
108 {
109     int i,j,k,iexcl;
110     int nj0,nj1;
111     int max_offset;
112     int max_excl_offset;
113     int nj;
114     
115     /* This routine can appear to be a bit complex, but it is mostly book-keeping.
116      * To enable the fast all-vs-all kernel we need to be able to stream through all coordinates
117      * whether they should interact or not. 
118      *
119      * To avoid looping over the exclusions, we create a simple mask that is 1 if the interaction
120      * should be present, otherwise 0. Since exclusions typically only occur when i & j are close,
121      * we create a jindex array with three elements per i atom: the starting point, the point to
122      * which we need to check exclusions, and the end point.
123      * This way we only have to allocate a short exclusion mask per i atom.
124      */
125     
126     /* Allocate memory for our modified jindex array */
127     snew(aadata->jindex,3*natoms);
128     
129     /* Pointer to lists with exclusion masks */
130     snew(aadata->exclusion_mask,natoms);
131     
132     for(i=0;i<natoms;i++)
133     {
134         /* Start */
135         aadata->jindex[3*i]   = i+1;
136         max_offset = calc_maxoffset(i,natoms);
137         
138         /* Exclusions */
139         nj0   = excl->index[i];
140         nj1   = excl->index[i+1];
141
142         /* first check the max range */
143         max_excl_offset = -1;
144         
145         for(j=nj0; j<nj1; j++)
146         {
147             iexcl = excl->a[j];
148                         
149             k = iexcl - i;
150             
151             if( k+natoms <= max_offset )
152             {
153                 k+=natoms;
154             }
155                
156             max_excl_offset = (k > max_excl_offset) ? k : max_excl_offset;
157         }
158         
159         max_excl_offset = (max_offset < max_excl_offset) ? max_offset : max_excl_offset;
160         
161         aadata->jindex[3*i+1] = i+1+max_excl_offset;        
162
163         
164         snew(aadata->exclusion_mask[i],max_excl_offset);
165         /* Include everything by default */
166         for(j=0;j<max_excl_offset;j++)
167         {
168             /* Use all-ones to mark interactions that should be present, compatible with SSE */
169             aadata->exclusion_mask[i][j] = 0xFFFFFFFF;
170         }
171         
172         /* Go through exclusions again */
173         for(j=nj0; j<nj1; j++)
174         {
175             iexcl = excl->a[j];
176             
177             k = iexcl - i;
178             
179             if( k+natoms <= max_offset )
180             {
181                 k+=natoms;
182             }
183             
184             if(k>0 && k<=max_excl_offset)
185             {
186                 /* Excluded, kill it! */
187                 aadata->exclusion_mask[i][k-1] = 0;
188             }
189         }
190         
191         /* End */
192         aadata->jindex[3*i+2] = i+1+max_offset;        
193     }
194 }
195
196 static void
197 setup_aadata(gmx_allvsall_data_t **  p_aadata,
198                          t_blocka *              excl, 
199              int                     natoms,
200              int *                   type,
201              int                     ntype,
202              real *                  pvdwparam)
203 {
204         int i,j,idx;
205         gmx_allvsall_data_t *aadata;
206     real *p;
207         
208         snew(aadata,1);
209         *p_aadata = aadata;
210     
211     /* Generate vdw params */
212     snew(aadata->pvdwparam,ntype);
213         
214     for(i=0;i<ntype;i++)
215     {
216         snew(aadata->pvdwparam[i],2*natoms);
217         p = aadata->pvdwparam[i];
218
219         /* Lets keep it simple and use multiple steps - first create temp. c6/c12 arrays */
220         for(j=0;j<natoms;j++)
221         {
222             idx             = i*ntype+type[j];
223             p[2*j]          = pvdwparam[2*idx];
224             p[2*j+1]        = pvdwparam[2*idx+1];
225         }        
226     }
227     
228     setup_exclusions_and_indices(aadata,excl,natoms);
229 }
230
231
232
233 void
234 nb_kernel_allvsall(t_forcerec *           fr,
235                                    t_mdatoms *            mdatoms,
236                                    t_blocka *             excl,    
237                                    real *                 x,
238                                    real *                 f,
239                                    real *                 Vc,
240                                    real *                 Vvdw,
241                                    int *                  outeriter,
242                                    int *                  inneriter,
243                                    void *                 work)
244 {
245         gmx_allvsall_data_t *aadata;
246         int        natoms;
247         int        ni0,ni1;
248         int        nj0,nj1,nj2;
249         int        i,j,k;
250         real *     charge;
251         int *      type;
252     real       facel;
253         real *     pvdw;
254         int        ggid;
255     int *      mask;
256     
257     real       ix,iy,iz,iq;
258     real       fix,fiy,fiz;
259     real       jx,jy,jz,qq;
260     real       dx,dy,dz;
261     real       tx,ty,tz;
262     real       rsq,rinv,rinvsq,rinvsix;
263     real       vcoul,vctot;
264     real       c6,c12,Vvdw6,Vvdw12,Vvdwtot;
265     real       fscal;
266     
267         charge              = mdatoms->chargeA;
268         type                = mdatoms->typeA;
269         facel               = fr->epsfac;
270     natoms              = mdatoms->nr;
271         ni0                 = mdatoms->start;
272         ni1                 = mdatoms->start+mdatoms->homenr;
273     
274     aadata = *((gmx_allvsall_data_t **)work);
275
276         if(aadata==NULL)
277         {
278                 setup_aadata(&aadata,excl,natoms,type,fr->ntype,fr->nbfp);
279         *((gmx_allvsall_data_t **)work) = aadata;
280         }
281         
282         for(i=ni0; i<ni1; i++)
283         {
284                 /* We assume shifts are NOT used for all-vs-all interactions */
285                 
286                 /* Load i atom data */
287         ix                = x[3*i];
288         iy                = x[3*i+1];
289         iz                = x[3*i+2];
290         iq                = facel*charge[i];
291
292         pvdw              = aadata->pvdwparam[type[i]];
293         
294                 /* Zero the potential energy for this list */
295                 Vvdwtot           = 0.0;
296         vctot             = 0.0;
297
298                 /* Clear i atom forces */
299         fix               = 0.0;
300         fiy               = 0.0;
301         fiz               = 0.0;
302         
303                 /* Load limits for loop over neighbors */
304                 nj0              = aadata->jindex[3*i];
305                 nj1              = aadata->jindex[3*i+1];
306                 nj2              = aadata->jindex[3*i+2];
307
308         mask             = aadata->exclusion_mask[i];
309                 
310         /* Prologue part, including exclusion mask */
311         for(j=nj0; j<nj1; j++,mask++)
312         {          
313             if(*mask!=0)
314             {
315                 k = j%natoms;
316                 
317                 /* load j atom coordinates */
318                 jx                = x[3*k];
319                 jy                = x[3*k+1];
320                 jz                = x[3*k+2];
321                 
322                 /* Calculate distance */
323                 dx                = ix - jx;      
324                 dy                = iy - jy;      
325                 dz                = iz - jz;      
326                 rsq               = dx*dx+dy*dy+dz*dz;
327                 
328                 /* Calculate 1/r and 1/r2 */
329                 rinv              = gmx_invsqrt(rsq);
330                 rinvsq            = rinv*rinv;  
331                 
332                 /* Load parameters for j atom */
333                 qq                = iq*charge[k]; 
334                 c6                = pvdw[2*k];
335                 c12               = pvdw[2*k+1];
336                 
337                 /* Coulomb interaction */
338                 vcoul             = qq*rinv;      
339                 vctot             = vctot+vcoul;    
340                 
341                 /* Lennard-Jones interaction */
342                 rinvsix           = rinvsq*rinvsq*rinvsq;
343                 Vvdw6             = c6*rinvsix;     
344                 Vvdw12            = c12*rinvsix*rinvsix;
345                 Vvdwtot           = Vvdwtot+Vvdw12-Vvdw6;
346                 fscal             = (vcoul+12.0*Vvdw12-6.0*Vvdw6)*rinvsq;
347                 
348                 /* Calculate temporary vectorial force */
349                 tx                = fscal*dx;     
350                 ty                = fscal*dy;     
351                 tz                = fscal*dz;     
352                 
353                 /* Increment i atom force */
354                 fix               = fix + tx;      
355                 fiy               = fiy + ty;      
356                 fiz               = fiz + tz;      
357             
358                 /* Decrement j atom force */
359                 f[3*k]            = f[3*k]   - tx;
360                 f[3*k+1]          = f[3*k+1] - ty;
361                 f[3*k+2]          = f[3*k+2] - tz;
362             }
363             /* Inner loop uses 38 flops/iteration */
364         }
365
366         /* Main part, no exclusions */
367         for(j=nj1; j<nj2; j++)
368         {       
369             k = j%natoms;
370
371             /* load j atom coordinates */
372             jx                = x[3*k];
373             jy                = x[3*k+1];
374             jz                = x[3*k+2];
375             
376             /* Calculate distance */
377             dx                = ix - jx;      
378             dy                = iy - jy;      
379             dz                = iz - jz;      
380             rsq               = dx*dx+dy*dy+dz*dz;
381             
382             /* Calculate 1/r and 1/r2 */
383             rinv              = gmx_invsqrt(rsq);
384             rinvsq            = rinv*rinv;  
385             
386             /* Load parameters for j atom */
387             qq                = iq*charge[k]; 
388             c6                = pvdw[2*k];
389             c12               = pvdw[2*k+1];
390             
391             /* Coulomb interaction */
392             vcoul             = qq*rinv;      
393             vctot             = vctot+vcoul;    
394             
395             /* Lennard-Jones interaction */
396             rinvsix           = rinvsq*rinvsq*rinvsq;
397             Vvdw6             = c6*rinvsix;     
398             Vvdw12            = c12*rinvsix*rinvsix;
399             Vvdwtot           = Vvdwtot+Vvdw12-Vvdw6;
400             fscal             = (vcoul+12.0*Vvdw12-6.0*Vvdw6)*rinvsq;
401                         
402             /* Calculate temporary vectorial force */
403             tx                = fscal*dx;     
404             ty                = fscal*dy;     
405             tz                = fscal*dz;     
406             
407             /* Increment i atom force */
408             fix               = fix + tx;      
409             fiy               = fiy + ty;      
410             fiz               = fiz + tz;      
411
412             /* Decrement j atom force */
413             f[3*k]            = f[3*k]   - tx;
414             f[3*k+1]          = f[3*k+1] - ty;
415             f[3*k+2]          = f[3*k+2] - tz;
416             
417             /* Inner loop uses 38 flops/iteration */
418         }
419         
420         f[3*i]   += fix;
421         f[3*i+1] += fiy;
422         f[3*i+2] += fiz;
423                 
424                 /* Add potential energies to the group for this list */
425                 ggid             = 0;         
426         
427                 Vc[ggid]         = Vc[ggid] + vctot;
428         Vvdw[ggid]       = Vvdw[ggid] + Vvdwtot;
429                 
430                 /* Outer loop uses 6 flops/iteration */
431         }    
432       
433     /* Write outer/inner iteration count to pointers */
434     *outeriter       = ni1-ni0;         
435     *inneriter       = (ni1-ni0)*natoms/2;         
436 }
437
438