Update copyright statements and change license to LGPL
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team,
6  * check out http://www.gromacs.org for more information.
7  * Copyright (c) 2012, by the GROMACS development team, led by
8  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
9  * others, as listed in the AUTHORS file in the top-level source
10  * directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* IMPORTANT FOR DEVELOPERS:
39  *
40  * Triclinic pme stuff isn't entirely trivial, and we've experienced
41  * some bugs during development (many of them due to me). To avoid
42  * this in the future, please check the following things if you make
43  * changes in this file:
44  *
45  * 1. You should obtain identical (at least to the PME precision)
46  *    energies, forces, and virial for
47  *    a rectangular box and a triclinic one where the z (or y) axis is
48  *    tilted a whole box side. For instance you could use these boxes:
49  *
50  *    rectangular       triclinic
51  *     2  0  0           2  0  0
52  *     0  2  0           0  2  0
53  *     0  0  6           2  2  6
54  *
55  * 2. You should check the energy conservation in a triclinic box.
56  *
57  * It might seem an overkill, but better safe than sorry.
58  * /Erik 001109
59  */
60
61 #ifdef HAVE_CONFIG_H
62 #include <config.h>
63 #endif
64
65 #ifdef GMX_LIB_MPI
66 #include <mpi.h>
67 #endif
68 #ifdef GMX_THREAD_MPI
69 #include "tmpi.h"
70 #endif
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93 #include "gmx_omp.h"
94
95 /* Single precision, with SSE2 or higher available */
96 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
97
98 #include "gmx_x86_simd_single.h"
99
100 #define PME_SSE
101 /* Some old AMD processors could have problems with unaligned loads+stores */
102 #ifndef GMX_FAHCORE
103 #define PME_SSE_UNALIGNED
104 #endif
105 #endif
106
107 #include "mpelogging.h"
108
109 #define DFT_TOL 1e-7
110 /* #define PRT_FORCE */
111 /* conditions for on the fly time-measurement */
112 /* #define TAKETIME (step > 1 && timesteps < 10) */
113 #define TAKETIME FALSE
114
115 /* #define PME_TIME_THREADS */
116
117 #ifdef GMX_DOUBLE
118 #define mpi_type MPI_DOUBLE
119 #else
120 #define mpi_type MPI_FLOAT
121 #endif
122
123 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
124 #define GMX_CACHE_SEP 64
125
126 /* We only define a maximum to be able to use local arrays without allocation.
127  * An order larger than 12 should never be needed, even for test cases.
128  * If needed it can be changed here.
129  */
130 #define PME_ORDER_MAX 12
131
132 /* Internal datastructures */
133 typedef struct {
134     int send_index0;
135     int send_nindex;
136     int recv_index0;
137     int recv_nindex;
138     int recv_size;   /* Receive buffer width, used with OpenMP */
139 } pme_grid_comm_t;
140
141 typedef struct {
142 #ifdef GMX_MPI
143     MPI_Comm mpi_comm;
144 #endif
145     int  nnodes,nodeid;
146     int  *s2g0;
147     int  *s2g1;
148     int  noverlap_nodes;
149     int  *send_id,*recv_id;
150     int  send_size;             /* Send buffer width, used with OpenMP */
151     pme_grid_comm_t *comm_data;
152     real *sendbuf;
153     real *recvbuf;
154 } pme_overlap_t;
155
156 typedef struct {
157     int *n;     /* Cumulative counts of the number of particles per thread */
158     int nalloc; /* Allocation size of i */
159     int *i;     /* Particle indices ordered on thread index (n) */
160 } thread_plist_t;
161
162 typedef struct {
163     int  *thread_one;
164     int  n;
165     int  *ind;
166     splinevec theta;
167     real *ptr_theta_z;
168     splinevec dtheta;
169     real *ptr_dtheta_z;
170 } splinedata_t;
171
172 typedef struct {
173     int  dimind;            /* The index of the dimension, 0=x, 1=y */
174     int  nslab;
175     int  nodeid;
176 #ifdef GMX_MPI
177     MPI_Comm mpi_comm;
178 #endif
179
180     int  *node_dest;        /* The nodes to send x and q to with DD */
181     int  *node_src;         /* The nodes to receive x and q from with DD */
182     int  *buf_index;        /* Index for commnode into the buffers */
183
184     int  maxshift;
185
186     int  npd;
187     int  pd_nalloc;
188     int  *pd;
189     int  *count;            /* The number of atoms to send to each node */
190     int  **count_thread;
191     int  *rcount;           /* The number of atoms to receive */
192
193     int  n;
194     int  nalloc;
195     rvec *x;
196     real *q;
197     rvec *f;
198     gmx_bool bSpread;       /* These coordinates are used for spreading */
199     int  pme_order;
200     ivec *idx;
201     rvec *fractx;            /* Fractional coordinate relative to the
202                               * lower cell boundary
203                               */
204     int  nthread;
205     int  *thread_idx;        /* Which thread should spread which charge */
206     thread_plist_t *thread_plist;
207     splinedata_t *spline;
208 } pme_atomcomm_t;
209
210 #define FLBS  3
211 #define FLBSZ 4
212
213 typedef struct {
214     ivec ci;     /* The spatial location of this grid         */
215     ivec n;      /* The used size of *grid, including order-1 */
216     ivec offset; /* The grid offset from the full node grid   */
217     int  order;  /* PME spreading order                       */
218     ivec s;      /* The allocated size of *grid, s >= n       */
219     real *grid;  /* The grid local thread, size n             */
220 } pmegrid_t;
221
222 typedef struct {
223     pmegrid_t grid;     /* The full node grid (non thread-local)            */
224     int  nthread;       /* The number of threads operating on this grid     */
225     ivec nc;            /* The local spatial decomposition over the threads */
226     pmegrid_t *grid_th; /* Array of grids for each thread                   */
227     real *grid_all;     /* Allocated array for the grids in *grid_th        */
228     int  **g2t;         /* The grid to thread index                         */
229     ivec nthread_comm;  /* The number of threads to communicate with        */
230 } pmegrids_t;
231
232
233 typedef struct {
234 #ifdef PME_SSE
235     /* Masks for SSE aligned spreading and gathering */
236     __m128 mask_SSE0[6],mask_SSE1[6];
237 #else
238     int dummy; /* C89 requires that struct has at least one member */
239 #endif
240 } pme_spline_work_t;
241
242 typedef struct {
243     /* work data for solve_pme */
244     int      nalloc;
245     real *   mhx;
246     real *   mhy;
247     real *   mhz;
248     real *   m2;
249     real *   denom;
250     real *   tmp1_alloc;
251     real *   tmp1;
252     real *   eterm;
253     real *   m2inv;
254
255     real     energy;
256     matrix   vir;
257 } pme_work_t;
258
259 typedef struct gmx_pme {
260     int  ndecompdim;         /* The number of decomposition dimensions */
261     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
262     int  nodeid_major;
263     int  nodeid_minor;
264     int  nnodes;             /* The number of nodes doing PME */
265     int  nnodes_major;
266     int  nnodes_minor;
267
268     MPI_Comm mpi_comm;
269     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
270 #ifdef GMX_MPI
271     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
272 #endif
273
274     int  nthread;            /* The number of threads doing PME */
275
276     gmx_bool bPPnode;        /* Node also does particle-particle forces */
277     gmx_bool bFEP;           /* Compute Free energy contribution */
278     int nkx,nky,nkz;         /* Grid dimensions */
279     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
280     int pme_order;
281     real epsilon_r;
282
283     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
284     pmegrids_t pmegridB;
285     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
286     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
287     /* pmegrid_nz might be larger than strictly necessary to ensure
288      * memory alignment, pmegrid_nz_base gives the real base size.
289      */
290     int     pmegrid_nz_base;
291     /* The local PME grid starting indices */
292     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
293
294     /* Work data for spreading and gathering */
295     pme_spline_work_t *spline_work;
296
297     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
298     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
299     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
300
301     t_complex *cfftgridA;             /* Grids for complex FFT data */
302     t_complex *cfftgridB;
303     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
304
305     gmx_parallel_3dfft_t  pfft_setupA;
306     gmx_parallel_3dfft_t  pfft_setupB;
307
308     int  *nnx,*nny,*nnz;
309     real *fshx,*fshy,*fshz;
310
311     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
312     matrix    recipbox;
313     splinevec bsp_mod;
314
315     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
316
317     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
318
319     rvec *bufv;             /* Communication buffer */
320     real *bufr;             /* Communication buffer */
321     int  buf_nalloc;        /* The communication buffer size */
322
323     /* thread local work data for solve_pme */
324     pme_work_t *work;
325
326     /* Work data for PME_redist */
327     gmx_bool redist_init;
328     int *    scounts;
329     int *    rcounts;
330     int *    sdispls;
331     int *    rdispls;
332     int *    sidx;
333     int *    idxa;
334     real *   redist_buf;
335     int      redist_buf_nalloc;
336
337     /* Work data for sum_qgrid */
338     real *   sum_qgrid_tmp;
339     real *   sum_qgrid_dd_tmp;
340 } t_gmx_pme;
341
342
343 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
344                                    int start,int end,int thread)
345 {
346     int  i;
347     int  *idxptr,tix,tiy,tiz;
348     real *xptr,*fptr,tx,ty,tz;
349     real rxx,ryx,ryy,rzx,rzy,rzz;
350     int  nx,ny,nz;
351     int  start_ix,start_iy,start_iz;
352     int  *g2tx,*g2ty,*g2tz;
353     gmx_bool bThreads;
354     int  *thread_idx=NULL;
355     thread_plist_t *tpl=NULL;
356     int  *tpl_n=NULL;
357     int  thread_i;
358
359     nx  = pme->nkx;
360     ny  = pme->nky;
361     nz  = pme->nkz;
362
363     start_ix = pme->pmegrid_start_ix;
364     start_iy = pme->pmegrid_start_iy;
365     start_iz = pme->pmegrid_start_iz;
366
367     rxx = pme->recipbox[XX][XX];
368     ryx = pme->recipbox[YY][XX];
369     ryy = pme->recipbox[YY][YY];
370     rzx = pme->recipbox[ZZ][XX];
371     rzy = pme->recipbox[ZZ][YY];
372     rzz = pme->recipbox[ZZ][ZZ];
373
374     g2tx = pme->pmegridA.g2t[XX];
375     g2ty = pme->pmegridA.g2t[YY];
376     g2tz = pme->pmegridA.g2t[ZZ];
377
378     bThreads = (atc->nthread > 1);
379     if (bThreads)
380     {
381         thread_idx = atc->thread_idx;
382
383         tpl   = &atc->thread_plist[thread];
384         tpl_n = tpl->n;
385         for(i=0; i<atc->nthread; i++)
386         {
387             tpl_n[i] = 0;
388         }
389     }
390
391     for(i=start; i<end; i++) {
392         xptr   = atc->x[i];
393         idxptr = atc->idx[i];
394         fptr   = atc->fractx[i];
395
396         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
397         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
398         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
399         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
400
401         tix = (int)(tx);
402         tiy = (int)(ty);
403         tiz = (int)(tz);
404
405         /* Because decomposition only occurs in x and y,
406          * we never have a fraction correction in z.
407          */
408         fptr[XX] = tx - tix + pme->fshx[tix];
409         fptr[YY] = ty - tiy + pme->fshy[tiy];
410         fptr[ZZ] = tz - tiz;
411
412         idxptr[XX] = pme->nnx[tix];
413         idxptr[YY] = pme->nny[tiy];
414         idxptr[ZZ] = pme->nnz[tiz];
415
416 #ifdef DEBUG
417         range_check(idxptr[XX],0,pme->pmegrid_nx);
418         range_check(idxptr[YY],0,pme->pmegrid_ny);
419         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
420 #endif
421
422         if (bThreads)
423         {
424             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
425             thread_idx[i] = thread_i;
426             tpl_n[thread_i]++;
427         }
428     }
429
430     if (bThreads)
431     {
432         /* Make a list of particle indices sorted on thread */
433
434         /* Get the cumulative count */
435         for(i=1; i<atc->nthread; i++)
436         {
437             tpl_n[i] += tpl_n[i-1];
438         }
439         /* The current implementation distributes particles equally
440          * over the threads, so we could actually allocate for that
441          * in pme_realloc_atomcomm_things.
442          */
443         if (tpl_n[atc->nthread-1] > tpl->nalloc)
444         {
445             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
446             srenew(tpl->i,tpl->nalloc);
447         }
448         /* Set tpl_n to the cumulative start */
449         for(i=atc->nthread-1; i>=1; i--)
450         {
451             tpl_n[i] = tpl_n[i-1];
452         }
453         tpl_n[0] = 0;
454
455         /* Fill our thread local array with indices sorted on thread */
456         for(i=start; i<end; i++)
457         {
458             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
459         }
460         /* Now tpl_n contains the cummulative count again */
461     }
462 }
463
464 static void make_thread_local_ind(pme_atomcomm_t *atc,
465                                   int thread,splinedata_t *spline)
466 {
467     int  n,t,i,start,end;
468     thread_plist_t *tpl;
469
470     /* Combine the indices made by each thread into one index */
471
472     n = 0;
473     start = 0;
474     for(t=0; t<atc->nthread; t++)
475     {
476         tpl = &atc->thread_plist[t];
477         /* Copy our part (start - end) from the list of thread t */
478         if (thread > 0)
479         {
480             start = tpl->n[thread-1];
481         }
482         end = tpl->n[thread];
483         for(i=start; i<end; i++)
484         {
485             spline->ind[n++] = tpl->i[i];
486         }
487     }
488
489     spline->n = n;
490 }
491
492
493 static void pme_calc_pidx(int start, int end,
494                           matrix recipbox, rvec x[],
495                           pme_atomcomm_t *atc, int *count)
496 {
497     int  nslab,i;
498     int  si;
499     real *xptr,s;
500     real rxx,ryx,rzx,ryy,rzy;
501     int *pd;
502
503     /* Calculate PME task index (pidx) for each grid index.
504      * Here we always assign equally sized slabs to each node
505      * for load balancing reasons (the PME grid spacing is not used).
506      */
507
508     nslab = atc->nslab;
509     pd    = atc->pd;
510
511     /* Reset the count */
512     for(i=0; i<nslab; i++)
513     {
514         count[i] = 0;
515     }
516
517     if (atc->dimind == 0)
518     {
519         rxx = recipbox[XX][XX];
520         ryx = recipbox[YY][XX];
521         rzx = recipbox[ZZ][XX];
522         /* Calculate the node index in x-dimension */
523         for(i=start; i<end; i++)
524         {
525             xptr   = x[i];
526             /* Fractional coordinates along box vectors */
527             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
528             si = (int)(s + 2*nslab) % nslab;
529             pd[i] = si;
530             count[si]++;
531         }
532     }
533     else
534     {
535         ryy = recipbox[YY][YY];
536         rzy = recipbox[ZZ][YY];
537         /* Calculate the node index in y-dimension */
538         for(i=start; i<end; i++)
539         {
540             xptr   = x[i];
541             /* Fractional coordinates along box vectors */
542             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
543             si = (int)(s + 2*nslab) % nslab;
544             pd[i] = si;
545             count[si]++;
546         }
547     }
548 }
549
550 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
551                                   pme_atomcomm_t *atc)
552 {
553     int nthread,thread,slab;
554
555     nthread = atc->nthread;
556
557 #pragma omp parallel for num_threads(nthread) schedule(static)
558     for(thread=0; thread<nthread; thread++)
559     {
560         pme_calc_pidx(natoms* thread   /nthread,
561                       natoms*(thread+1)/nthread,
562                       recipbox,x,atc,atc->count_thread[thread]);
563     }
564     /* Non-parallel reduction, since nslab is small */
565
566     for(thread=1; thread<nthread; thread++)
567     {
568         for(slab=0; slab<atc->nslab; slab++)
569         {
570             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
571         }
572     }
573 }
574
575 static void realloc_splinevec(splinevec th,real **ptr_z,int nalloc)
576 {
577     const int padding=4;
578     int i;
579
580     srenew(th[XX],nalloc);
581     srenew(th[YY],nalloc);
582     /* In z we add padding, this is only required for the aligned SSE code */
583     srenew(*ptr_z,nalloc+2*padding);
584     th[ZZ] = *ptr_z + padding;
585
586     for(i=0; i<padding; i++)
587     {
588         (*ptr_z)[               i] = 0;
589         (*ptr_z)[padding+nalloc+i] = 0;
590     }
591 }
592
593 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
594 {
595     int i,d;
596
597     srenew(spline->ind,atc->nalloc);
598     /* Initialize the index to identity so it works without threads */
599     for(i=0; i<atc->nalloc; i++)
600     {
601         spline->ind[i] = i;
602     }
603
604     realloc_splinevec(spline->theta,&spline->ptr_theta_z,
605                       atc->pme_order*atc->nalloc);
606     realloc_splinevec(spline->dtheta,&spline->ptr_dtheta_z,
607                       atc->pme_order*atc->nalloc);
608 }
609
610 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
611 {
612     int nalloc_old,i,j,nalloc_tpl;
613
614     /* We have to avoid a NULL pointer for atc->x to avoid
615      * possible fatal errors in MPI routines.
616      */
617     if (atc->n > atc->nalloc || atc->nalloc == 0)
618     {
619         nalloc_old = atc->nalloc;
620         atc->nalloc = over_alloc_dd(max(atc->n,1));
621
622         if (atc->nslab > 1) {
623             srenew(atc->x,atc->nalloc);
624             srenew(atc->q,atc->nalloc);
625             srenew(atc->f,atc->nalloc);
626             for(i=nalloc_old; i<atc->nalloc; i++)
627             {
628                 clear_rvec(atc->f[i]);
629             }
630         }
631         if (atc->bSpread) {
632             srenew(atc->fractx,atc->nalloc);
633             srenew(atc->idx   ,atc->nalloc);
634
635             if (atc->nthread > 1)
636             {
637                 srenew(atc->thread_idx,atc->nalloc);
638             }
639
640             for(i=0; i<atc->nthread; i++)
641             {
642                 pme_realloc_splinedata(&atc->spline[i],atc);
643             }
644         }
645     }
646 }
647
648 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
649                          int n, gmx_bool bXF, rvec *x_f, real *charge,
650                          pme_atomcomm_t *atc)
651 /* Redistribute particle data for PME calculation */
652 /* domain decomposition by x coordinate           */
653 {
654     int *idxa;
655     int i, ii;
656
657     if(FALSE == pme->redist_init) {
658         snew(pme->scounts,atc->nslab);
659         snew(pme->rcounts,atc->nslab);
660         snew(pme->sdispls,atc->nslab);
661         snew(pme->rdispls,atc->nslab);
662         snew(pme->sidx,atc->nslab);
663         pme->redist_init = TRUE;
664     }
665     if (n > pme->redist_buf_nalloc) {
666         pme->redist_buf_nalloc = over_alloc_dd(n);
667         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
668     }
669
670     pme->idxa = atc->pd;
671
672 #ifdef GMX_MPI
673     if (forw && bXF) {
674         /* forward, redistribution from pp to pme */
675
676         /* Calculate send counts and exchange them with other nodes */
677         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
678         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
679         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
680
681         /* Calculate send and receive displacements and index into send
682            buffer */
683         pme->sdispls[0]=0;
684         pme->rdispls[0]=0;
685         pme->sidx[0]=0;
686         for(i=1; i<atc->nslab; i++) {
687             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
688             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
689             pme->sidx[i]=pme->sdispls[i];
690         }
691         /* Total # of particles to be received */
692         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
693
694         pme_realloc_atomcomm_things(atc);
695
696         /* Copy particle coordinates into send buffer and exchange*/
697         for(i=0; (i<n); i++) {
698             ii=DIM*pme->sidx[pme->idxa[i]];
699             pme->sidx[pme->idxa[i]]++;
700             pme->redist_buf[ii+XX]=x_f[i][XX];
701             pme->redist_buf[ii+YY]=x_f[i][YY];
702             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
703         }
704         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
705                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
706                       pme->rvec_mpi, atc->mpi_comm);
707     }
708     if (forw) {
709         /* Copy charge into send buffer and exchange*/
710         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
711         for(i=0; (i<n); i++) {
712             ii=pme->sidx[pme->idxa[i]];
713             pme->sidx[pme->idxa[i]]++;
714             pme->redist_buf[ii]=charge[i];
715         }
716         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
717                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
718                       atc->mpi_comm);
719     }
720     else { /* backward, redistribution from pme to pp */
721         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
722                       pme->redist_buf, pme->scounts, pme->sdispls,
723                       pme->rvec_mpi, atc->mpi_comm);
724
725         /* Copy data from receive buffer */
726         for(i=0; i<atc->nslab; i++)
727             pme->sidx[i] = pme->sdispls[i];
728         for(i=0; (i<n); i++) {
729             ii = DIM*pme->sidx[pme->idxa[i]];
730             x_f[i][XX] += pme->redist_buf[ii+XX];
731             x_f[i][YY] += pme->redist_buf[ii+YY];
732             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
733             pme->sidx[pme->idxa[i]]++;
734         }
735     }
736 #endif
737 }
738
739 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
740                             gmx_bool bBackward,int shift,
741                             void *buf_s,int nbyte_s,
742                             void *buf_r,int nbyte_r)
743 {
744 #ifdef GMX_MPI
745     int dest,src;
746     MPI_Status stat;
747
748     if (bBackward == FALSE) {
749         dest = atc->node_dest[shift];
750         src  = atc->node_src[shift];
751     } else {
752         dest = atc->node_src[shift];
753         src  = atc->node_dest[shift];
754     }
755
756     if (nbyte_s > 0 && nbyte_r > 0) {
757         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
758                      dest,shift,
759                      buf_r,nbyte_r,MPI_BYTE,
760                      src,shift,
761                      atc->mpi_comm,&stat);
762     } else if (nbyte_s > 0) {
763         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
764                  dest,shift,
765                  atc->mpi_comm);
766     } else if (nbyte_r > 0) {
767         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
768                  src,shift,
769                  atc->mpi_comm,&stat);
770     }
771 #endif
772 }
773
774 static void dd_pmeredist_x_q(gmx_pme_t pme,
775                              int n, gmx_bool bX, rvec *x, real *charge,
776                              pme_atomcomm_t *atc)
777 {
778     int *commnode,*buf_index;
779     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
780
781     commnode  = atc->node_dest;
782     buf_index = atc->buf_index;
783
784     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
785
786     nsend = 0;
787     for(i=0; i<nnodes_comm; i++) {
788         buf_index[commnode[i]] = nsend;
789         nsend += atc->count[commnode[i]];
790     }
791     if (bX) {
792         if (atc->count[atc->nodeid] + nsend != n)
793             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
794                       "This usually means that your system is not well equilibrated.",
795                       n - (atc->count[atc->nodeid] + nsend),
796                       pme->nodeid,'x'+atc->dimind);
797
798         if (nsend > pme->buf_nalloc) {
799             pme->buf_nalloc = over_alloc_dd(nsend);
800             srenew(pme->bufv,pme->buf_nalloc);
801             srenew(pme->bufr,pme->buf_nalloc);
802         }
803
804         atc->n = atc->count[atc->nodeid];
805         for(i=0; i<nnodes_comm; i++) {
806             scount = atc->count[commnode[i]];
807             /* Communicate the count */
808             if (debug)
809                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
810                         atc->dimind,atc->nodeid,commnode[i],scount);
811             pme_dd_sendrecv(atc,FALSE,i,
812                             &scount,sizeof(int),
813                             &atc->rcount[i],sizeof(int));
814             atc->n += atc->rcount[i];
815         }
816
817         pme_realloc_atomcomm_things(atc);
818     }
819
820     local_pos = 0;
821     for(i=0; i<n; i++) {
822         node = atc->pd[i];
823         if (node == atc->nodeid) {
824             /* Copy direct to the receive buffer */
825             if (bX) {
826                 copy_rvec(x[i],atc->x[local_pos]);
827             }
828             atc->q[local_pos] = charge[i];
829             local_pos++;
830         } else {
831             /* Copy to the send buffer */
832             if (bX) {
833                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
834             }
835             pme->bufr[buf_index[node]] = charge[i];
836             buf_index[node]++;
837         }
838     }
839
840     buf_pos = 0;
841     for(i=0; i<nnodes_comm; i++) {
842         scount = atc->count[commnode[i]];
843         rcount = atc->rcount[i];
844         if (scount > 0 || rcount > 0) {
845             if (bX) {
846                 /* Communicate the coordinates */
847                 pme_dd_sendrecv(atc,FALSE,i,
848                                 pme->bufv[buf_pos],scount*sizeof(rvec),
849                                 atc->x[local_pos],rcount*sizeof(rvec));
850             }
851             /* Communicate the charges */
852             pme_dd_sendrecv(atc,FALSE,i,
853                             pme->bufr+buf_pos,scount*sizeof(real),
854                             atc->q+local_pos,rcount*sizeof(real));
855             buf_pos   += scount;
856             local_pos += atc->rcount[i];
857         }
858     }
859 }
860
861 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
862                            int n, rvec *f,
863                            gmx_bool bAddF)
864 {
865   int *commnode,*buf_index;
866   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
867
868   commnode  = atc->node_dest;
869   buf_index = atc->buf_index;
870
871   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
872
873   local_pos = atc->count[atc->nodeid];
874   buf_pos = 0;
875   for(i=0; i<nnodes_comm; i++) {
876     scount = atc->rcount[i];
877     rcount = atc->count[commnode[i]];
878     if (scount > 0 || rcount > 0) {
879       /* Communicate the forces */
880       pme_dd_sendrecv(atc,TRUE,i,
881                       atc->f[local_pos],scount*sizeof(rvec),
882                       pme->bufv[buf_pos],rcount*sizeof(rvec));
883       local_pos += scount;
884     }
885     buf_index[commnode[i]] = buf_pos;
886     buf_pos   += rcount;
887   }
888
889     local_pos = 0;
890     if (bAddF)
891     {
892         for(i=0; i<n; i++)
893         {
894             node = atc->pd[i];
895             if (node == atc->nodeid)
896             {
897                 /* Add from the local force array */
898                 rvec_inc(f[i],atc->f[local_pos]);
899                 local_pos++;
900             }
901             else
902             {
903                 /* Add from the receive buffer */
904                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
905                 buf_index[node]++;
906             }
907         }
908     }
909     else
910     {
911         for(i=0; i<n; i++)
912         {
913             node = atc->pd[i];
914             if (node == atc->nodeid)
915             {
916                 /* Copy from the local force array */
917                 copy_rvec(atc->f[local_pos],f[i]);
918                 local_pos++;
919             }
920             else
921             {
922                 /* Copy from the receive buffer */
923                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
924                 buf_index[node]++;
925             }
926         }
927     }
928 }
929
930 #ifdef GMX_MPI
931 static void
932 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
933 {
934     pme_overlap_t *overlap;
935     int send_index0,send_nindex;
936     int recv_index0,recv_nindex;
937     MPI_Status stat;
938     int i,j,k,ix,iy,iz,icnt;
939     int ipulse,send_id,recv_id,datasize;
940     real *p;
941     real *sendptr,*recvptr;
942
943     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
944     overlap = &pme->overlap[1];
945
946     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
947     {
948         /* Since we have already (un)wrapped the overlap in the z-dimension,
949          * we only have to communicate 0 to nkz (not pmegrid_nz).
950          */
951         if (direction==GMX_SUM_QGRID_FORWARD)
952         {
953             send_id = overlap->send_id[ipulse];
954             recv_id = overlap->recv_id[ipulse];
955             send_index0   = overlap->comm_data[ipulse].send_index0;
956             send_nindex   = overlap->comm_data[ipulse].send_nindex;
957             recv_index0   = overlap->comm_data[ipulse].recv_index0;
958             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
959         }
960         else
961         {
962             send_id = overlap->recv_id[ipulse];
963             recv_id = overlap->send_id[ipulse];
964             send_index0   = overlap->comm_data[ipulse].recv_index0;
965             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
966             recv_index0   = overlap->comm_data[ipulse].send_index0;
967             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
968         }
969
970         /* Copy data to contiguous send buffer */
971         if (debug)
972         {
973             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
974                     pme->nodeid,overlap->nodeid,send_id,
975                     pme->pmegrid_start_iy,
976                     send_index0-pme->pmegrid_start_iy,
977                     send_index0-pme->pmegrid_start_iy+send_nindex);
978         }
979         icnt = 0;
980         for(i=0;i<pme->pmegrid_nx;i++)
981         {
982             ix = i;
983             for(j=0;j<send_nindex;j++)
984             {
985                 iy = j + send_index0 - pme->pmegrid_start_iy;
986                 for(k=0;k<pme->nkz;k++)
987                 {
988                     iz = k;
989                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
990                 }
991             }
992         }
993
994         datasize      = pme->pmegrid_nx * pme->nkz;
995
996         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
997                      send_id,ipulse,
998                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
999                      recv_id,ipulse,
1000                      overlap->mpi_comm,&stat);
1001
1002         /* Get data from contiguous recv buffer */
1003         if (debug)
1004         {
1005             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1006                     pme->nodeid,overlap->nodeid,recv_id,
1007                     pme->pmegrid_start_iy,
1008                     recv_index0-pme->pmegrid_start_iy,
1009                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1010         }
1011         icnt = 0;
1012         for(i=0;i<pme->pmegrid_nx;i++)
1013         {
1014             ix = i;
1015             for(j=0;j<recv_nindex;j++)
1016             {
1017                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1018                 for(k=0;k<pme->nkz;k++)
1019                 {
1020                     iz = k;
1021                     if(direction==GMX_SUM_QGRID_FORWARD)
1022                     {
1023                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1024                     }
1025                     else
1026                     {
1027                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1028                     }
1029                 }
1030             }
1031         }
1032     }
1033
1034     /* Major dimension is easier, no copying required,
1035      * but we might have to sum to separate array.
1036      * Since we don't copy, we have to communicate up to pmegrid_nz,
1037      * not nkz as for the minor direction.
1038      */
1039     overlap = &pme->overlap[0];
1040
1041     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1042     {
1043         if(direction==GMX_SUM_QGRID_FORWARD)
1044         {
1045             send_id = overlap->send_id[ipulse];
1046             recv_id = overlap->recv_id[ipulse];
1047             send_index0   = overlap->comm_data[ipulse].send_index0;
1048             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1049             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1050             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1051             recvptr   = overlap->recvbuf;
1052         }
1053         else
1054         {
1055             send_id = overlap->recv_id[ipulse];
1056             recv_id = overlap->send_id[ipulse];
1057             send_index0   = overlap->comm_data[ipulse].recv_index0;
1058             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1059             recv_index0   = overlap->comm_data[ipulse].send_index0;
1060             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1061             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1062         }
1063
1064         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1065         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1066
1067         if (debug)
1068         {
1069             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1070                     pme->nodeid,overlap->nodeid,send_id,
1071                     pme->pmegrid_start_ix,
1072                     send_index0-pme->pmegrid_start_ix,
1073                     send_index0-pme->pmegrid_start_ix+send_nindex);
1074             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1075                     pme->nodeid,overlap->nodeid,recv_id,
1076                     pme->pmegrid_start_ix,
1077                     recv_index0-pme->pmegrid_start_ix,
1078                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1079         }
1080
1081         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1082                      send_id,ipulse,
1083                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1084                      recv_id,ipulse,
1085                      overlap->mpi_comm,&stat);
1086
1087         /* ADD data from contiguous recv buffer */
1088         if(direction==GMX_SUM_QGRID_FORWARD)
1089         {
1090             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1091             for(i=0;i<recv_nindex*datasize;i++)
1092             {
1093                 p[i] += overlap->recvbuf[i];
1094             }
1095         }
1096     }
1097 }
1098 #endif
1099
1100
1101 static int
1102 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1103 {
1104     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1105     ivec    local_pme_size;
1106     int     i,ix,iy,iz;
1107     int     pmeidx,fftidx;
1108
1109     /* Dimensions should be identical for A/B grid, so we just use A here */
1110     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1111                                    local_fft_ndata,
1112                                    local_fft_offset,
1113                                    local_fft_size);
1114
1115     local_pme_size[0] = pme->pmegrid_nx;
1116     local_pme_size[1] = pme->pmegrid_ny;
1117     local_pme_size[2] = pme->pmegrid_nz;
1118
1119     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1120      the offset is identical, and the PME grid always has more data (due to overlap)
1121      */
1122     {
1123 #ifdef DEBUG_PME
1124         FILE *fp,*fp2;
1125         char fn[STRLEN],format[STRLEN];
1126         real val;
1127         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1128         fp = ffopen(fn,"w");
1129         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1130         fp2 = ffopen(fn,"w");
1131      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1132 #endif
1133
1134     for(ix=0;ix<local_fft_ndata[XX];ix++)
1135     {
1136         for(iy=0;iy<local_fft_ndata[YY];iy++)
1137         {
1138             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1139             {
1140                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1141                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1142                 fftgrid[fftidx] = pmegrid[pmeidx];
1143 #ifdef DEBUG_PME
1144                 val = 100*pmegrid[pmeidx];
1145                 if (pmegrid[pmeidx] != 0)
1146                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1147                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1148                 if (pmegrid[pmeidx] != 0)
1149                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1150                             "qgrid",
1151                             pme->pmegrid_start_ix + ix,
1152                             pme->pmegrid_start_iy + iy,
1153                             pme->pmegrid_start_iz + iz,
1154                             pmegrid[pmeidx]);
1155 #endif
1156             }
1157         }
1158     }
1159 #ifdef DEBUG_PME
1160     ffclose(fp);
1161     ffclose(fp2);
1162 #endif
1163     }
1164     return 0;
1165 }
1166
1167
1168 static gmx_cycles_t omp_cyc_start()
1169 {
1170     return gmx_cycles_read();
1171 }
1172
1173 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1174 {
1175     return gmx_cycles_read() - c;
1176 }
1177
1178
1179 static int
1180 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1181                         int nthread,int thread)
1182 {
1183     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1184     ivec    local_pme_size;
1185     int     ixy0,ixy1,ixy,ix,iy,iz;
1186     int     pmeidx,fftidx;
1187 #ifdef PME_TIME_THREADS
1188     gmx_cycles_t c1;
1189     static double cs1=0;
1190     static int cnt=0;
1191 #endif
1192
1193 #ifdef PME_TIME_THREADS
1194     c1 = omp_cyc_start();
1195 #endif
1196     /* Dimensions should be identical for A/B grid, so we just use A here */
1197     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1198                                    local_fft_ndata,
1199                                    local_fft_offset,
1200                                    local_fft_size);
1201
1202     local_pme_size[0] = pme->pmegrid_nx;
1203     local_pme_size[1] = pme->pmegrid_ny;
1204     local_pme_size[2] = pme->pmegrid_nz;
1205
1206     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1207      the offset is identical, and the PME grid always has more data (due to overlap)
1208      */
1209     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1210     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1211
1212     for(ixy=ixy0;ixy<ixy1;ixy++)
1213     {
1214         ix = ixy/local_fft_ndata[YY];
1215         iy = ixy - ix*local_fft_ndata[YY];
1216
1217         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1218         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1219         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1220         {
1221             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1222         }
1223     }
1224
1225 #ifdef PME_TIME_THREADS
1226     c1 = omp_cyc_end(c1);
1227     cs1 += (double)c1;
1228     cnt++;
1229     if (cnt % 20 == 0)
1230     {
1231         printf("copy %.2f\n",cs1*1e-9);
1232     }
1233 #endif
1234
1235     return 0;
1236 }
1237
1238
1239 static void
1240 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1241 {
1242     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1243
1244     nx = pme->nkx;
1245     ny = pme->nky;
1246     nz = pme->nkz;
1247
1248     pnx = pme->pmegrid_nx;
1249     pny = pme->pmegrid_ny;
1250     pnz = pme->pmegrid_nz;
1251
1252     overlap = pme->pme_order - 1;
1253
1254     /* Add periodic overlap in z */
1255     for(ix=0; ix<pme->pmegrid_nx; ix++)
1256     {
1257         for(iy=0; iy<pme->pmegrid_ny; iy++)
1258         {
1259             for(iz=0; iz<overlap; iz++)
1260             {
1261                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1262                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1263             }
1264         }
1265     }
1266
1267     if (pme->nnodes_minor == 1)
1268     {
1269        for(ix=0; ix<pme->pmegrid_nx; ix++)
1270        {
1271            for(iy=0; iy<overlap; iy++)
1272            {
1273                for(iz=0; iz<nz; iz++)
1274                {
1275                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1276                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1277                }
1278            }
1279        }
1280     }
1281
1282     if (pme->nnodes_major == 1)
1283     {
1284         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1285
1286         for(ix=0; ix<overlap; ix++)
1287         {
1288             for(iy=0; iy<ny_x; iy++)
1289             {
1290                 for(iz=0; iz<nz; iz++)
1291                 {
1292                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1293                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1294                 }
1295             }
1296         }
1297     }
1298 }
1299
1300
1301 static void
1302 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1303 {
1304     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1305
1306     nx = pme->nkx;
1307     ny = pme->nky;
1308     nz = pme->nkz;
1309
1310     pnx = pme->pmegrid_nx;
1311     pny = pme->pmegrid_ny;
1312     pnz = pme->pmegrid_nz;
1313
1314     overlap = pme->pme_order - 1;
1315
1316     if (pme->nnodes_major == 1)
1317     {
1318         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1319
1320         for(ix=0; ix<overlap; ix++)
1321         {
1322             int iy,iz;
1323
1324             for(iy=0; iy<ny_x; iy++)
1325             {
1326                 for(iz=0; iz<nz; iz++)
1327                 {
1328                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1329                         pmegrid[(ix*pny+iy)*pnz+iz];
1330                 }
1331             }
1332         }
1333     }
1334
1335     if (pme->nnodes_minor == 1)
1336     {
1337 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1338        for(ix=0; ix<pme->pmegrid_nx; ix++)
1339        {
1340            int iy,iz;
1341
1342            for(iy=0; iy<overlap; iy++)
1343            {
1344                for(iz=0; iz<nz; iz++)
1345                {
1346                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1347                        pmegrid[(ix*pny+iy)*pnz+iz];
1348                }
1349            }
1350        }
1351     }
1352
1353     /* Copy periodic overlap in z */
1354 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1355     for(ix=0; ix<pme->pmegrid_nx; ix++)
1356     {
1357         int iy,iz;
1358
1359         for(iy=0; iy<pme->pmegrid_ny; iy++)
1360         {
1361             for(iz=0; iz<overlap; iz++)
1362             {
1363                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1364                     pmegrid[(ix*pny+iy)*pnz+iz];
1365             }
1366         }
1367     }
1368 }
1369
1370 static void clear_grid(int nx,int ny,int nz,real *grid,
1371                        ivec fs,int *flag,
1372                        int fx,int fy,int fz,
1373                        int order)
1374 {
1375     int nc,ncz;
1376     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1377     int flind;
1378
1379     nc  = 2 + (order - 2)/FLBS;
1380     ncz = 2 + (order - 2)/FLBSZ;
1381
1382     for(fsx=fx; fsx<fx+nc; fsx++)
1383     {
1384         for(fsy=fy; fsy<fy+nc; fsy++)
1385         {
1386             for(fsz=fz; fsz<fz+ncz; fsz++)
1387             {
1388                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1389                 if (flag[flind] == 0)
1390                 {
1391                     gx = fsx*FLBS;
1392                     gy = fsy*FLBS;
1393                     gz = fsz*FLBSZ;
1394                     g0x = (gx*ny + gy)*nz + gz;
1395                     for(x=0; x<FLBS; x++)
1396                     {
1397                         g0y = g0x;
1398                         for(y=0; y<FLBS; y++)
1399                         {
1400                             for(z=0; z<FLBSZ; z++)
1401                             {
1402                                 grid[g0y+z] = 0;
1403                             }
1404                             g0y += nz;
1405                         }
1406                         g0x += ny*nz;
1407                     }
1408
1409                     flag[flind] = 1;
1410                 }
1411             }
1412         }
1413     }
1414 }
1415
1416 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1417 #define DO_BSPLINE(order)                            \
1418 for(ithx=0; (ithx<order); ithx++)                    \
1419 {                                                    \
1420     index_x = (i0+ithx)*pny*pnz;                     \
1421     valx    = qn*thx[ithx];                          \
1422                                                      \
1423     for(ithy=0; (ithy<order); ithy++)                \
1424     {                                                \
1425         valxy    = valx*thy[ithy];                   \
1426         index_xy = index_x+(j0+ithy)*pnz;            \
1427                                                      \
1428         for(ithz=0; (ithz<order); ithz++)            \
1429         {                                            \
1430             index_xyz        = index_xy+(k0+ithz);   \
1431             grid[index_xyz] += valxy*thz[ithz];      \
1432         }                                            \
1433     }                                                \
1434 }
1435
1436
1437 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1438                                      pme_atomcomm_t *atc, splinedata_t *spline,
1439                                      pme_spline_work_t *work)
1440 {
1441
1442     /* spread charges from home atoms to local grid */
1443     real     *grid;
1444     pme_overlap_t *ol;
1445     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1446     int *    idxptr;
1447     int      order,norder,index_x,index_xy,index_xyz;
1448     real     valx,valxy,qn;
1449     real     *thx,*thy,*thz;
1450     int      localsize, bndsize;
1451     int      pnx,pny,pnz,ndatatot;
1452     int      offx,offy,offz;
1453
1454     pnx = pmegrid->s[XX];
1455     pny = pmegrid->s[YY];
1456     pnz = pmegrid->s[ZZ];
1457
1458     offx = pmegrid->offset[XX];
1459     offy = pmegrid->offset[YY];
1460     offz = pmegrid->offset[ZZ];
1461
1462     ndatatot = pnx*pny*pnz;
1463     grid = pmegrid->grid;
1464     for(i=0;i<ndatatot;i++)
1465     {
1466         grid[i] = 0;
1467     }
1468     
1469     order = pmegrid->order;
1470
1471     for(nn=0; nn<spline->n; nn++)
1472     {
1473         n  = spline->ind[nn];
1474         qn = atc->q[n];
1475
1476         if (qn != 0)
1477         {
1478             idxptr = atc->idx[n];
1479             norder = nn*order;
1480
1481             i0   = idxptr[XX] - offx;
1482             j0   = idxptr[YY] - offy;
1483             k0   = idxptr[ZZ] - offz;
1484
1485             thx = spline->theta[XX] + norder;
1486             thy = spline->theta[YY] + norder;
1487             thz = spline->theta[ZZ] + norder;
1488
1489             switch (order) {
1490             case 4:
1491 #ifdef PME_SSE
1492 #ifdef PME_SSE_UNALIGNED
1493 #define PME_SPREAD_SSE_ORDER4
1494 #else
1495 #define PME_SPREAD_SSE_ALIGNED
1496 #define PME_ORDER 4
1497 #endif
1498 #include "pme_sse_single.h"
1499 #else
1500                 DO_BSPLINE(4);
1501 #endif
1502                 break;
1503             case 5:
1504 #ifdef PME_SSE
1505 #define PME_SPREAD_SSE_ALIGNED
1506 #define PME_ORDER 5
1507 #include "pme_sse_single.h"
1508 #else
1509                 DO_BSPLINE(5);
1510 #endif
1511                 break;
1512             default:
1513                 DO_BSPLINE(order);
1514                 break;
1515             }
1516         }
1517     }
1518 }
1519
1520 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1521 {
1522 #ifdef PME_SSE
1523     if (pme_order == 5
1524 #ifndef PME_SSE_UNALIGNED
1525         || pme_order == 4
1526 #endif
1527         )
1528     {
1529         /* Round nz up to a multiple of 4 to ensure alignment */
1530         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1531     }
1532 #endif
1533 }
1534
1535 static void set_gridsize_alignment(int *gridsize,int pme_order)
1536 {
1537 #ifdef PME_SSE
1538 #ifndef PME_SSE_UNALIGNED
1539     if (pme_order == 4)
1540     {
1541         /* Add extra elements to ensured aligned operations do not go
1542          * beyond the allocated grid size.
1543          * Note that for pme_order=5, the pme grid z-size alignment
1544          * ensures that we will not go beyond the grid size.
1545          */
1546          *gridsize += 4;
1547     }
1548 #endif
1549 #endif
1550 }
1551
1552 static void pmegrid_init(pmegrid_t *grid,
1553                          int cx, int cy, int cz,
1554                          int x0, int y0, int z0,
1555                          int x1, int y1, int z1,
1556                          gmx_bool set_alignment,
1557                          int pme_order,
1558                          real *ptr)
1559 {
1560     int nz,gridsize;
1561
1562     grid->ci[XX] = cx;
1563     grid->ci[YY] = cy;
1564     grid->ci[ZZ] = cz;
1565     grid->offset[XX] = x0;
1566     grid->offset[YY] = y0;
1567     grid->offset[ZZ] = z0;
1568     grid->n[XX]      = x1 - x0 + pme_order - 1;
1569     grid->n[YY]      = y1 - y0 + pme_order - 1;
1570     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1571     copy_ivec(grid->n,grid->s);
1572
1573     nz = grid->s[ZZ];
1574     set_grid_alignment(&nz,pme_order);
1575     if (set_alignment)
1576     {
1577         grid->s[ZZ] = nz;
1578     }
1579     else if (nz != grid->s[ZZ])
1580     {
1581         gmx_incons("pmegrid_init call with an unaligned z size");
1582     }
1583
1584     grid->order = pme_order;
1585     if (ptr == NULL)
1586     {
1587         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1588         set_gridsize_alignment(&gridsize,pme_order);
1589         snew_aligned(grid->grid,gridsize,16);
1590     }
1591     else
1592     {
1593         grid->grid = ptr;
1594     }
1595 }
1596
1597 static int div_round_up(int enumerator,int denominator)
1598 {
1599     return (enumerator + denominator - 1)/denominator;
1600 }
1601
1602 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1603                                   ivec nsub)
1604 {
1605     int gsize_opt,gsize;
1606     int nsx,nsy,nsz;
1607     char *env;
1608
1609     gsize_opt = -1;
1610     for(nsx=1; nsx<=nthread; nsx++)
1611     {
1612         if (nthread % nsx == 0)
1613         {
1614             for(nsy=1; nsy<=nthread; nsy++)
1615             {
1616                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1617                 {
1618                     nsz = nthread/(nsx*nsy);
1619
1620                     /* Determine the number of grid points per thread */
1621                     gsize =
1622                         (div_round_up(n[XX],nsx) + ovl)*
1623                         (div_round_up(n[YY],nsy) + ovl)*
1624                         (div_round_up(n[ZZ],nsz) + ovl);
1625
1626                     /* Minimize the number of grids points per thread
1627                      * and, secondarily, the number of cuts in minor dimensions.
1628                      */
1629                     if (gsize_opt == -1 ||
1630                         gsize < gsize_opt ||
1631                         (gsize == gsize_opt &&
1632                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1633                     {
1634                         nsub[XX] = nsx;
1635                         nsub[YY] = nsy;
1636                         nsub[ZZ] = nsz;
1637                         gsize_opt = gsize;
1638                     }
1639                 }
1640             }
1641         }
1642     }
1643
1644     env = getenv("GMX_PME_THREAD_DIVISION");
1645     if (env != NULL)
1646     {
1647         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1648     }
1649
1650     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1651     {
1652         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1653     }
1654 }
1655
1656 static void pmegrids_init(pmegrids_t *grids,
1657                           int nx,int ny,int nz,int nz_base,
1658                           int pme_order,
1659                           int nthread,
1660                           int overlap_x,
1661                           int overlap_y)
1662 {
1663     ivec n,n_base,g0,g1;
1664     int t,x,y,z,d,i,tfac;
1665     int max_comm_lines=-1;
1666
1667     n[XX] = nx - (pme_order - 1);
1668     n[YY] = ny - (pme_order - 1);
1669     n[ZZ] = nz - (pme_order - 1);
1670
1671     copy_ivec(n,n_base);
1672     n_base[ZZ] = nz_base;
1673
1674     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1675                  NULL);
1676
1677     grids->nthread = nthread;
1678
1679     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1680
1681     if (grids->nthread > 1)
1682     {
1683         ivec nst;
1684         int gridsize;
1685
1686         for(d=0; d<DIM; d++)
1687         {
1688             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1689         }
1690         set_grid_alignment(&nst[ZZ],pme_order);
1691
1692         if (debug)
1693         {
1694             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1695                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1696             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1697                     nx,ny,nz,
1698                     nst[XX],nst[YY],nst[ZZ]);
1699         }
1700
1701         snew(grids->grid_th,grids->nthread);
1702         t = 0;
1703         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1704         set_gridsize_alignment(&gridsize,pme_order);
1705         snew_aligned(grids->grid_all,
1706                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1707                      16);
1708
1709         for(x=0; x<grids->nc[XX]; x++)
1710         {
1711             for(y=0; y<grids->nc[YY]; y++)
1712             {
1713                 for(z=0; z<grids->nc[ZZ]; z++)
1714                 {
1715                     pmegrid_init(&grids->grid_th[t],
1716                                  x,y,z,
1717                                  (n[XX]*(x  ))/grids->nc[XX],
1718                                  (n[YY]*(y  ))/grids->nc[YY],
1719                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1720                                  (n[XX]*(x+1))/grids->nc[XX],
1721                                  (n[YY]*(y+1))/grids->nc[YY],
1722                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1723                                  TRUE,
1724                                  pme_order,
1725                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1726                     t++;
1727                 }
1728             }
1729         }
1730     }
1731
1732     snew(grids->g2t,DIM);
1733     tfac = 1;
1734     for(d=DIM-1; d>=0; d--)
1735     {
1736         snew(grids->g2t[d],n[d]);
1737         t = 0;
1738         for(i=0; i<n[d]; i++)
1739         {
1740             /* The second check should match the parameters
1741              * of the pmegrid_init call above.
1742              */
1743             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1744             {
1745                 t++;
1746             }
1747             grids->g2t[d][i] = t*tfac;
1748         }
1749
1750         tfac *= grids->nc[d];
1751
1752         switch (d)
1753         {
1754         case XX: max_comm_lines = overlap_x;     break;
1755         case YY: max_comm_lines = overlap_y;     break;
1756         case ZZ: max_comm_lines = pme_order - 1; break;
1757         }
1758         grids->nthread_comm[d] = 0;
1759         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1760                grids->nthread_comm[d] < grids->nc[d])
1761         {
1762             grids->nthread_comm[d]++;
1763         }
1764         if (debug != NULL)
1765         {
1766             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1767                     'x'+d,grids->nthread_comm[d]);
1768         }
1769         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1770          * work, but this is not a problematic restriction.
1771          */
1772         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1773         {
1774             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1775         }
1776     }
1777 }
1778
1779
1780 static void pmegrids_destroy(pmegrids_t *grids)
1781 {
1782     int t;
1783
1784     if (grids->grid.grid != NULL)
1785     {
1786         sfree(grids->grid.grid);
1787
1788         if (grids->nthread > 0)
1789         {
1790             for(t=0; t<grids->nthread; t++)
1791             {
1792                 sfree(grids->grid_th[t].grid);
1793             }
1794             sfree(grids->grid_th);
1795         }
1796     }
1797 }
1798
1799
1800 static void realloc_work(pme_work_t *work,int nkx)
1801 {
1802     if (nkx > work->nalloc)
1803     {
1804         work->nalloc = nkx;
1805         srenew(work->mhx  ,work->nalloc);
1806         srenew(work->mhy  ,work->nalloc);
1807         srenew(work->mhz  ,work->nalloc);
1808         srenew(work->m2   ,work->nalloc);
1809         /* Allocate an aligned pointer for SSE operations, including 3 extra
1810          * elements at the end since SSE operates on 4 elements at a time.
1811          */
1812         sfree_aligned(work->denom);
1813         sfree_aligned(work->tmp1);
1814         sfree_aligned(work->eterm);
1815         snew_aligned(work->denom,work->nalloc+3,16);
1816         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1817         snew_aligned(work->eterm,work->nalloc+3,16);
1818         srenew(work->m2inv,work->nalloc);
1819     }
1820 }
1821
1822
1823 static void free_work(pme_work_t *work)
1824 {
1825     sfree(work->mhx);
1826     sfree(work->mhy);
1827     sfree(work->mhz);
1828     sfree(work->m2);
1829     sfree_aligned(work->denom);
1830     sfree_aligned(work->tmp1);
1831     sfree_aligned(work->eterm);
1832     sfree(work->m2inv);
1833 }
1834
1835
1836 #ifdef PME_SSE
1837     /* Calculate exponentials through SSE in float precision */
1838 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1839 {
1840     {
1841         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1842         __m128 f_sse;
1843         __m128 lu;
1844         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1845         int kx;
1846         f_sse = _mm_load1_ps(&f);
1847         for(kx=0; kx<end; kx+=4)
1848         {
1849             tmp_d1   = _mm_load_ps(d_aligned+kx);
1850             lu       = _mm_rcp_ps(tmp_d1);
1851             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1852             tmp_r    = _mm_load_ps(r_aligned+kx);
1853             tmp_r    = gmx_mm_exp_ps(tmp_r);
1854             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1855             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1856             _mm_store_ps(e_aligned+kx,tmp_e);
1857         }
1858     }
1859 }
1860 #else
1861 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1862 {
1863     int kx;
1864     for(kx=start; kx<end; kx++)
1865     {
1866         d[kx] = 1.0/d[kx];
1867     }
1868     for(kx=start; kx<end; kx++)
1869     {
1870         r[kx] = exp(r[kx]);
1871     }
1872     for(kx=start; kx<end; kx++)
1873     {
1874         e[kx] = f*r[kx]*d[kx];
1875     }
1876 }
1877 #endif
1878
1879
1880 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1881                          real ewaldcoeff,real vol,
1882                          gmx_bool bEnerVir,
1883                          int nthread,int thread)
1884 {
1885     /* do recip sum over local cells in grid */
1886     /* y major, z middle, x minor or continuous */
1887     t_complex *p0;
1888     int     kx,ky,kz,maxkx,maxky,maxkz;
1889     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1890     real    mx,my,mz;
1891     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1892     real    ets2,struct2,vfactor,ets2vf;
1893     real    d1,d2,energy=0;
1894     real    by,bz;
1895     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1896     real    rxx,ryx,ryy,rzx,rzy,rzz;
1897     pme_work_t *work;
1898     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1899     real    mhxk,mhyk,mhzk,m2k;
1900     real    corner_fac;
1901     ivec    complex_order;
1902     ivec    local_ndata,local_offset,local_size;
1903     real    elfac;
1904
1905     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1906
1907     nx = pme->nkx;
1908     ny = pme->nky;
1909     nz = pme->nkz;
1910
1911     /* Dimensions should be identical for A/B grid, so we just use A here */
1912     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1913                                       complex_order,
1914                                       local_ndata,
1915                                       local_offset,
1916                                       local_size);
1917
1918     rxx = pme->recipbox[XX][XX];
1919     ryx = pme->recipbox[YY][XX];
1920     ryy = pme->recipbox[YY][YY];
1921     rzx = pme->recipbox[ZZ][XX];
1922     rzy = pme->recipbox[ZZ][YY];
1923     rzz = pme->recipbox[ZZ][ZZ];
1924
1925     maxkx = (nx+1)/2;
1926     maxky = (ny+1)/2;
1927     maxkz = nz/2+1;
1928
1929     work = &pme->work[thread];
1930     mhx   = work->mhx;
1931     mhy   = work->mhy;
1932     mhz   = work->mhz;
1933     m2    = work->m2;
1934     denom = work->denom;
1935     tmp1  = work->tmp1;
1936     eterm = work->eterm;
1937     m2inv = work->m2inv;
1938
1939     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1940     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1941
1942     for(iyz=iyz0; iyz<iyz1; iyz++)
1943     {
1944         iy = iyz/local_ndata[ZZ];
1945         iz = iyz - iy*local_ndata[ZZ];
1946
1947         ky = iy + local_offset[YY];
1948
1949         if (ky < maxky)
1950         {
1951             my = ky;
1952         }
1953         else
1954         {
1955             my = (ky - ny);
1956         }
1957
1958         by = M_PI*vol*pme->bsp_mod[YY][ky];
1959
1960         kz = iz + local_offset[ZZ];
1961
1962         mz = kz;
1963
1964         bz = pme->bsp_mod[ZZ][kz];
1965
1966         /* 0.5 correction for corner points */
1967         corner_fac = 1;
1968         if (kz == 0 || kz == (nz+1)/2)
1969         {
1970             corner_fac = 0.5;
1971         }
1972
1973         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1974
1975         /* We should skip the k-space point (0,0,0) */
1976         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1977         {
1978             kxstart = local_offset[XX];
1979         }
1980         else
1981         {
1982             kxstart = local_offset[XX] + 1;
1983             p0++;
1984         }
1985         kxend = local_offset[XX] + local_ndata[XX];
1986
1987         if (bEnerVir)
1988         {
1989             /* More expensive inner loop, especially because of the storage
1990              * of the mh elements in array's.
1991              * Because x is the minor grid index, all mh elements
1992              * depend on kx for triclinic unit cells.
1993              */
1994
1995                 /* Two explicit loops to avoid a conditional inside the loop */
1996             for(kx=kxstart; kx<maxkx; kx++)
1997             {
1998                 mx = kx;
1999
2000                 mhxk      = mx * rxx;
2001                 mhyk      = mx * ryx + my * ryy;
2002                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2003                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2004                 mhx[kx]   = mhxk;
2005                 mhy[kx]   = mhyk;
2006                 mhz[kx]   = mhzk;
2007                 m2[kx]    = m2k;
2008                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2009                 tmp1[kx]  = -factor*m2k;
2010             }
2011
2012             for(kx=maxkx; kx<kxend; kx++)
2013             {
2014                 mx = (kx - nx);
2015
2016                 mhxk      = mx * rxx;
2017                 mhyk      = mx * ryx + my * ryy;
2018                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2019                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2020                 mhx[kx]   = mhxk;
2021                 mhy[kx]   = mhyk;
2022                 mhz[kx]   = mhzk;
2023                 m2[kx]    = m2k;
2024                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2025                 tmp1[kx]  = -factor*m2k;
2026             }
2027
2028             for(kx=kxstart; kx<kxend; kx++)
2029             {
2030                 m2inv[kx] = 1.0/m2[kx];
2031             }
2032
2033             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2034
2035             for(kx=kxstart; kx<kxend; kx++,p0++)
2036             {
2037                 d1      = p0->re;
2038                 d2      = p0->im;
2039
2040                 p0->re  = d1*eterm[kx];
2041                 p0->im  = d2*eterm[kx];
2042
2043                 struct2 = 2.0*(d1*d1+d2*d2);
2044
2045                 tmp1[kx] = eterm[kx]*struct2;
2046             }
2047
2048             for(kx=kxstart; kx<kxend; kx++)
2049             {
2050                 ets2     = corner_fac*tmp1[kx];
2051                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2052                 energy  += ets2;
2053
2054                 ets2vf   = ets2*vfactor;
2055                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2056                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2057                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2058                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2059                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2060                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2061             }
2062         }
2063         else
2064         {
2065             /* We don't need to calculate the energy and the virial.
2066              * In this case the triclinic overhead is small.
2067              */
2068
2069             /* Two explicit loops to avoid a conditional inside the loop */
2070
2071             for(kx=kxstart; kx<maxkx; kx++)
2072             {
2073                 mx = kx;
2074
2075                 mhxk      = mx * rxx;
2076                 mhyk      = mx * ryx + my * ryy;
2077                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2078                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2079                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2080                 tmp1[kx]  = -factor*m2k;
2081             }
2082
2083             for(kx=maxkx; kx<kxend; kx++)
2084             {
2085                 mx = (kx - nx);
2086
2087                 mhxk      = mx * rxx;
2088                 mhyk      = mx * ryx + my * ryy;
2089                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2090                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2091                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2092                 tmp1[kx]  = -factor*m2k;
2093             }
2094
2095             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2096
2097             for(kx=kxstart; kx<kxend; kx++,p0++)
2098             {
2099                 d1      = p0->re;
2100                 d2      = p0->im;
2101
2102                 p0->re  = d1*eterm[kx];
2103                 p0->im  = d2*eterm[kx];
2104             }
2105         }
2106     }
2107
2108     if (bEnerVir)
2109     {
2110         /* Update virial with local values.
2111          * The virial is symmetric by definition.
2112          * this virial seems ok for isotropic scaling, but I'm
2113          * experiencing problems on semiisotropic membranes.
2114          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2115          */
2116         work->vir[XX][XX] = 0.25*virxx;
2117         work->vir[YY][YY] = 0.25*viryy;
2118         work->vir[ZZ][ZZ] = 0.25*virzz;
2119         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2120         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2121         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2122
2123         /* This energy should be corrected for a charged system */
2124         work->energy = 0.5*energy;
2125     }
2126
2127     /* Return the loop count */
2128     return local_ndata[YY]*local_ndata[XX];
2129 }
2130
2131 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2132                              real *mesh_energy,matrix vir)
2133 {
2134     /* This function sums output over threads
2135      * and should therefore only be called after thread synchronization.
2136      */
2137     int thread;
2138
2139     *mesh_energy = pme->work[0].energy;
2140     copy_mat(pme->work[0].vir,vir);
2141
2142     for(thread=1; thread<nthread; thread++)
2143     {
2144         *mesh_energy += pme->work[thread].energy;
2145         m_add(vir,pme->work[thread].vir,vir);
2146     }
2147 }
2148
2149 #define DO_FSPLINE(order)                      \
2150 for(ithx=0; (ithx<order); ithx++)              \
2151 {                                              \
2152     index_x = (i0+ithx)*pny*pnz;               \
2153     tx      = thx[ithx];                       \
2154     dx      = dthx[ithx];                      \
2155                                                \
2156     for(ithy=0; (ithy<order); ithy++)          \
2157     {                                          \
2158         index_xy = index_x+(j0+ithy)*pnz;      \
2159         ty       = thy[ithy];                  \
2160         dy       = dthy[ithy];                 \
2161         fxy1     = fz1 = 0;                    \
2162                                                \
2163         for(ithz=0; (ithz<order); ithz++)      \
2164         {                                      \
2165             gval  = grid[index_xy+(k0+ithz)];  \
2166             fxy1 += thz[ithz]*gval;            \
2167             fz1  += dthz[ithz]*gval;           \
2168         }                                      \
2169         fx += dx*ty*fxy1;                      \
2170         fy += tx*dy*fxy1;                      \
2171         fz += tx*ty*fz1;                       \
2172     }                                          \
2173 }
2174
2175
2176 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2177                               gmx_bool bClearF,pme_atomcomm_t *atc,
2178                               splinedata_t *spline,
2179                               real scale)
2180 {
2181     /* sum forces for local particles */
2182     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2183     int     index_x,index_xy;
2184     int     nx,ny,nz,pnx,pny,pnz;
2185     int *   idxptr;
2186     real    tx,ty,dx,dy,qn;
2187     real    fx,fy,fz,gval;
2188     real    fxy1,fz1;
2189     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2190     int     norder;
2191     real    rxx,ryx,ryy,rzx,rzy,rzz;
2192     int     order;
2193
2194     pme_spline_work_t *work;
2195
2196     work = pme->spline_work;
2197
2198     order = pme->pme_order;
2199     thx   = spline->theta[XX];
2200     thy   = spline->theta[YY];
2201     thz   = spline->theta[ZZ];
2202     dthx  = spline->dtheta[XX];
2203     dthy  = spline->dtheta[YY];
2204     dthz  = spline->dtheta[ZZ];
2205     nx    = pme->nkx;
2206     ny    = pme->nky;
2207     nz    = pme->nkz;
2208     pnx   = pme->pmegrid_nx;
2209     pny   = pme->pmegrid_ny;
2210     pnz   = pme->pmegrid_nz;
2211
2212     rxx   = pme->recipbox[XX][XX];
2213     ryx   = pme->recipbox[YY][XX];
2214     ryy   = pme->recipbox[YY][YY];
2215     rzx   = pme->recipbox[ZZ][XX];
2216     rzy   = pme->recipbox[ZZ][YY];
2217     rzz   = pme->recipbox[ZZ][ZZ];
2218
2219     for(nn=0; nn<spline->n; nn++)
2220     {
2221         n  = spline->ind[nn];
2222         qn = scale*atc->q[n];
2223
2224         if (bClearF)
2225         {
2226             atc->f[n][XX] = 0;
2227             atc->f[n][YY] = 0;
2228             atc->f[n][ZZ] = 0;
2229         }
2230         if (qn != 0)
2231         {
2232             fx     = 0;
2233             fy     = 0;
2234             fz     = 0;
2235             idxptr = atc->idx[n];
2236             norder = nn*order;
2237
2238             i0   = idxptr[XX];
2239             j0   = idxptr[YY];
2240             k0   = idxptr[ZZ];
2241
2242             /* Pointer arithmetic alert, next six statements */
2243             thx  = spline->theta[XX] + norder;
2244             thy  = spline->theta[YY] + norder;
2245             thz  = spline->theta[ZZ] + norder;
2246             dthx = spline->dtheta[XX] + norder;
2247             dthy = spline->dtheta[YY] + norder;
2248             dthz = spline->dtheta[ZZ] + norder;
2249
2250             switch (order) {
2251             case 4:
2252 #ifdef PME_SSE
2253 #ifdef PME_SSE_UNALIGNED
2254 #define PME_GATHER_F_SSE_ORDER4
2255 #else
2256 #define PME_GATHER_F_SSE_ALIGNED
2257 #define PME_ORDER 4
2258 #endif
2259 #include "pme_sse_single.h"
2260 #else
2261                 DO_FSPLINE(4);
2262 #endif
2263                 break;
2264             case 5:
2265 #ifdef PME_SSE
2266 #define PME_GATHER_F_SSE_ALIGNED
2267 #define PME_ORDER 5
2268 #include "pme_sse_single.h"
2269 #else
2270                 DO_FSPLINE(5);
2271 #endif
2272                 break;
2273             default:
2274                 DO_FSPLINE(order);
2275                 break;
2276             }
2277
2278             atc->f[n][XX] += -qn*( fx*nx*rxx );
2279             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2280             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2281         }
2282     }
2283     /* Since the energy and not forces are interpolated
2284      * the net force might not be exactly zero.
2285      * This can be solved by also interpolating F, but
2286      * that comes at a cost.
2287      * A better hack is to remove the net force every
2288      * step, but that must be done at a higher level
2289      * since this routine doesn't see all atoms if running
2290      * in parallel. Don't know how important it is?  EL 990726
2291      */
2292 }
2293
2294
2295 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2296                                    pme_atomcomm_t *atc)
2297 {
2298     splinedata_t *spline;
2299     int     n,ithx,ithy,ithz,i0,j0,k0;
2300     int     index_x,index_xy;
2301     int *   idxptr;
2302     real    energy,pot,tx,ty,qn,gval;
2303     real    *thx,*thy,*thz;
2304     int     norder;
2305     int     order;
2306
2307     spline = &atc->spline[0];
2308
2309     order = pme->pme_order;
2310
2311     energy = 0;
2312     for(n=0; (n<atc->n); n++) {
2313         qn      = atc->q[n];
2314
2315         if (qn != 0) {
2316             idxptr = atc->idx[n];
2317             norder = n*order;
2318
2319             i0   = idxptr[XX];
2320             j0   = idxptr[YY];
2321             k0   = idxptr[ZZ];
2322
2323             /* Pointer arithmetic alert, next three statements */
2324             thx  = spline->theta[XX] + norder;
2325             thy  = spline->theta[YY] + norder;
2326             thz  = spline->theta[ZZ] + norder;
2327
2328             pot = 0;
2329             for(ithx=0; (ithx<order); ithx++)
2330             {
2331                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2332                 tx      = thx[ithx];
2333
2334                 for(ithy=0; (ithy<order); ithy++)
2335                 {
2336                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2337                     ty       = thy[ithy];
2338
2339                     for(ithz=0; (ithz<order); ithz++)
2340                     {
2341                         gval  = grid[index_xy+(k0+ithz)];
2342                         pot  += tx*ty*thz[ithz]*gval;
2343                     }
2344
2345                 }
2346             }
2347
2348             energy += pot*qn;
2349         }
2350     }
2351
2352     return energy;
2353 }
2354
2355 /* Macro to force loop unrolling by fixing order.
2356  * This gives a significant performance gain.
2357  */
2358 #define CALC_SPLINE(order)                     \
2359 {                                              \
2360     int j,k,l;                                 \
2361     real dr,div;                               \
2362     real data[PME_ORDER_MAX];                  \
2363     real ddata[PME_ORDER_MAX];                 \
2364                                                \
2365     for(j=0; (j<DIM); j++)                     \
2366     {                                          \
2367         dr  = xptr[j];                         \
2368                                                \
2369         /* dr is relative offset from lower cell limit */ \
2370         data[order-1] = 0;                     \
2371         data[1] = dr;                          \
2372         data[0] = 1 - dr;                      \
2373                                                \
2374         for(k=3; (k<order); k++)               \
2375         {                                      \
2376             div = 1.0/(k - 1.0);               \
2377             data[k-1] = div*dr*data[k-2];      \
2378             for(l=1; (l<(k-1)); l++)           \
2379             {                                  \
2380                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2381                                    data[k-l-1]);                \
2382             }                                  \
2383             data[0] = div*(1-dr)*data[0];      \
2384         }                                      \
2385         /* differentiate */                    \
2386         ddata[0] = -data[0];                   \
2387         for(k=1; (k<order); k++)               \
2388         {                                      \
2389             ddata[k] = data[k-1] - data[k];    \
2390         }                                      \
2391                                                \
2392         div = 1.0/(order - 1);                 \
2393         data[order-1] = div*dr*data[order-2];  \
2394         for(l=1; (l<(order-1)); l++)           \
2395         {                                      \
2396             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2397                                (order-l-dr)*data[order-l-1]); \
2398         }                                      \
2399         data[0] = div*(1 - dr)*data[0];        \
2400                                                \
2401         for(k=0; k<order; k++)                 \
2402         {                                      \
2403             theta[j][i*order+k]  = data[k];    \
2404             dtheta[j][i*order+k] = ddata[k];   \
2405         }                                      \
2406     }                                          \
2407 }
2408
2409 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2410                    rvec fractx[],int nr,int ind[],real charge[],
2411                    gmx_bool bFreeEnergy)
2412 {
2413     /* construct splines for local atoms */
2414     int  i,ii;
2415     real *xptr;
2416
2417     for(i=0; i<nr; i++)
2418     {
2419         /* With free energy we do not use the charge check.
2420          * In most cases this will be more efficient than calling make_bsplines
2421          * twice, since usually more than half the particles have charges.
2422          */
2423         ii = ind[i];
2424         if (bFreeEnergy || charge[ii] != 0.0) {
2425             xptr = fractx[ii];
2426             switch(order) {
2427             case 4:  CALC_SPLINE(4);     break;
2428             case 5:  CALC_SPLINE(5);     break;
2429             default: CALC_SPLINE(order); break;
2430             }
2431         }
2432     }
2433 }
2434
2435
2436 void make_dft_mod(real *mod,real *data,int ndata)
2437 {
2438   int i,j;
2439   real sc,ss,arg;
2440
2441   for(i=0;i<ndata;i++) {
2442     sc=ss=0;
2443     for(j=0;j<ndata;j++) {
2444       arg=(2.0*M_PI*i*j)/ndata;
2445       sc+=data[j]*cos(arg);
2446       ss+=data[j]*sin(arg);
2447     }
2448     mod[i]=sc*sc+ss*ss;
2449   }
2450   for(i=0;i<ndata;i++)
2451     if(mod[i]<1e-7)
2452       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2453 }
2454
2455
2456 static void make_bspline_moduli(splinevec bsp_mod,
2457                                 int nx,int ny,int nz,int order)
2458 {
2459   int nmax=max(nx,max(ny,nz));
2460   real *data,*ddata,*bsp_data;
2461   int i,k,l;
2462   real div;
2463
2464   snew(data,order);
2465   snew(ddata,order);
2466   snew(bsp_data,nmax);
2467
2468   data[order-1]=0;
2469   data[1]=0;
2470   data[0]=1;
2471
2472   for(k=3;k<order;k++) {
2473     div=1.0/(k-1.0);
2474     data[k-1]=0;
2475     for(l=1;l<(k-1);l++)
2476       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2477     data[0]=div*data[0];
2478   }
2479   /* differentiate */
2480   ddata[0]=-data[0];
2481   for(k=1;k<order;k++)
2482     ddata[k]=data[k-1]-data[k];
2483   div=1.0/(order-1);
2484   data[order-1]=0;
2485   for(l=1;l<(order-1);l++)
2486     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2487   data[0]=div*data[0];
2488
2489   for(i=0;i<nmax;i++)
2490     bsp_data[i]=0;
2491   for(i=1;i<=order;i++)
2492     bsp_data[i]=data[i-1];
2493
2494   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2495   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2496   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2497
2498   sfree(data);
2499   sfree(ddata);
2500   sfree(bsp_data);
2501 }
2502
2503
2504 /* Return the P3M optimal influence function */
2505 static double do_p3m_influence(double z, int order)
2506 {
2507     double z2,z4;
2508
2509     z2 = z*z;
2510     z4 = z2*z2;
2511
2512     /* The formula and most constants can be found in:
2513      * Ballenegger et al., JCTC 8, 936 (2012)
2514      */
2515     switch(order)
2516     {
2517     case 2:
2518         return 1.0 - 2.0*z2/3.0;
2519         break;
2520     case 3:
2521         return 1.0 - z2 + 2.0*z4/15.0;
2522         break;
2523     case 4:
2524         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2525         break;
2526     case 5:
2527         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2528         break;
2529     case 6:
2530         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2531         break;
2532     case 7:
2533         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2534     case 8:
2535         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2536         break;
2537     }
2538
2539     return 0.0;
2540 }
2541
2542 /* Calculate the P3M B-spline moduli for one dimension */
2543 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2544 {
2545     double zarg,zai,sinzai,infl;
2546     int    maxk,i;
2547
2548     if (order > 8)
2549     {
2550         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2551     }
2552
2553     zarg = M_PI/n;
2554
2555     maxk = (n + 1)/2;
2556
2557     for(i=-maxk; i<0; i++)
2558     {
2559         zai    = zarg*i;
2560         sinzai = sin(zai);
2561         infl   = do_p3m_influence(sinzai,order);
2562         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2563     }
2564     bsp_mod[0] = 1.0;
2565     for(i=1; i<maxk; i++)
2566     {
2567         zai    = zarg*i;
2568         sinzai = sin(zai);
2569         infl   = do_p3m_influence(sinzai,order);
2570         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2571     }
2572 }
2573
2574 /* Calculate the P3M B-spline moduli */
2575 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2576                                     int nx,int ny,int nz,int order)
2577 {
2578     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2579     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2580     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2581 }
2582
2583
2584 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2585 {
2586   int nslab,n,i;
2587   int fw,bw;
2588
2589   nslab = atc->nslab;
2590
2591   n = 0;
2592   for(i=1; i<=nslab/2; i++) {
2593     fw = (atc->nodeid + i) % nslab;
2594     bw = (atc->nodeid - i + nslab) % nslab;
2595     if (n < nslab - 1) {
2596       atc->node_dest[n] = fw;
2597       atc->node_src[n]  = bw;
2598       n++;
2599     }
2600     if (n < nslab - 1) {
2601       atc->node_dest[n] = bw;
2602       atc->node_src[n]  = fw;
2603       n++;
2604     }
2605   }
2606 }
2607
2608 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2609 {
2610     int thread;
2611
2612     if(NULL != log)
2613     {
2614         fprintf(log,"Destroying PME data structures.\n");
2615     }
2616
2617     sfree((*pmedata)->nnx);
2618     sfree((*pmedata)->nny);
2619     sfree((*pmedata)->nnz);
2620
2621     pmegrids_destroy(&(*pmedata)->pmegridA);
2622
2623     sfree((*pmedata)->fftgridA);
2624     sfree((*pmedata)->cfftgridA);
2625     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2626
2627     if ((*pmedata)->pmegridB.grid.grid != NULL)
2628     {
2629         pmegrids_destroy(&(*pmedata)->pmegridB);
2630         sfree((*pmedata)->fftgridB);
2631         sfree((*pmedata)->cfftgridB);
2632         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2633     }
2634     for(thread=0; thread<(*pmedata)->nthread; thread++)
2635     {
2636         free_work(&(*pmedata)->work[thread]);
2637     }
2638     sfree((*pmedata)->work);
2639
2640     sfree(*pmedata);
2641     *pmedata = NULL;
2642
2643   return 0;
2644 }
2645
2646 static int mult_up(int n,int f)
2647 {
2648     return ((n + f - 1)/f)*f;
2649 }
2650
2651
2652 static double pme_load_imbalance(gmx_pme_t pme)
2653 {
2654     int    nma,nmi;
2655     double n1,n2,n3;
2656
2657     nma = pme->nnodes_major;
2658     nmi = pme->nnodes_minor;
2659
2660     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2661     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2662     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2663
2664     /* pme_solve is roughly double the cost of an fft */
2665
2666     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2667 }
2668
2669 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2670                           int dimind,gmx_bool bSpread)
2671 {
2672     int nk,k,s,thread;
2673
2674     atc->dimind = dimind;
2675     atc->nslab  = 1;
2676     atc->nodeid = 0;
2677     atc->pd_nalloc = 0;
2678 #ifdef GMX_MPI
2679     if (pme->nnodes > 1)
2680     {
2681         atc->mpi_comm = pme->mpi_comm_d[dimind];
2682         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2683         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2684     }
2685     if (debug)
2686     {
2687         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2688     }
2689 #endif
2690
2691     atc->bSpread   = bSpread;
2692     atc->pme_order = pme->pme_order;
2693
2694     if (atc->nslab > 1)
2695     {
2696         /* These three allocations are not required for particle decomp. */
2697         snew(atc->node_dest,atc->nslab);
2698         snew(atc->node_src,atc->nslab);
2699         setup_coordinate_communication(atc);
2700
2701         snew(atc->count_thread,pme->nthread);
2702         for(thread=0; thread<pme->nthread; thread++)
2703         {
2704             snew(atc->count_thread[thread],atc->nslab);
2705         }
2706         atc->count = atc->count_thread[0];
2707         snew(atc->rcount,atc->nslab);
2708         snew(atc->buf_index,atc->nslab);
2709     }
2710
2711     atc->nthread = pme->nthread;
2712     if (atc->nthread > 1)
2713     {
2714         snew(atc->thread_plist,atc->nthread);
2715     }
2716     snew(atc->spline,atc->nthread);
2717     for(thread=0; thread<atc->nthread; thread++)
2718     {
2719         if (atc->nthread > 1)
2720         {
2721             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2722             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2723         }
2724         snew(atc->spline[thread].thread_one,pme->nthread);
2725         atc->spline[thread].thread_one[thread] = 1;
2726     }
2727 }
2728
2729 static void
2730 init_overlap_comm(pme_overlap_t *  ol,
2731                   int              norder,
2732 #ifdef GMX_MPI
2733                   MPI_Comm         comm,
2734 #endif
2735                   int              nnodes,
2736                   int              nodeid,
2737                   int              ndata,
2738                   int              commplainsize)
2739 {
2740     int lbnd,rbnd,maxlr,b,i;
2741     int exten;
2742     int nn,nk;
2743     pme_grid_comm_t *pgc;
2744     gmx_bool bCont;
2745     int fft_start,fft_end,send_index1,recv_index1;
2746 #ifdef GMX_MPI
2747     MPI_Status stat;
2748
2749     ol->mpi_comm = comm;
2750 #endif
2751
2752     ol->nnodes = nnodes;
2753     ol->nodeid = nodeid;
2754
2755     /* Linear translation of the PME grid won't affect reciprocal space
2756      * calculations, so to optimize we only interpolate "upwards",
2757      * which also means we only have to consider overlap in one direction.
2758      * I.e., particles on this node might also be spread to grid indices
2759      * that belong to higher nodes (modulo nnodes)
2760      */
2761
2762     snew(ol->s2g0,ol->nnodes+1);
2763     snew(ol->s2g1,ol->nnodes);
2764     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2765     for(i=0; i<nnodes; i++)
2766     {
2767         /* s2g0 the local interpolation grid start.
2768          * s2g1 the local interpolation grid end.
2769          * Because grid overlap communication only goes forward,
2770          * the grid the slabs for fft's should be rounded down.
2771          */
2772         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2773         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2774
2775         if (debug)
2776         {
2777             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2778         }
2779     }
2780     ol->s2g0[nnodes] = ndata;
2781     if (debug) { fprintf(debug,"\n"); }
2782
2783     /* Determine with how many nodes we need to communicate the grid overlap */
2784     b = 0;
2785     do
2786     {
2787         b++;
2788         bCont = FALSE;
2789         for(i=0; i<nnodes; i++)
2790         {
2791             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2792                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2793             {
2794                 bCont = TRUE;
2795             }
2796         }
2797     }
2798     while (bCont && b < nnodes);
2799     ol->noverlap_nodes = b - 1;
2800
2801     snew(ol->send_id,ol->noverlap_nodes);
2802     snew(ol->recv_id,ol->noverlap_nodes);
2803     for(b=0; b<ol->noverlap_nodes; b++)
2804     {
2805         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2806         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2807     }
2808     snew(ol->comm_data, ol->noverlap_nodes);
2809
2810     ol->send_size = 0;
2811     for(b=0; b<ol->noverlap_nodes; b++)
2812     {
2813         pgc = &ol->comm_data[b];
2814         /* Send */
2815         fft_start        = ol->s2g0[ol->send_id[b]];
2816         fft_end          = ol->s2g0[ol->send_id[b]+1];
2817         if (ol->send_id[b] < nodeid)
2818         {
2819             fft_start += ndata;
2820             fft_end   += ndata;
2821         }
2822         send_index1      = ol->s2g1[nodeid];
2823         send_index1      = min(send_index1,fft_end);
2824         pgc->send_index0 = fft_start;
2825         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2826         ol->send_size    += pgc->send_nindex;
2827
2828         /* We always start receiving to the first index of our slab */
2829         fft_start        = ol->s2g0[ol->nodeid];
2830         fft_end          = ol->s2g0[ol->nodeid+1];
2831         recv_index1      = ol->s2g1[ol->recv_id[b]];
2832         if (ol->recv_id[b] > nodeid)
2833         {
2834             recv_index1 -= ndata;
2835         }
2836         recv_index1      = min(recv_index1,fft_end);
2837         pgc->recv_index0 = fft_start;
2838         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2839     }
2840
2841 #ifdef GMX_MPI
2842     /* Communicate the buffer sizes to receive */
2843     for(b=0; b<ol->noverlap_nodes; b++)
2844     {
2845         MPI_Sendrecv(&ol->send_size             ,1,MPI_INT,ol->send_id[b],b,
2846                      &ol->comm_data[b].recv_size,1,MPI_INT,ol->recv_id[b],b,
2847                      ol->mpi_comm,&stat);
2848     }
2849 #endif
2850
2851     /* For non-divisible grid we need pme_order iso pme_order-1 */
2852     snew(ol->sendbuf,norder*commplainsize);
2853     snew(ol->recvbuf,norder*commplainsize);
2854 }
2855
2856 static void
2857 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2858                               int **global_to_local,
2859                               real **fraction_shift)
2860 {
2861     int i;
2862     int * gtl;
2863     real * fsh;
2864
2865     snew(gtl,5*n);
2866     snew(fsh,5*n);
2867     for(i=0; (i<5*n); i++)
2868     {
2869         /* Determine the global to local grid index */
2870         gtl[i] = (i - local_start + n) % n;
2871         /* For coordinates that fall within the local grid the fraction
2872          * is correct, we don't need to shift it.
2873          */
2874         fsh[i] = 0;
2875         if (local_range < n)
2876         {
2877             /* Due to rounding issues i could be 1 beyond the lower or
2878              * upper boundary of the local grid. Correct the index for this.
2879              * If we shift the index, we need to shift the fraction by
2880              * the same amount in the other direction to not affect
2881              * the weights.
2882              * Note that due to this shifting the weights at the end of
2883              * the spline might change, but that will only involve values
2884              * between zero and values close to the precision of a real,
2885              * which is anyhow the accuracy of the whole mesh calculation.
2886              */
2887             /* With local_range=0 we should not change i=local_start */
2888             if (i % n != local_start)
2889             {
2890                 if (gtl[i] == n-1)
2891                 {
2892                     gtl[i] = 0;
2893                     fsh[i] = -1;
2894                 }
2895                 else if (gtl[i] == local_range)
2896                 {
2897                     gtl[i] = local_range - 1;
2898                     fsh[i] = 1;
2899                 }
2900             }
2901         }
2902     }
2903
2904     *global_to_local = gtl;
2905     *fraction_shift  = fsh;
2906 }
2907
2908 static pme_spline_work_t *make_pme_spline_work(int order)
2909 {
2910     pme_spline_work_t *work;
2911
2912 #ifdef PME_SSE
2913     float  tmp[8];
2914     __m128 zero_SSE;
2915     int    of,i;
2916
2917     snew_aligned(work,1,16);
2918
2919     zero_SSE = _mm_setzero_ps();
2920
2921     /* Generate bit masks to mask out the unused grid entries,
2922      * as we only operate on order of the 8 grid entries that are
2923      * load into 2 SSE float registers.
2924      */
2925     for(of=0; of<8-(order-1); of++)
2926     {
2927         for(i=0; i<8; i++)
2928         {
2929             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2930         }
2931         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2932         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2933         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2934         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2935     }
2936 #else
2937     work = NULL;
2938 #endif
2939
2940     return work;
2941 }
2942
2943 static void
2944 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2945 {
2946     int nk_new;
2947
2948     if (*nk % nnodes != 0)
2949     {
2950         nk_new = nnodes*(*nk/nnodes + 1);
2951
2952         if (2*nk_new >= 3*(*nk))
2953         {
2954             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2955                       dim,*nk,dim,nnodes,dim);
2956         }
2957
2958         if (fplog != NULL)
2959         {
2960             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2961                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2962         }
2963
2964         *nk = nk_new;
2965     }
2966 }
2967
2968 int gmx_pme_init(gmx_pme_t *         pmedata,
2969                  t_commrec *         cr,
2970                  int                 nnodes_major,
2971                  int                 nnodes_minor,
2972                  t_inputrec *        ir,
2973                  int                 homenr,
2974                  gmx_bool            bFreeEnergy,
2975                  gmx_bool            bReproducible,
2976                  int                 nthread)
2977 {
2978     gmx_pme_t pme=NULL;
2979
2980     pme_atomcomm_t *atc;
2981     ivec ndata;
2982
2983     if (debug)
2984         fprintf(debug,"Creating PME data structures.\n");
2985     snew(pme,1);
2986
2987     pme->redist_init         = FALSE;
2988     pme->sum_qgrid_tmp       = NULL;
2989     pme->sum_qgrid_dd_tmp    = NULL;
2990     pme->buf_nalloc          = 0;
2991     pme->redist_buf_nalloc   = 0;
2992
2993     pme->nnodes              = 1;
2994     pme->bPPnode             = TRUE;
2995
2996     pme->nnodes_major        = nnodes_major;
2997     pme->nnodes_minor        = nnodes_minor;
2998
2999 #ifdef GMX_MPI
3000     if (nnodes_major*nnodes_minor > 1)
3001     {
3002         pme->mpi_comm = cr->mpi_comm_mygroup;
3003
3004         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
3005         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
3006         if (pme->nnodes != nnodes_major*nnodes_minor)
3007         {
3008             gmx_incons("PME node count mismatch");
3009         }
3010     }
3011     else
3012     {
3013         pme->mpi_comm = MPI_COMM_NULL;
3014     }
3015 #endif
3016
3017     if (pme->nnodes == 1)
3018     {
3019 #ifdef GMX_MPI
3020         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3021         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3022 #endif
3023         pme->ndecompdim = 0;
3024         pme->nodeid_major = 0;
3025         pme->nodeid_minor = 0;
3026 #ifdef GMX_MPI
3027         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3028 #endif
3029     }
3030     else
3031     {
3032         if (nnodes_minor == 1)
3033         {
3034 #ifdef GMX_MPI
3035             pme->mpi_comm_d[0] = pme->mpi_comm;
3036             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3037 #endif
3038             pme->ndecompdim = 1;
3039             pme->nodeid_major = pme->nodeid;
3040             pme->nodeid_minor = 0;
3041
3042         }
3043         else if (nnodes_major == 1)
3044         {
3045 #ifdef GMX_MPI
3046             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3047             pme->mpi_comm_d[1] = pme->mpi_comm;
3048 #endif
3049             pme->ndecompdim = 1;
3050             pme->nodeid_major = 0;
3051             pme->nodeid_minor = pme->nodeid;
3052         }
3053         else
3054         {
3055             if (pme->nnodes % nnodes_major != 0)
3056             {
3057                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3058             }
3059             pme->ndecompdim = 2;
3060
3061 #ifdef GMX_MPI
3062             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3063                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3064             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3065                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3066
3067             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3068             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3069             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3070             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3071 #endif
3072         }
3073         pme->bPPnode = (cr->duty & DUTY_PP);
3074     }
3075
3076     pme->nthread = nthread;
3077
3078     if (ir->ePBC == epbcSCREW)
3079     {
3080         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3081     }
3082
3083     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3084     pme->nkx         = ir->nkx;
3085     pme->nky         = ir->nky;
3086     pme->nkz         = ir->nkz;
3087     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3088     pme->pme_order   = ir->pme_order;
3089     pme->epsilon_r   = ir->epsilon_r;
3090
3091     if (pme->pme_order > PME_ORDER_MAX)
3092     {
3093         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3094                   pme->pme_order,PME_ORDER_MAX);
3095     }
3096
3097     /* Currently pme.c supports only the fft5d FFT code.
3098      * Therefore the grid always needs to be divisible by nnodes.
3099      * When the old 1D code is also supported again, change this check.
3100      *
3101      * This check should be done before calling gmx_pme_init
3102      * and fplog should be passed iso stderr.
3103      *
3104     if (pme->ndecompdim >= 2)
3105     */
3106     if (pme->ndecompdim >= 1)
3107     {
3108         /*
3109         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3110                                         'x',nnodes_major,&pme->nkx);
3111         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3112                                         'y',nnodes_minor,&pme->nky);
3113         */
3114     }
3115
3116     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3117         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3118         pme->nkz <= pme->pme_order)
3119     {
3120         gmx_fatal(FARGS,"The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",pme->pme_order);
3121     }
3122
3123     if (pme->nnodes > 1) {
3124         double imbal;
3125
3126 #ifdef GMX_MPI
3127         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3128         MPI_Type_commit(&(pme->rvec_mpi));
3129 #endif
3130
3131         /* Note that the charge spreading and force gathering, which usually
3132          * takes about the same amount of time as FFT+solve_pme,
3133          * is always fully load balanced
3134          * (unless the charge distribution is inhomogeneous).
3135          */
3136
3137         imbal = pme_load_imbalance(pme);
3138         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3139         {
3140             fprintf(stderr,
3141                     "\n"
3142                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3143                     "      For optimal PME load balancing\n"
3144                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3145                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3146                     "\n",
3147                     (int)((imbal-1)*100 + 0.5),
3148                     pme->nkx,pme->nky,pme->nnodes_major,
3149                     pme->nky,pme->nkz,pme->nnodes_minor);
3150         }
3151     }
3152
3153     /* For non-divisible grid we need pme_order iso pme_order-1 */
3154     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3155      * y is always copied through a buffer: we don't need padding in z,
3156      * but we do need the overlap in x because of the communication order.
3157      */
3158     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3159 #ifdef GMX_MPI
3160                       pme->mpi_comm_d[0],
3161 #endif
3162                       pme->nnodes_major,pme->nodeid_major,
3163                       pme->nkx,
3164                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3165
3166     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3167      * We do this with an offset buffer of equal size, so we need to allocate
3168      * extra for the offset. That's what the (+1)*pme->nkz is for.
3169      */
3170     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3171 #ifdef GMX_MPI
3172                       pme->mpi_comm_d[1],
3173 #endif
3174                       pme->nnodes_minor,pme->nodeid_minor,
3175                       pme->nky,
3176                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3177
3178     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3179      * We only allow multiple communication pulses in dim 1, not in dim 0.
3180      */
3181     if (pme->nthread > 1 && (pme->overlap[0].noverlap_nodes > 1 ||
3182                              pme->nkx < pme->nnodes_major*pme->pme_order))
3183     {
3184         gmx_fatal(FARGS,"The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x and should be >= pme_order (%d). To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3185                   pme->nkx/(double)pme->nnodes_major,pme->pme_order);
3186     }
3187
3188     snew(pme->bsp_mod[XX],pme->nkx);
3189     snew(pme->bsp_mod[YY],pme->nky);
3190     snew(pme->bsp_mod[ZZ],pme->nkz);
3191
3192     /* The required size of the interpolation grid, including overlap.
3193      * The allocated size (pmegrid_n?) might be slightly larger.
3194      */
3195     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3196                       pme->overlap[0].s2g0[pme->nodeid_major];
3197     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3198                       pme->overlap[1].s2g0[pme->nodeid_minor];
3199     pme->pmegrid_nz_base = pme->nkz;
3200     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3201     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3202
3203     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3204     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3205     pme->pmegrid_start_iz = 0;
3206
3207     make_gridindex5_to_localindex(pme->nkx,
3208                                   pme->pmegrid_start_ix,
3209                                   pme->pmegrid_nx - (pme->pme_order-1),
3210                                   &pme->nnx,&pme->fshx);
3211     make_gridindex5_to_localindex(pme->nky,
3212                                   pme->pmegrid_start_iy,
3213                                   pme->pmegrid_ny - (pme->pme_order-1),
3214                                   &pme->nny,&pme->fshy);
3215     make_gridindex5_to_localindex(pme->nkz,
3216                                   pme->pmegrid_start_iz,
3217                                   pme->pmegrid_nz_base,
3218                                   &pme->nnz,&pme->fshz);
3219
3220     pmegrids_init(&pme->pmegridA,
3221                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3222                   pme->pmegrid_nz_base,
3223                   pme->pme_order,
3224                   pme->nthread,
3225                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3226                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3227
3228     pme->spline_work = make_pme_spline_work(pme->pme_order);
3229
3230     ndata[0] = pme->nkx;
3231     ndata[1] = pme->nky;
3232     ndata[2] = pme->nkz;
3233
3234     /* This routine will allocate the grid data to fit the FFTs */
3235     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3236                             &pme->fftgridA,&pme->cfftgridA,
3237                             pme->mpi_comm_d,
3238                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3239                             bReproducible,pme->nthread);
3240
3241     if (bFreeEnergy)
3242     {
3243         pmegrids_init(&pme->pmegridB,
3244                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3245                       pme->pmegrid_nz_base,
3246                       pme->pme_order,
3247                       pme->nthread,
3248                       pme->nkx % pme->nnodes_major != 0,
3249                       pme->nky % pme->nnodes_minor != 0);
3250
3251         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3252                                 &pme->fftgridB,&pme->cfftgridB,
3253                                 pme->mpi_comm_d,
3254                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3255                                 bReproducible,pme->nthread);
3256     }
3257     else
3258     {
3259         pme->pmegridB.grid.grid = NULL;
3260         pme->fftgridB           = NULL;
3261         pme->cfftgridB          = NULL;
3262     }
3263
3264     if (!pme->bP3M)
3265     {
3266         /* Use plain SPME B-spline interpolation */
3267         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3268     }
3269     else
3270     {
3271         /* Use the P3M grid-optimized influence function */
3272         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3273     }
3274
3275     /* Use atc[0] for spreading */
3276     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3277     if (pme->ndecompdim >= 2)
3278     {
3279         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3280     }
3281
3282     if (pme->nnodes == 1) {
3283         pme->atc[0].n = homenr;
3284         pme_realloc_atomcomm_things(&pme->atc[0]);
3285     }
3286
3287     {
3288         int thread;
3289
3290         /* Use fft5d, order after FFT is y major, z, x minor */
3291
3292         snew(pme->work,pme->nthread);
3293         for(thread=0; thread<pme->nthread; thread++)
3294         {
3295             realloc_work(&pme->work[thread],pme->nkx);
3296         }
3297     }
3298
3299     *pmedata = pme;
3300     
3301     return 0;
3302 }
3303
3304 static void reuse_pmegrids(const pmegrids_t *old,pmegrids_t *new)
3305 {
3306     int d,t;
3307
3308     for(d=0; d<DIM; d++)
3309     {
3310         if (new->grid.n[d] > old->grid.n[d])
3311         {
3312             return;
3313         }
3314     }
3315
3316     sfree_aligned(new->grid.grid);
3317     new->grid.grid = old->grid.grid;
3318
3319     if (new->nthread > 1 && new->nthread == old->nthread)
3320     {
3321         sfree_aligned(new->grid_all);
3322         for(t=0; t<new->nthread; t++)
3323         {
3324             new->grid_th[t].grid = old->grid_th[t].grid;
3325         }
3326     }
3327 }
3328
3329 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3330                    t_commrec *         cr,
3331                    gmx_pme_t           pme_src,
3332                    const t_inputrec *  ir,
3333                    ivec                grid_size)
3334 {
3335     t_inputrec irc;
3336     int homenr;
3337     int ret;
3338
3339     irc = *ir;
3340     irc.nkx = grid_size[XX];
3341     irc.nky = grid_size[YY];
3342     irc.nkz = grid_size[ZZ];
3343
3344     if (pme_src->nnodes == 1)
3345     {
3346         homenr = pme_src->atc[0].n;
3347     }
3348     else
3349     {
3350         homenr = -1;
3351     }
3352
3353     ret = gmx_pme_init(pmedata,cr,pme_src->nnodes_major,pme_src->nnodes_minor,
3354                        &irc,homenr,pme_src->bFEP,FALSE,pme_src->nthread);
3355
3356     if (ret == 0)
3357     {
3358         /* We can easily reuse the allocated pme grids in pme_src */
3359         reuse_pmegrids(&pme_src->pmegridA,&(*pmedata)->pmegridA);
3360         /* We would like to reuse the fft grids, but that's harder */
3361     }
3362
3363     return ret;
3364 }
3365
3366
3367 static void copy_local_grid(gmx_pme_t pme,
3368                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3369 {
3370     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3371     int  fft_my,fft_mz;
3372     int  nsx,nsy,nsz;
3373     ivec nf;
3374     int  offx,offy,offz,x,y,z,i0,i0t;
3375     int  d;
3376     pmegrid_t *pmegrid;
3377     real *grid_th;
3378
3379     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3380                                    local_fft_ndata,
3381                                    local_fft_offset,
3382                                    local_fft_size);
3383     fft_my = local_fft_size[YY];
3384     fft_mz = local_fft_size[ZZ];
3385
3386     pmegrid = &pmegrids->grid_th[thread];
3387
3388     nsx = pmegrid->s[XX];
3389     nsy = pmegrid->s[YY];
3390     nsz = pmegrid->s[ZZ];
3391
3392     for(d=0; d<DIM; d++)
3393     {
3394         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3395                     local_fft_ndata[d] - pmegrid->offset[d]);
3396     }
3397
3398     offx = pmegrid->offset[XX];
3399     offy = pmegrid->offset[YY];
3400     offz = pmegrid->offset[ZZ];
3401
3402     /* Directly copy the non-overlapping parts of the local grids.
3403      * This also initializes the full grid.
3404      */
3405     grid_th = pmegrid->grid;
3406     for(x=0; x<nf[XX]; x++)
3407     {
3408         for(y=0; y<nf[YY]; y++)
3409         {
3410             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3411             i0t = (x*nsy + y)*nsz;
3412             for(z=0; z<nf[ZZ]; z++)
3413             {
3414                 fftgrid[i0+z] = grid_th[i0t+z];
3415             }
3416         }
3417     }
3418 }
3419
3420 static void
3421 reduce_threadgrid_overlap(gmx_pme_t pme,
3422                           const pmegrids_t *pmegrids,int thread,
3423                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3424 {
3425     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3426     int  fft_nx,fft_ny,fft_nz;
3427     int  fft_my,fft_mz;
3428     int  buf_my=-1;
3429     int  nsx,nsy,nsz;
3430     ivec ne;
3431     int  offx,offy,offz,x,y,z,i0,i0t;
3432     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3433     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3434     gmx_bool bCommX,bCommY;
3435     int  d;
3436     int  thread_f;
3437     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3438     const real *grid_th;
3439     real *commbuf=NULL;
3440
3441     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3442                                    local_fft_ndata,
3443                                    local_fft_offset,
3444                                    local_fft_size);
3445     fft_nx = local_fft_ndata[XX];
3446     fft_ny = local_fft_ndata[YY];
3447     fft_nz = local_fft_ndata[ZZ];
3448
3449     fft_my = local_fft_size[YY];
3450     fft_mz = local_fft_size[ZZ];
3451
3452     /* This routine is called when all thread have finished spreading.
3453      * Here each thread sums grid contributions calculated by other threads
3454      * to the thread local grid volume.
3455      * To minimize the number of grid copying operations,
3456      * this routines sums immediately from the pmegrid to the fftgrid.
3457      */
3458
3459     /* Determine which part of the full node grid we should operate on,
3460      * this is our thread local part of the full grid.
3461      */
3462     pmegrid = &pmegrids->grid_th[thread];
3463
3464     for(d=0; d<DIM; d++)
3465     {
3466         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3467                     local_fft_ndata[d]);
3468     }
3469
3470     offx = pmegrid->offset[XX];
3471     offy = pmegrid->offset[YY];
3472     offz = pmegrid->offset[ZZ];
3473
3474
3475     bClearBufX  = TRUE;
3476     bClearBufY  = TRUE;
3477     bClearBufXY = TRUE;
3478
3479     /* Now loop over all the thread data blocks that contribute
3480      * to the grid region we (our thread) are operating on.
3481      */
3482     /* Note that ffy_nx/y is equal to the number of grid points
3483      * between the first point of our node grid and the one of the next node.
3484      */
3485     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3486     {
3487         fx = pmegrid->ci[XX] + sx;
3488         ox = 0;
3489         bCommX = FALSE;
3490         if (fx < 0) {
3491             fx += pmegrids->nc[XX];
3492             ox -= fft_nx;
3493             bCommX = (pme->nnodes_major > 1);
3494         }
3495         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3496         ox += pmegrid_g->offset[XX];
3497         if (!bCommX)
3498         {
3499             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3500         }
3501         else
3502         {
3503             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3504         }
3505
3506         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3507         {
3508             fy = pmegrid->ci[YY] + sy;
3509             oy = 0;
3510             bCommY = FALSE;
3511             if (fy < 0) {
3512                 fy += pmegrids->nc[YY];
3513                 oy -= fft_ny;
3514                 bCommY = (pme->nnodes_minor > 1);
3515             }
3516             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3517             oy += pmegrid_g->offset[YY];
3518             if (!bCommY)
3519             {
3520                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3521             }
3522             else
3523             {
3524                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3525             }
3526
3527             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3528             {
3529                 fz = pmegrid->ci[ZZ] + sz;
3530                 oz = 0;
3531                 if (fz < 0)
3532                 {
3533                     fz += pmegrids->nc[ZZ];
3534                     oz -= fft_nz;
3535                 }
3536                 pmegrid_g = &pmegrids->grid_th[fz];
3537                 oz += pmegrid_g->offset[ZZ];
3538                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3539
3540                 if (sx == 0 && sy == 0 && sz == 0)
3541                 {
3542                     /* We have already added our local contribution
3543                      * before calling this routine, so skip it here.
3544                      */
3545                     continue;
3546                 }
3547
3548                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3549
3550                 pmegrid_f = &pmegrids->grid_th[thread_f];
3551
3552                 grid_th = pmegrid_f->grid;
3553
3554                 nsx = pmegrid_f->s[XX];
3555                 nsy = pmegrid_f->s[YY];
3556                 nsz = pmegrid_f->s[ZZ];
3557
3558 #ifdef DEBUG_PME_REDUCE
3559                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3560                        pme->nodeid,thread,thread_f,
3561                        pme->pmegrid_start_ix,
3562                        pme->pmegrid_start_iy,
3563                        pme->pmegrid_start_iz,
3564                        sx,sy,sz,
3565                        offx-ox,tx1-ox,offx,tx1,
3566                        offy-oy,ty1-oy,offy,ty1,
3567                        offz-oz,tz1-oz,offz,tz1);
3568 #endif
3569
3570                 if (!(bCommX || bCommY))
3571                 {
3572                     /* Copy from the thread local grid to the node grid */
3573                     for(x=offx; x<tx1; x++)
3574                     {
3575                         for(y=offy; y<ty1; y++)
3576                         {
3577                             i0  = (x*fft_my + y)*fft_mz;
3578                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3579                             for(z=offz; z<tz1; z++)
3580                             {
3581                                 fftgrid[i0+z] += grid_th[i0t+z];
3582                             }
3583                         }
3584                     }
3585                 }
3586                 else
3587                 {
3588                     /* The order of this conditional decides
3589                      * where the corner volume gets stored with x+y decomp.
3590                      */
3591                     if (bCommY)
3592                     {
3593                         commbuf = commbuf_y;
3594                         buf_my  = ty1 - offy;
3595                         if (bCommX)
3596                         {
3597                             /* We index commbuf modulo the local grid size */
3598                             commbuf += buf_my*fft_nx*fft_nz;
3599
3600                             bClearBuf  = bClearBufXY;
3601                             bClearBufXY = FALSE;
3602                         }
3603                         else
3604                         {
3605                             bClearBuf  = bClearBufY;
3606                             bClearBufY = FALSE;
3607                         }
3608                     }
3609                     else
3610                     {
3611                         commbuf = commbuf_x;
3612                         buf_my  = fft_ny;
3613                         bClearBuf  = bClearBufX;
3614                         bClearBufX = FALSE;
3615                     }
3616
3617                     /* Copy to the communication buffer */
3618                     for(x=offx; x<tx1; x++)
3619                     {
3620                         for(y=offy; y<ty1; y++)
3621                         {
3622                             i0  = (x*buf_my + y)*fft_nz;
3623                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3624
3625                             if (bClearBuf)
3626                             {
3627                                 /* First access of commbuf, initialize it */
3628                                 for(z=offz; z<tz1; z++)
3629                                 {
3630                                     commbuf[i0+z]  = grid_th[i0t+z];
3631                                 }
3632                             }
3633                             else
3634                             {
3635                                 for(z=offz; z<tz1; z++)
3636                                 {
3637                                     commbuf[i0+z] += grid_th[i0t+z];
3638                                 }
3639                             }
3640                         }
3641                     }
3642                 }
3643             }
3644         }
3645     }
3646 }
3647
3648
3649 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3650 {
3651     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3652     pme_overlap_t *overlap;
3653     int  send_index0,send_nindex;
3654     int  recv_nindex;
3655 #ifdef GMX_MPI
3656     MPI_Status stat;
3657 #endif
3658     int  send_size_y,recv_size_y;
3659     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3660     real *sendptr,*recvptr;
3661     int  x,y,z,indg,indb;
3662
3663     /* Note that this routine is only used for forward communication.
3664      * Since the force gathering, unlike the charge spreading,
3665      * can be trivially parallelized over the particles,
3666      * the backwards process is much simpler and can use the "old"
3667      * communication setup.
3668      */
3669
3670     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3671                                    local_fft_ndata,
3672                                    local_fft_offset,
3673                                    local_fft_size);
3674
3675     if (pme->nnodes_minor > 1)
3676     {
3677         /* Major dimension */
3678         overlap = &pme->overlap[1];
3679
3680         if (pme->nnodes_major > 1)
3681         {
3682              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3683         }
3684         else
3685         {
3686             size_yx = 0;
3687         }
3688         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3689
3690         send_size_y = overlap->send_size;
3691
3692         for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
3693         {
3694             send_id = overlap->send_id[ipulse];
3695             recv_id = overlap->recv_id[ipulse];
3696             send_index0   =
3697                 overlap->comm_data[ipulse].send_index0 -
3698                 overlap->comm_data[0].send_index0;
3699             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3700             /* We don't use recv_index0, as we always receive starting at 0 */
3701             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3702             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3703
3704             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3705             recvptr = overlap->recvbuf;
3706
3707 #ifdef GMX_MPI
3708             MPI_Sendrecv(sendptr,send_size_y*datasize,GMX_MPI_REAL,
3709                          send_id,ipulse,
3710                          recvptr,recv_size_y*datasize,GMX_MPI_REAL,
3711                          recv_id,ipulse,
3712                          overlap->mpi_comm,&stat);
3713 #endif
3714
3715             for(x=0; x<local_fft_ndata[XX]; x++)
3716             {
3717                 for(y=0; y<recv_nindex; y++)
3718                 {
3719                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3720                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3721                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3722                     {
3723                         fftgrid[indg+z] += recvptr[indb+z];
3724                     }
3725                 }
3726             }
3727
3728             if (pme->nnodes_major > 1)
3729             {
3730                 /* Copy from the received buffer to the send buffer for dim 0 */
3731                 sendptr = pme->overlap[0].sendbuf;
3732                 for(x=0; x<size_yx; x++)
3733                 {
3734                     for(y=0; y<recv_nindex; y++)
3735                     {
3736                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3737                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3738                         for(z=0; z<local_fft_ndata[ZZ]; z++)
3739                         {
3740                             sendptr[indg+z] += recvptr[indb+z];
3741                         }
3742                     }
3743                 }
3744             }
3745         }
3746     }
3747
3748     /* We only support a single pulse here.
3749      * This is not a severe limitation, as this code is only used
3750      * with OpenMP and with OpenMP the (PME) domains can be larger.
3751      */
3752     if (pme->nnodes_major > 1)
3753     {
3754         /* Major dimension */
3755         overlap = &pme->overlap[0];
3756
3757         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3758         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3759
3760         ipulse = 0;
3761
3762         send_id = overlap->send_id[ipulse];
3763         recv_id = overlap->recv_id[ipulse];
3764         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3765         /* We don't use recv_index0, as we always receive starting at 0 */
3766         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3767
3768         sendptr = overlap->sendbuf;
3769         recvptr = overlap->recvbuf;
3770
3771         if (debug != NULL)
3772         {
3773             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3774                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3775         }
3776
3777 #ifdef GMX_MPI
3778         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3779                      send_id,ipulse,
3780                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3781                      recv_id,ipulse,
3782                      overlap->mpi_comm,&stat);
3783 #endif
3784
3785         for(x=0; x<recv_nindex; x++)
3786         {
3787             for(y=0; y<local_fft_ndata[YY]; y++)
3788             {
3789                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3790                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3791                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3792                 {
3793                     fftgrid[indg+z] += recvptr[indb+z];
3794                 }
3795             }
3796         }
3797     }
3798 }
3799
3800
3801 static void spread_on_grid(gmx_pme_t pme,
3802                            pme_atomcomm_t *atc,pmegrids_t *grids,
3803                            gmx_bool bCalcSplines,gmx_bool bSpread,
3804                            real *fftgrid)
3805 {
3806     int nthread,thread;
3807 #ifdef PME_TIME_THREADS
3808     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3809     static double cs1=0,cs2=0,cs3=0;
3810     static double cs1a[6]={0,0,0,0,0,0};
3811     static int cnt=0;
3812 #endif
3813
3814     nthread = pme->nthread;
3815     assert(nthread>0);
3816
3817 #ifdef PME_TIME_THREADS
3818     c1 = omp_cyc_start();
3819 #endif
3820     if (bCalcSplines)
3821     {
3822 #pragma omp parallel for num_threads(nthread) schedule(static)
3823         for(thread=0; thread<nthread; thread++)
3824         {
3825             int start,end;
3826
3827             start = atc->n* thread   /nthread;
3828             end   = atc->n*(thread+1)/nthread;
3829
3830             /* Compute fftgrid index for all atoms,
3831              * with help of some extra variables.
3832              */
3833             calc_interpolation_idx(pme,atc,start,end,thread);
3834         }
3835     }
3836 #ifdef PME_TIME_THREADS
3837     c1 = omp_cyc_end(c1);
3838     cs1 += (double)c1;
3839 #endif
3840
3841 #ifdef PME_TIME_THREADS
3842     c2 = omp_cyc_start();
3843 #endif
3844 #pragma omp parallel for num_threads(nthread) schedule(static)
3845     for(thread=0; thread<nthread; thread++)
3846     {
3847         splinedata_t *spline;
3848         pmegrid_t *grid;
3849
3850         /* make local bsplines  */
3851         if (grids == NULL || grids->nthread == 1)
3852         {
3853             spline = &atc->spline[0];
3854
3855             spline->n = atc->n;
3856
3857             grid = &grids->grid;
3858         }
3859         else
3860         {
3861             spline = &atc->spline[thread];
3862
3863             make_thread_local_ind(atc,thread,spline);
3864
3865             grid = &grids->grid_th[thread];
3866         }
3867
3868         if (bCalcSplines)
3869         {
3870             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3871                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3872         }
3873
3874         if (bSpread)
3875         {
3876             /* put local atoms on grid. */
3877 #ifdef PME_TIME_SPREAD
3878             ct1a = omp_cyc_start();
3879 #endif
3880             spread_q_bsplines_thread(grid,atc,spline,pme->spline_work);
3881
3882             if (grids->nthread > 1)
3883             {
3884                 copy_local_grid(pme,grids,thread,fftgrid);
3885             }
3886 #ifdef PME_TIME_SPREAD
3887             ct1a = omp_cyc_end(ct1a);
3888             cs1a[thread] += (double)ct1a;
3889 #endif
3890         }
3891     }
3892 #ifdef PME_TIME_THREADS
3893     c2 = omp_cyc_end(c2);
3894     cs2 += (double)c2;
3895 #endif
3896
3897     if (bSpread && grids->nthread > 1)
3898     {
3899 #ifdef PME_TIME_THREADS
3900         c3 = omp_cyc_start();
3901 #endif
3902 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3903         for(thread=0; thread<grids->nthread; thread++)
3904         {
3905             reduce_threadgrid_overlap(pme,grids,thread,
3906                                       fftgrid,
3907                                       pme->overlap[0].sendbuf,
3908                                       pme->overlap[1].sendbuf);
3909         }
3910 #ifdef PME_TIME_THREADS
3911         c3 = omp_cyc_end(c3);
3912         cs3 += (double)c3;
3913 #endif
3914
3915         if (pme->nnodes > 1)
3916         {
3917             /* Communicate the overlapping part of the fftgrid */
3918             sum_fftgrid_dd(pme,fftgrid);
3919         }
3920     }
3921
3922 #ifdef PME_TIME_THREADS
3923     cnt++;
3924     if (cnt % 20 == 0)
3925     {
3926         printf("idx %.2f spread %.2f red %.2f",
3927                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3928 #ifdef PME_TIME_SPREAD
3929         for(thread=0; thread<nthread; thread++)
3930             printf(" %.2f",cs1a[thread]*1e-9);
3931 #endif
3932         printf("\n");
3933     }
3934 #endif
3935 }
3936
3937
3938 static void dump_grid(FILE *fp,
3939                       int sx,int sy,int sz,int nx,int ny,int nz,
3940                       int my,int mz,const real *g)
3941 {
3942     int x,y,z;
3943
3944     for(x=0; x<nx; x++)
3945     {
3946         for(y=0; y<ny; y++)
3947         {
3948             for(z=0; z<nz; z++)
3949             {
3950                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3951                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3952             }
3953         }
3954     }
3955 }
3956
3957 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3958 {
3959     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3960
3961     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3962                                    local_fft_ndata,
3963                                    local_fft_offset,
3964                                    local_fft_size);
3965
3966     dump_grid(stderr,
3967               pme->pmegrid_start_ix,
3968               pme->pmegrid_start_iy,
3969               pme->pmegrid_start_iz,
3970               pme->pmegrid_nx-pme->pme_order+1,
3971               pme->pmegrid_ny-pme->pme_order+1,
3972               pme->pmegrid_nz-pme->pme_order+1,
3973               local_fft_size[YY],
3974               local_fft_size[ZZ],
3975               fftgrid);
3976 }
3977
3978
3979 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3980 {
3981     pme_atomcomm_t *atc;
3982     pmegrids_t *grid;
3983
3984     if (pme->nnodes > 1)
3985     {
3986         gmx_incons("gmx_pme_calc_energy called in parallel");
3987     }
3988     if (pme->bFEP > 1)
3989     {
3990         gmx_incons("gmx_pme_calc_energy with free energy");
3991     }
3992
3993     atc = &pme->atc_energy;
3994     atc->nthread   = 1;
3995     if (atc->spline == NULL)
3996     {
3997         snew(atc->spline,atc->nthread);
3998     }
3999     atc->nslab     = 1;
4000     atc->bSpread   = TRUE;
4001     atc->pme_order = pme->pme_order;
4002     atc->n         = n;
4003     pme_realloc_atomcomm_things(atc);
4004     atc->x         = x;
4005     atc->q         = q;
4006
4007     /* We only use the A-charges grid */
4008     grid = &pme->pmegridA;
4009
4010     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
4011
4012     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
4013 }
4014
4015
4016 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
4017         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
4018 {
4019     /* Reset all the counters related to performance over the run */
4020     wallcycle_stop(wcycle,ewcRUN);
4021     wallcycle_reset_all(wcycle);
4022     init_nrnb(nrnb);
4023     ir->init_step += step_rel;
4024     ir->nsteps    -= step_rel;
4025     wallcycle_start(wcycle,ewcRUN);
4026 }
4027
4028
4029 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4030                                ivec grid_size,
4031                                t_commrec *cr, t_inputrec *ir,
4032                                gmx_pme_t *pme_ret)
4033 {
4034     int ind;
4035     gmx_pme_t pme = NULL;
4036
4037     ind = 0;
4038     while (ind < *npmedata)
4039     {
4040         pme = (*pmedata)[ind];
4041         if (pme->nkx == grid_size[XX] &&
4042             pme->nky == grid_size[YY] &&
4043             pme->nkz == grid_size[ZZ])
4044         {
4045             *pme_ret = pme;
4046
4047             return;
4048         }
4049
4050         ind++;
4051     }
4052
4053     (*npmedata)++;
4054     srenew(*pmedata,*npmedata);
4055
4056     /* Generate a new PME data structure, copying part of the old pointers */
4057     gmx_pme_reinit(&((*pmedata)[ind]),cr,pme,ir,grid_size);
4058
4059     *pme_ret = (*pmedata)[ind];
4060 }
4061
4062
4063 int gmx_pmeonly(gmx_pme_t pme,
4064                 t_commrec *cr,    t_nrnb *nrnb,
4065                 gmx_wallcycle_t wcycle,
4066                 real ewaldcoeff,  gmx_bool bGatherOnly,
4067                 t_inputrec *ir)
4068 {
4069     int npmedata;
4070     gmx_pme_t *pmedata;
4071     gmx_pme_pp_t pme_pp;
4072     int  natoms;
4073     matrix box;
4074     rvec *x_pp=NULL,*f_pp=NULL;
4075     real *chargeA=NULL,*chargeB=NULL;
4076     real lambda=0;
4077     int  maxshift_x=0,maxshift_y=0;
4078     real energy,dvdlambda;
4079     matrix vir;
4080     float cycles;
4081     int  count;
4082     gmx_bool bEnerVir;
4083     gmx_large_int_t step,step_rel;
4084     ivec grid_switch;
4085
4086     /* This data will only use with PME tuning, i.e. switching PME grids */
4087     npmedata = 1;
4088     snew(pmedata,npmedata);
4089     pmedata[0] = pme;
4090
4091     pme_pp = gmx_pme_pp_init(cr);
4092
4093     init_nrnb(nrnb);
4094
4095     count = 0;
4096     do /****** this is a quasi-loop over time steps! */
4097     {
4098         /* The reason for having a loop here is PME grid tuning/switching */
4099         do
4100         {
4101             /* Domain decomposition */
4102             natoms = gmx_pme_recv_q_x(pme_pp,
4103                                       &chargeA,&chargeB,box,&x_pp,&f_pp,
4104                                       &maxshift_x,&maxshift_y,
4105                                       &pme->bFEP,&lambda,
4106                                       &bEnerVir,
4107                                       &step,
4108                                       grid_switch,&ewaldcoeff);
4109
4110             if (natoms == -2)
4111             {
4112                 /* Switch the PME grid to grid_switch */
4113                 gmx_pmeonly_switch(&npmedata,&pmedata,grid_switch,cr,ir,&pme);
4114             }
4115         }
4116         while (natoms == -2);
4117
4118         if (natoms == -1)
4119         {
4120             /* We should stop: break out of the loop */
4121             break;
4122         }
4123
4124         step_rel = step - ir->init_step;
4125
4126         if (count == 0)
4127             wallcycle_start(wcycle,ewcRUN);
4128
4129         wallcycle_start(wcycle,ewcPMEMESH);
4130
4131         dvdlambda = 0;
4132         clear_mat(vir);
4133         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
4134                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
4135                    &energy,lambda,&dvdlambda,
4136                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4137
4138         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
4139
4140         gmx_pme_send_force_vir_ener(pme_pp,
4141                                     f_pp,vir,energy,dvdlambda,
4142                                     cycles);
4143
4144         count++;
4145
4146         if (step_rel == wcycle_get_reset_counters(wcycle))
4147         {
4148             /* Reset all the counters related to performance over the run */
4149             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4150             wcycle_set_reset_counters(wcycle, 0);
4151         }
4152
4153     } /***** end of quasi-loop, we stop with the break above */
4154     while (TRUE);
4155
4156     return 0;
4157 }
4158
4159 int gmx_pme_do(gmx_pme_t pme,
4160                int start,       int homenr,
4161                rvec x[],        rvec f[],
4162                real *chargeA,   real *chargeB,
4163                matrix box, t_commrec *cr,
4164                int  maxshift_x, int maxshift_y,
4165                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4166                matrix vir,      real ewaldcoeff,
4167                real *energy,    real lambda,
4168                real *dvdlambda, int flags)
4169 {
4170     int     q,d,i,j,ntot,npme;
4171     int     nx,ny,nz;
4172     int     n_d,local_ny;
4173     pme_atomcomm_t *atc=NULL;
4174     pmegrids_t *pmegrid=NULL;
4175     real    *grid=NULL;
4176     real    *ptr;
4177     rvec    *x_d,*f_d;
4178     real    *charge=NULL,*q_d;
4179     real    energy_AB[2];
4180     matrix  vir_AB[2];
4181     gmx_bool bClearF;
4182     gmx_parallel_3dfft_t pfft_setup;
4183     real *  fftgrid;
4184     t_complex * cfftgrid;
4185     int     thread;
4186     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4187     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4188
4189     assert(pme->nnodes > 0);
4190     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4191
4192     if (pme->nnodes > 1) {
4193         atc = &pme->atc[0];
4194         atc->npd = homenr;
4195         if (atc->npd > atc->pd_nalloc) {
4196             atc->pd_nalloc = over_alloc_dd(atc->npd);
4197             srenew(atc->pd,atc->pd_nalloc);
4198         }
4199         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4200     }
4201     else
4202     {
4203         /* This could be necessary for TPI */
4204         pme->atc[0].n = homenr;
4205     }
4206
4207     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4208         if (q == 0) {
4209             pmegrid = &pme->pmegridA;
4210             fftgrid = pme->fftgridA;
4211             cfftgrid = pme->cfftgridA;
4212             pfft_setup = pme->pfft_setupA;
4213             charge = chargeA+start;
4214         } else {
4215             pmegrid = &pme->pmegridB;
4216             fftgrid = pme->fftgridB;
4217             cfftgrid = pme->cfftgridB;
4218             pfft_setup = pme->pfft_setupB;
4219             charge = chargeB+start;
4220         }
4221         grid = pmegrid->grid.grid;
4222         /* Unpack structure */
4223         if (debug) {
4224             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4225                     cr->nnodes,cr->nodeid);
4226             fprintf(debug,"Grid = %p\n",(void*)grid);
4227             if (grid == NULL)
4228                 gmx_fatal(FARGS,"No grid!");
4229         }
4230         where();
4231
4232         m_inv_ur0(box,pme->recipbox);
4233
4234         if (pme->nnodes == 1) {
4235             atc = &pme->atc[0];
4236             if (DOMAINDECOMP(cr)) {
4237                 atc->n = homenr;
4238                 pme_realloc_atomcomm_things(atc);
4239             }
4240             atc->x = x;
4241             atc->q = charge;
4242             atc->f = f;
4243         } else {
4244             wallcycle_start(wcycle,ewcPME_REDISTXF);
4245             for(d=pme->ndecompdim-1; d>=0; d--)
4246             {
4247                 if (d == pme->ndecompdim-1)
4248                 {
4249                     n_d = homenr;
4250                     x_d = x + start;
4251                     q_d = charge;
4252                 }
4253                 else
4254                 {
4255                     n_d = pme->atc[d+1].n;
4256                     x_d = atc->x;
4257                     q_d = atc->q;
4258                 }
4259                 atc = &pme->atc[d];
4260                 atc->npd = n_d;
4261                 if (atc->npd > atc->pd_nalloc) {
4262                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4263                     srenew(atc->pd,atc->pd_nalloc);
4264                 }
4265                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4266                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4267                 where();
4268
4269                 GMX_BARRIER(cr->mpi_comm_mygroup);
4270                 /* Redistribute x (only once) and qA or qB */
4271                 if (DOMAINDECOMP(cr)) {
4272                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4273                 } else {
4274                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4275                 }
4276             }
4277             where();
4278
4279             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4280         }
4281
4282         if (debug)
4283             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4284                     cr->nodeid,atc->n);
4285
4286         if (flags & GMX_PME_SPREAD_Q)
4287         {
4288             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4289
4290             /* Spread the charges on a grid */
4291             GMX_MPE_LOG(ev_spread_on_grid_start);
4292
4293             /* Spread the charges on a grid */
4294             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4295             GMX_MPE_LOG(ev_spread_on_grid_finish);
4296
4297             if (q == 0)
4298             {
4299                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4300             }
4301             inc_nrnb(nrnb,eNR_SPREADQBSP,
4302                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4303
4304             if (pme->nthread == 1)
4305             {
4306                 wrap_periodic_pmegrid(pme,grid);
4307
4308                 /* sum contributions to local grid from other nodes */
4309 #ifdef GMX_MPI
4310                 if (pme->nnodes > 1)
4311                 {
4312                     GMX_BARRIER(cr->mpi_comm_mygroup);
4313                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4314                     where();
4315                 }
4316 #endif
4317
4318                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4319             }
4320
4321             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4322
4323             /*
4324             dump_local_fftgrid(pme,fftgrid);
4325             exit(0);
4326             */
4327         }
4328
4329         /* Here we start a large thread parallel region */
4330 #pragma omp parallel num_threads(pme->nthread) private(thread)
4331         {
4332             thread=gmx_omp_get_thread_num();
4333             if (flags & GMX_PME_SOLVE)
4334             {
4335                 int loop_count;
4336
4337                 /* do 3d-fft */
4338                 if (thread == 0)
4339                 {
4340                     GMX_BARRIER(cr->mpi_comm_mygroup);
4341                     GMX_MPE_LOG(ev_gmxfft3d_start);
4342                     wallcycle_start(wcycle,ewcPME_FFT);
4343                 }
4344                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4345                                            fftgrid,cfftgrid,thread,wcycle);
4346                 if (thread == 0)
4347                 {
4348                     wallcycle_stop(wcycle,ewcPME_FFT);
4349                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4350                 }
4351                 where();
4352
4353                 /* solve in k-space for our local cells */
4354                 if (thread == 0)
4355                 {
4356                     GMX_BARRIER(cr->mpi_comm_mygroup);
4357                     GMX_MPE_LOG(ev_solve_pme_start);
4358                     wallcycle_start(wcycle,ewcPME_SOLVE);
4359                 }
4360                 loop_count =
4361                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4362                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4363                                   bCalcEnerVir,
4364                                   pme->nthread,thread);
4365                 if (thread == 0)
4366                 {
4367                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4368                     where();
4369                     GMX_MPE_LOG(ev_solve_pme_finish);
4370                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4371                 }
4372             }
4373
4374             if (bCalcF)
4375             {
4376                 /* do 3d-invfft */
4377                 if (thread == 0)
4378                 {
4379                     GMX_BARRIER(cr->mpi_comm_mygroup);
4380                     GMX_MPE_LOG(ev_gmxfft3d_start);
4381                     where();
4382                     wallcycle_start(wcycle,ewcPME_FFT);
4383                 }
4384                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4385                                            cfftgrid,fftgrid,thread,wcycle);
4386                 if (thread == 0)
4387                 {
4388                     wallcycle_stop(wcycle,ewcPME_FFT);
4389                     
4390                     where();
4391                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4392
4393                     if (pme->nodeid == 0)
4394                     {
4395                         ntot = pme->nkx*pme->nky*pme->nkz;
4396                         npme  = ntot*log((real)ntot)/log(2.0);
4397                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4398                     }
4399
4400                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4401                 }
4402
4403                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4404             }
4405         }
4406         /* End of thread parallel section.
4407          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4408          */
4409
4410         if (bCalcF)
4411         {
4412             /* distribute local grid to all nodes */
4413 #ifdef GMX_MPI
4414             if (pme->nnodes > 1) {
4415                 GMX_BARRIER(cr->mpi_comm_mygroup);
4416                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4417             }
4418 #endif
4419             where();
4420
4421             unwrap_periodic_pmegrid(pme,grid);
4422
4423             /* interpolate forces for our local atoms */
4424             GMX_BARRIER(cr->mpi_comm_mygroup);
4425             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4426
4427             where();
4428
4429             /* If we are running without parallelization,
4430              * atc->f is the actual force array, not a buffer,
4431              * therefore we should not clear it.
4432              */
4433             bClearF = (q == 0 && PAR(cr));
4434 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4435             for(thread=0; thread<pme->nthread; thread++)
4436             {
4437                 gather_f_bsplines(pme,grid,bClearF,atc,
4438                                   &atc->spline[thread],
4439                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4440             }
4441
4442             where();
4443
4444             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4445
4446             inc_nrnb(nrnb,eNR_GATHERFBSP,
4447                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4448             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4449         }
4450
4451         if (bCalcEnerVir)
4452         {
4453             /* This should only be called on the master thread
4454              * and after the threads have synchronized.
4455              */
4456             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4457         }
4458     } /* of q-loop */
4459
4460     if (bCalcF && pme->nnodes > 1) {
4461         wallcycle_start(wcycle,ewcPME_REDISTXF);
4462         for(d=0; d<pme->ndecompdim; d++)
4463         {
4464             atc = &pme->atc[d];
4465             if (d == pme->ndecompdim - 1)
4466             {
4467                 n_d = homenr;
4468                 f_d = f + start;
4469             }
4470             else
4471             {
4472                 n_d = pme->atc[d+1].n;
4473                 f_d = pme->atc[d+1].f;
4474             }
4475             GMX_BARRIER(cr->mpi_comm_mygroup);
4476             if (DOMAINDECOMP(cr)) {
4477                 dd_pmeredist_f(pme,atc,n_d,f_d,
4478                                d==pme->ndecompdim-1 && pme->bPPnode);
4479             } else {
4480                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4481             }
4482         }
4483
4484         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4485     }
4486     where();
4487
4488     if (bCalcEnerVir)
4489     {
4490         if (!pme->bFEP) {
4491             *energy = energy_AB[0];
4492             m_add(vir,vir_AB[0],vir);
4493         } else {
4494             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4495             *dvdlambda += energy_AB[1] - energy_AB[0];
4496             for(i=0; i<DIM; i++)
4497             {
4498                 for(j=0; j<DIM; j++)
4499                 {
4500                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4501                         lambda*vir_AB[1][i][j];
4502                 }
4503             }
4504         }
4505     }
4506     else
4507     {
4508         *energy = 0;
4509     }
4510
4511     if (debug)
4512     {
4513         fprintf(debug,"PME mesh energy: %g\n",*energy);
4514     }
4515
4516     return 0;
4517 }