introduce gmx_omp wrapper for the OpenMP API
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include "gmx_omp.h"
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93
94 /* Single precision, with SSE2 or higher available */
95 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
96
97 #include "gmx_x86_sse2.h"
98 #include "gmx_math_x86_sse2_single.h"
99
100 #define PME_SSE
101 /* Some old AMD processors could have problems with unaligned loads+stores */
102 #ifndef GMX_FAHCORE
103 #define PME_SSE_UNALIGNED
104 #endif
105 #endif
106
107 #include "mpelogging.h"
108
109 #define DFT_TOL 1e-7
110 /* #define PRT_FORCE */
111 /* conditions for on the fly time-measurement */
112 /* #define TAKETIME (step > 1 && timesteps < 10) */
113 #define TAKETIME FALSE
114
115 /* #define PME_TIME_THREADS */
116
117 #ifdef GMX_DOUBLE
118 #define mpi_type MPI_DOUBLE
119 #else
120 #define mpi_type MPI_FLOAT
121 #endif
122
123 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
124 #define GMX_CACHE_SEP 64
125
126 /* We only define a maximum to be able to use local arrays without allocation.
127  * An order larger than 12 should never be needed, even for test cases.
128  * If needed it can be changed here.
129  */
130 #define PME_ORDER_MAX 12
131
132 /* Internal datastructures */
133 typedef struct {
134     int send_index0;
135     int send_nindex;
136     int recv_index0;
137     int recv_nindex;
138 } pme_grid_comm_t;
139
140 typedef struct {
141 #ifdef GMX_MPI
142     MPI_Comm mpi_comm;
143 #endif
144     int  nnodes,nodeid;
145     int  *s2g0;
146     int  *s2g1;
147     int  noverlap_nodes;
148     int  *send_id,*recv_id;
149     pme_grid_comm_t *comm_data;
150     real *sendbuf;
151     real *recvbuf;
152 } pme_overlap_t;
153
154 typedef struct {
155     int *n;     /* Cumulative counts of the number of particles per thread */
156     int nalloc; /* Allocation size of i */
157     int *i;     /* Particle indices ordered on thread index (n) */
158 } thread_plist_t;
159
160 typedef struct {
161     int  n;
162     int  *ind;
163     splinevec theta;
164     splinevec dtheta;
165 } splinedata_t;
166
167 typedef struct {
168     int  dimind;            /* The index of the dimension, 0=x, 1=y */
169     int  nslab;
170     int  nodeid;
171 #ifdef GMX_MPI
172     MPI_Comm mpi_comm;
173 #endif
174
175     int  *node_dest;        /* The nodes to send x and q to with DD */
176     int  *node_src;         /* The nodes to receive x and q from with DD */
177     int  *buf_index;        /* Index for commnode into the buffers */
178
179     int  maxshift;
180
181     int  npd;
182     int  pd_nalloc;
183     int  *pd;
184     int  *count;            /* The number of atoms to send to each node */
185     int  **count_thread;
186     int  *rcount;           /* The number of atoms to receive */
187
188     int  n;
189     int  nalloc;
190     rvec *x;
191     real *q;
192     rvec *f;
193     gmx_bool bSpread;       /* These coordinates are used for spreading */
194     int  pme_order;
195     ivec *idx;
196     rvec *fractx;            /* Fractional coordinate relative to the
197                               * lower cell boundary
198                               */
199     int  nthread;
200     int  *thread_idx;        /* Which thread should spread which charge */
201     thread_plist_t *thread_plist;
202     splinedata_t *spline;
203 } pme_atomcomm_t;
204
205 #define FLBS  3
206 #define FLBSZ 4
207
208 typedef struct {
209     ivec ci;     /* The spatial location of this grid       */
210     ivec n;      /* The size of *grid, including order-1    */
211     ivec offset; /* The grid offset from the full node grid */
212     int  order;  /* PME spreading order                     */
213     real *grid;  /* The grid local thread, size n           */
214 } pmegrid_t;
215
216 typedef struct {
217     pmegrid_t grid;     /* The full node grid (non thread-local)            */
218     int  nthread;       /* The number of threads operating on this grid     */
219     ivec nc;            /* The local spatial decomposition over the threads */
220     pmegrid_t *grid_th; /* Array of grids for each thread                   */
221     int  **g2t;         /* The grid to thread index                         */
222     ivec nthread_comm;  /* The number of threads to communicate with        */
223 } pmegrids_t;
224
225
226 typedef struct {
227 #ifdef PME_SSE
228     /* Masks for SSE aligned spreading and gathering */
229     __m128 mask_SSE0[6],mask_SSE1[6];
230 #else
231     int dummy; /* C89 requires that struct has at least one member */
232 #endif
233 } pme_spline_work_t;
234
235 typedef struct {
236     /* work data for solve_pme */
237     int      nalloc;
238     real *   mhx;
239     real *   mhy;
240     real *   mhz;
241     real *   m2;
242     real *   denom;
243     real *   tmp1_alloc;
244     real *   tmp1;
245     real *   eterm;
246     real *   m2inv;
247
248     real     energy;
249     matrix   vir;
250 } pme_work_t;
251
252 typedef struct gmx_pme {
253     int  ndecompdim;         /* The number of decomposition dimensions */
254     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
255     int  nodeid_major;
256     int  nodeid_minor;
257     int  nnodes;             /* The number of nodes doing PME */
258     int  nnodes_major;
259     int  nnodes_minor;
260
261     MPI_Comm mpi_comm;
262     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
263 #ifdef GMX_MPI
264     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
265 #endif
266
267     int  nthread;            /* The number of threads doing PME */
268
269     gmx_bool bPPnode;        /* Node also does particle-particle forces */
270     gmx_bool bFEP;           /* Compute Free energy contribution */
271     int nkx,nky,nkz;         /* Grid dimensions */
272     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
273     int pme_order;
274     real epsilon_r;
275
276     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
277     pmegrids_t pmegridB;
278     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
279     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
280     /* pmegrid_nz might be larger than strictly necessary to ensure
281      * memory alignment, pmegrid_nz_base gives the real base size.
282      */
283     int     pmegrid_nz_base;
284     /* The local PME grid starting indices */
285     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
286
287     /* Work data for spreading and gathering */
288     pme_spline_work_t spline_work;
289
290     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
291     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
292     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
293
294     t_complex *cfftgridA;             /* Grids for complex FFT data */
295     t_complex *cfftgridB;
296     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
297
298     gmx_parallel_3dfft_t  pfft_setupA;
299     gmx_parallel_3dfft_t  pfft_setupB;
300
301     int  *nnx,*nny,*nnz;
302     real *fshx,*fshy,*fshz;
303
304     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
305     matrix    recipbox;
306     splinevec bsp_mod;
307
308     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
309
310     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
311
312     rvec *bufv;             /* Communication buffer */
313     real *bufr;             /* Communication buffer */
314     int  buf_nalloc;        /* The communication buffer size */
315
316     /* thread local work data for solve_pme */
317     pme_work_t *work;
318
319     /* Work data for PME_redist */
320     gmx_bool redist_init;
321     int *    scounts;
322     int *    rcounts;
323     int *    sdispls;
324     int *    rdispls;
325     int *    sidx;
326     int *    idxa;
327     real *   redist_buf;
328     int      redist_buf_nalloc;
329
330     /* Work data for sum_qgrid */
331     real *   sum_qgrid_tmp;
332     real *   sum_qgrid_dd_tmp;
333 } t_gmx_pme;
334
335
336 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
337                                    int start,int end,int thread)
338 {
339     int  i;
340     int  *idxptr,tix,tiy,tiz;
341     real *xptr,*fptr,tx,ty,tz;
342     real rxx,ryx,ryy,rzx,rzy,rzz;
343     int  nx,ny,nz;
344     int  start_ix,start_iy,start_iz;
345     int  *g2tx,*g2ty,*g2tz;
346     gmx_bool bThreads;
347     int  *thread_idx=NULL;
348     thread_plist_t *tpl=NULL;
349     int  *tpl_n=NULL;
350     int  thread_i;
351
352     nx  = pme->nkx;
353     ny  = pme->nky;
354     nz  = pme->nkz;
355
356     start_ix = pme->pmegrid_start_ix;
357     start_iy = pme->pmegrid_start_iy;
358     start_iz = pme->pmegrid_start_iz;
359
360     rxx = pme->recipbox[XX][XX];
361     ryx = pme->recipbox[YY][XX];
362     ryy = pme->recipbox[YY][YY];
363     rzx = pme->recipbox[ZZ][XX];
364     rzy = pme->recipbox[ZZ][YY];
365     rzz = pme->recipbox[ZZ][ZZ];
366
367     g2tx = pme->pmegridA.g2t[XX];
368     g2ty = pme->pmegridA.g2t[YY];
369     g2tz = pme->pmegridA.g2t[ZZ];
370
371     bThreads = (atc->nthread > 1);
372     if (bThreads)
373     {
374         thread_idx = atc->thread_idx;
375
376         tpl   = &atc->thread_plist[thread];
377         tpl_n = tpl->n;
378         for(i=0; i<atc->nthread; i++)
379         {
380             tpl_n[i] = 0;
381         }
382     }
383
384     for(i=start; i<end; i++) {
385         xptr   = atc->x[i];
386         idxptr = atc->idx[i];
387         fptr   = atc->fractx[i];
388
389         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
390         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
391         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
392         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
393
394         tix = (int)(tx);
395         tiy = (int)(ty);
396         tiz = (int)(tz);
397
398         /* Because decomposition only occurs in x and y,
399          * we never have a fraction correction in z.
400          */
401         fptr[XX] = tx - tix + pme->fshx[tix];
402         fptr[YY] = ty - tiy + pme->fshy[tiy];
403         fptr[ZZ] = tz - tiz;
404
405         idxptr[XX] = pme->nnx[tix];
406         idxptr[YY] = pme->nny[tiy];
407         idxptr[ZZ] = pme->nnz[tiz];
408
409 #ifdef DEBUG
410         range_check(idxptr[XX],0,pme->pmegrid_nx);
411         range_check(idxptr[YY],0,pme->pmegrid_ny);
412         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
413 #endif
414
415         if (bThreads)
416         {
417             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
418             thread_idx[i] = thread_i;
419             tpl_n[thread_i]++;
420         }
421     }
422
423     if (bThreads)
424     {
425         /* Make a list of particle indices sorted on thread */
426
427         /* Get the cumulative count */
428         for(i=1; i<atc->nthread; i++)
429         {
430             tpl_n[i] += tpl_n[i-1];
431         }
432         /* The current implementation distributes particles equally
433          * over the threads, so we could actually allocate for that
434          * in pme_realloc_atomcomm_things.
435          */
436         if (tpl_n[atc->nthread-1] > tpl->nalloc)
437         {
438             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
439             srenew(tpl->i,tpl->nalloc);
440         }
441         /* Set tpl_n to the cumulative start */
442         for(i=atc->nthread-1; i>=1; i--)
443         {
444             tpl_n[i] = tpl_n[i-1];
445         }
446         tpl_n[0] = 0;
447
448         /* Fill our thread local array with indices sorted on thread */
449         for(i=start; i<end; i++)
450         {
451             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
452         }
453         /* Now tpl_n contains the cummulative count again */
454     }
455 }
456
457 static void make_thread_local_ind(pme_atomcomm_t *atc,
458                                   int thread,splinedata_t *spline)
459 {
460     int  n,t,i,start,end;
461     thread_plist_t *tpl;
462
463     /* Combine the indices made by each thread into one index */
464
465     n = 0;
466     start = 0;
467     for(t=0; t<atc->nthread; t++)
468     {
469         tpl = &atc->thread_plist[t];
470         /* Copy our part (start - end) from the list of thread t */
471         if (thread > 0)
472         {
473             start = tpl->n[thread-1];
474         }
475         end = tpl->n[thread];
476         for(i=start; i<end; i++)
477         {
478             spline->ind[n++] = tpl->i[i];
479         }
480     }
481
482     spline->n = n;
483 }
484
485
486 static void pme_calc_pidx(int start, int end,
487                           matrix recipbox, rvec x[],
488                           pme_atomcomm_t *atc, int *count)
489 {
490     int  nslab,i;
491     int  si;
492     real *xptr,s;
493     real rxx,ryx,rzx,ryy,rzy;
494     int *pd;
495
496     /* Calculate PME task index (pidx) for each grid index.
497      * Here we always assign equally sized slabs to each node
498      * for load balancing reasons (the PME grid spacing is not used).
499      */
500
501     nslab = atc->nslab;
502     pd    = atc->pd;
503
504     /* Reset the count */
505     for(i=0; i<nslab; i++)
506     {
507         count[i] = 0;
508     }
509
510     if (atc->dimind == 0)
511     {
512         rxx = recipbox[XX][XX];
513         ryx = recipbox[YY][XX];
514         rzx = recipbox[ZZ][XX];
515         /* Calculate the node index in x-dimension */
516         for(i=start; i<end; i++)
517         {
518             xptr   = x[i];
519             /* Fractional coordinates along box vectors */
520             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
521             si = (int)(s + 2*nslab) % nslab;
522             pd[i] = si;
523             count[si]++;
524         }
525     }
526     else
527     {
528         ryy = recipbox[YY][YY];
529         rzy = recipbox[ZZ][YY];
530         /* Calculate the node index in y-dimension */
531         for(i=start; i<end; i++)
532         {
533             xptr   = x[i];
534             /* Fractional coordinates along box vectors */
535             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
536             si = (int)(s + 2*nslab) % nslab;
537             pd[i] = si;
538             count[si]++;
539         }
540     }
541 }
542
543 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
544                                   pme_atomcomm_t *atc)
545 {
546     int nthread,thread,slab;
547
548     nthread = atc->nthread;
549
550 #pragma omp parallel for num_threads(nthread) schedule(static)
551     for(thread=0; thread<nthread; thread++)
552     {
553         pme_calc_pidx(natoms* thread   /nthread,
554                       natoms*(thread+1)/nthread,
555                       recipbox,x,atc,atc->count_thread[thread]);
556     }
557     /* Non-parallel reduction, since nslab is small */
558
559     for(thread=1; thread<nthread; thread++)
560     {
561         for(slab=0; slab<atc->nslab; slab++)
562         {
563             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
564         }
565     }
566 }
567
568 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
569 {
570     int i,d;
571
572     srenew(spline->ind,atc->nalloc);
573     /* Initialize the index to identity so it works without threads */
574     for(i=0; i<atc->nalloc; i++)
575     {
576         spline->ind[i] = i;
577     }
578
579     for(d=0;d<DIM;d++)
580     {
581         srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
582         srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
583     }
584 }
585
586 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
587 {
588     int nalloc_old,i,j,nalloc_tpl;
589
590     /* We have to avoid a NULL pointer for atc->x to avoid
591      * possible fatal errors in MPI routines.
592      */
593     if (atc->n > atc->nalloc || atc->nalloc == 0)
594     {
595         nalloc_old = atc->nalloc;
596         atc->nalloc = over_alloc_dd(max(atc->n,1));
597
598         if (atc->nslab > 1) {
599             srenew(atc->x,atc->nalloc);
600             srenew(atc->q,atc->nalloc);
601             srenew(atc->f,atc->nalloc);
602             for(i=nalloc_old; i<atc->nalloc; i++)
603             {
604                 clear_rvec(atc->f[i]);
605             }
606         }
607         if (atc->bSpread) {
608             srenew(atc->fractx,atc->nalloc);
609             srenew(atc->idx   ,atc->nalloc);
610
611             if (atc->nthread > 1)
612             {
613                 srenew(atc->thread_idx,atc->nalloc);
614             }
615
616             for(i=0; i<atc->nthread; i++)
617             {
618                 pme_realloc_splinedata(&atc->spline[i],atc);
619             }
620         }
621     }
622 }
623
624 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
625                          int n, gmx_bool bXF, rvec *x_f, real *charge,
626                          pme_atomcomm_t *atc)
627 /* Redistribute particle data for PME calculation */
628 /* domain decomposition by x coordinate           */
629 {
630     int *idxa;
631     int i, ii;
632
633     if(FALSE == pme->redist_init) {
634         snew(pme->scounts,atc->nslab);
635         snew(pme->rcounts,atc->nslab);
636         snew(pme->sdispls,atc->nslab);
637         snew(pme->rdispls,atc->nslab);
638         snew(pme->sidx,atc->nslab);
639         pme->redist_init = TRUE;
640     }
641     if (n > pme->redist_buf_nalloc) {
642         pme->redist_buf_nalloc = over_alloc_dd(n);
643         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
644     }
645
646     pme->idxa = atc->pd;
647
648 #ifdef GMX_MPI
649     if (forw && bXF) {
650         /* forward, redistribution from pp to pme */
651
652         /* Calculate send counts and exchange them with other nodes */
653         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
654         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
655         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
656
657         /* Calculate send and receive displacements and index into send
658            buffer */
659         pme->sdispls[0]=0;
660         pme->rdispls[0]=0;
661         pme->sidx[0]=0;
662         for(i=1; i<atc->nslab; i++) {
663             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
664             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
665             pme->sidx[i]=pme->sdispls[i];
666         }
667         /* Total # of particles to be received */
668         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
669
670         pme_realloc_atomcomm_things(atc);
671
672         /* Copy particle coordinates into send buffer and exchange*/
673         for(i=0; (i<n); i++) {
674             ii=DIM*pme->sidx[pme->idxa[i]];
675             pme->sidx[pme->idxa[i]]++;
676             pme->redist_buf[ii+XX]=x_f[i][XX];
677             pme->redist_buf[ii+YY]=x_f[i][YY];
678             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
679         }
680         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
681                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
682                       pme->rvec_mpi, atc->mpi_comm);
683     }
684     if (forw) {
685         /* Copy charge into send buffer and exchange*/
686         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
687         for(i=0; (i<n); i++) {
688             ii=pme->sidx[pme->idxa[i]];
689             pme->sidx[pme->idxa[i]]++;
690             pme->redist_buf[ii]=charge[i];
691         }
692         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
693                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
694                       atc->mpi_comm);
695     }
696     else { /* backward, redistribution from pme to pp */
697         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
698                       pme->redist_buf, pme->scounts, pme->sdispls,
699                       pme->rvec_mpi, atc->mpi_comm);
700
701         /* Copy data from receive buffer */
702         for(i=0; i<atc->nslab; i++)
703             pme->sidx[i] = pme->sdispls[i];
704         for(i=0; (i<n); i++) {
705             ii = DIM*pme->sidx[pme->idxa[i]];
706             x_f[i][XX] += pme->redist_buf[ii+XX];
707             x_f[i][YY] += pme->redist_buf[ii+YY];
708             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
709             pme->sidx[pme->idxa[i]]++;
710         }
711     }
712 #endif
713 }
714
715 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
716                             gmx_bool bBackward,int shift,
717                             void *buf_s,int nbyte_s,
718                             void *buf_r,int nbyte_r)
719 {
720 #ifdef GMX_MPI
721     int dest,src;
722     MPI_Status stat;
723
724     if (bBackward == FALSE) {
725         dest = atc->node_dest[shift];
726         src  = atc->node_src[shift];
727     } else {
728         dest = atc->node_src[shift];
729         src  = atc->node_dest[shift];
730     }
731
732     if (nbyte_s > 0 && nbyte_r > 0) {
733         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
734                      dest,shift,
735                      buf_r,nbyte_r,MPI_BYTE,
736                      src,shift,
737                      atc->mpi_comm,&stat);
738     } else if (nbyte_s > 0) {
739         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
740                  dest,shift,
741                  atc->mpi_comm);
742     } else if (nbyte_r > 0) {
743         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
744                  src,shift,
745                  atc->mpi_comm,&stat);
746     }
747 #endif
748 }
749
750 static void dd_pmeredist_x_q(gmx_pme_t pme,
751                              int n, gmx_bool bX, rvec *x, real *charge,
752                              pme_atomcomm_t *atc)
753 {
754     int *commnode,*buf_index;
755     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
756
757     commnode  = atc->node_dest;
758     buf_index = atc->buf_index;
759
760     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
761
762     nsend = 0;
763     for(i=0; i<nnodes_comm; i++) {
764         buf_index[commnode[i]] = nsend;
765         nsend += atc->count[commnode[i]];
766     }
767     if (bX) {
768         if (atc->count[atc->nodeid] + nsend != n)
769             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
770                       "This usually means that your system is not well equilibrated.",
771                       n - (atc->count[atc->nodeid] + nsend),
772                       pme->nodeid,'x'+atc->dimind);
773
774         if (nsend > pme->buf_nalloc) {
775             pme->buf_nalloc = over_alloc_dd(nsend);
776             srenew(pme->bufv,pme->buf_nalloc);
777             srenew(pme->bufr,pme->buf_nalloc);
778         }
779
780         atc->n = atc->count[atc->nodeid];
781         for(i=0; i<nnodes_comm; i++) {
782             scount = atc->count[commnode[i]];
783             /* Communicate the count */
784             if (debug)
785                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
786                         atc->dimind,atc->nodeid,commnode[i],scount);
787             pme_dd_sendrecv(atc,FALSE,i,
788                             &scount,sizeof(int),
789                             &atc->rcount[i],sizeof(int));
790             atc->n += atc->rcount[i];
791         }
792
793         pme_realloc_atomcomm_things(atc);
794     }
795
796     local_pos = 0;
797     for(i=0; i<n; i++) {
798         node = atc->pd[i];
799         if (node == atc->nodeid) {
800             /* Copy direct to the receive buffer */
801             if (bX) {
802                 copy_rvec(x[i],atc->x[local_pos]);
803             }
804             atc->q[local_pos] = charge[i];
805             local_pos++;
806         } else {
807             /* Copy to the send buffer */
808             if (bX) {
809                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
810             }
811             pme->bufr[buf_index[node]] = charge[i];
812             buf_index[node]++;
813         }
814     }
815
816     buf_pos = 0;
817     for(i=0; i<nnodes_comm; i++) {
818         scount = atc->count[commnode[i]];
819         rcount = atc->rcount[i];
820         if (scount > 0 || rcount > 0) {
821             if (bX) {
822                 /* Communicate the coordinates */
823                 pme_dd_sendrecv(atc,FALSE,i,
824                                 pme->bufv[buf_pos],scount*sizeof(rvec),
825                                 atc->x[local_pos],rcount*sizeof(rvec));
826             }
827             /* Communicate the charges */
828             pme_dd_sendrecv(atc,FALSE,i,
829                             pme->bufr+buf_pos,scount*sizeof(real),
830                             atc->q+local_pos,rcount*sizeof(real));
831             buf_pos   += scount;
832             local_pos += atc->rcount[i];
833         }
834     }
835 }
836
837 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
838                            int n, rvec *f,
839                            gmx_bool bAddF)
840 {
841   int *commnode,*buf_index;
842   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
843
844   commnode  = atc->node_dest;
845   buf_index = atc->buf_index;
846
847   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
848
849   local_pos = atc->count[atc->nodeid];
850   buf_pos = 0;
851   for(i=0; i<nnodes_comm; i++) {
852     scount = atc->rcount[i];
853     rcount = atc->count[commnode[i]];
854     if (scount > 0 || rcount > 0) {
855       /* Communicate the forces */
856       pme_dd_sendrecv(atc,TRUE,i,
857                       atc->f[local_pos],scount*sizeof(rvec),
858                       pme->bufv[buf_pos],rcount*sizeof(rvec));
859       local_pos += scount;
860     }
861     buf_index[commnode[i]] = buf_pos;
862     buf_pos   += rcount;
863   }
864
865     local_pos = 0;
866     if (bAddF)
867     {
868         for(i=0; i<n; i++)
869         {
870             node = atc->pd[i];
871             if (node == atc->nodeid)
872             {
873                 /* Add from the local force array */
874                 rvec_inc(f[i],atc->f[local_pos]);
875                 local_pos++;
876             }
877             else
878             {
879                 /* Add from the receive buffer */
880                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
881                 buf_index[node]++;
882             }
883         }
884     }
885     else
886     {
887         for(i=0; i<n; i++)
888         {
889             node = atc->pd[i];
890             if (node == atc->nodeid)
891             {
892                 /* Copy from the local force array */
893                 copy_rvec(atc->f[local_pos],f[i]);
894                 local_pos++;
895             }
896             else
897             {
898                 /* Copy from the receive buffer */
899                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
900                 buf_index[node]++;
901             }
902         }
903     }
904 }
905
906 #ifdef GMX_MPI
907 static void
908 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
909 {
910     pme_overlap_t *overlap;
911     int send_index0,send_nindex;
912     int recv_index0,recv_nindex;
913     MPI_Status stat;
914     int i,j,k,ix,iy,iz,icnt;
915     int ipulse,send_id,recv_id,datasize;
916     real *p;
917     real *sendptr,*recvptr;
918
919     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
920     overlap = &pme->overlap[1];
921
922     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
923     {
924         /* Since we have already (un)wrapped the overlap in the z-dimension,
925          * we only have to communicate 0 to nkz (not pmegrid_nz).
926          */
927         if (direction==GMX_SUM_QGRID_FORWARD)
928         {
929             send_id = overlap->send_id[ipulse];
930             recv_id = overlap->recv_id[ipulse];
931             send_index0   = overlap->comm_data[ipulse].send_index0;
932             send_nindex   = overlap->comm_data[ipulse].send_nindex;
933             recv_index0   = overlap->comm_data[ipulse].recv_index0;
934             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
935         }
936         else
937         {
938             send_id = overlap->recv_id[ipulse];
939             recv_id = overlap->send_id[ipulse];
940             send_index0   = overlap->comm_data[ipulse].recv_index0;
941             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
942             recv_index0   = overlap->comm_data[ipulse].send_index0;
943             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
944         }
945
946         /* Copy data to contiguous send buffer */
947         if (debug)
948         {
949             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
950                     pme->nodeid,overlap->nodeid,send_id,
951                     pme->pmegrid_start_iy,
952                     send_index0-pme->pmegrid_start_iy,
953                     send_index0-pme->pmegrid_start_iy+send_nindex);
954         }
955         icnt = 0;
956         for(i=0;i<pme->pmegrid_nx;i++)
957         {
958             ix = i;
959             for(j=0;j<send_nindex;j++)
960             {
961                 iy = j + send_index0 - pme->pmegrid_start_iy;
962                 for(k=0;k<pme->nkz;k++)
963                 {
964                     iz = k;
965                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
966                 }
967             }
968         }
969
970         datasize      = pme->pmegrid_nx * pme->nkz;
971
972         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
973                      send_id,ipulse,
974                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
975                      recv_id,ipulse,
976                      overlap->mpi_comm,&stat);
977
978         /* Get data from contiguous recv buffer */
979         if (debug)
980         {
981             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
982                     pme->nodeid,overlap->nodeid,recv_id,
983                     pme->pmegrid_start_iy,
984                     recv_index0-pme->pmegrid_start_iy,
985                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
986         }
987         icnt = 0;
988         for(i=0;i<pme->pmegrid_nx;i++)
989         {
990             ix = i;
991             for(j=0;j<recv_nindex;j++)
992             {
993                 iy = j + recv_index0 - pme->pmegrid_start_iy;
994                 for(k=0;k<pme->nkz;k++)
995                 {
996                     iz = k;
997                     if(direction==GMX_SUM_QGRID_FORWARD)
998                     {
999                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1000                     }
1001                     else
1002                     {
1003                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1004                     }
1005                 }
1006             }
1007         }
1008     }
1009
1010     /* Major dimension is easier, no copying required,
1011      * but we might have to sum to separate array.
1012      * Since we don't copy, we have to communicate up to pmegrid_nz,
1013      * not nkz as for the minor direction.
1014      */
1015     overlap = &pme->overlap[0];
1016
1017     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1018     {
1019         if(direction==GMX_SUM_QGRID_FORWARD)
1020         {
1021             send_id = overlap->send_id[ipulse];
1022             recv_id = overlap->recv_id[ipulse];
1023             send_index0   = overlap->comm_data[ipulse].send_index0;
1024             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1025             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1026             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1027             recvptr   = overlap->recvbuf;
1028         }
1029         else
1030         {
1031             send_id = overlap->recv_id[ipulse];
1032             recv_id = overlap->send_id[ipulse];
1033             send_index0   = overlap->comm_data[ipulse].recv_index0;
1034             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1035             recv_index0   = overlap->comm_data[ipulse].send_index0;
1036             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1037             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1038         }
1039
1040         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1041         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1042
1043         if (debug)
1044         {
1045             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1046                     pme->nodeid,overlap->nodeid,send_id,
1047                     pme->pmegrid_start_ix,
1048                     send_index0-pme->pmegrid_start_ix,
1049                     send_index0-pme->pmegrid_start_ix+send_nindex);
1050             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1051                     pme->nodeid,overlap->nodeid,recv_id,
1052                     pme->pmegrid_start_ix,
1053                     recv_index0-pme->pmegrid_start_ix,
1054                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1055         }
1056
1057         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1058                      send_id,ipulse,
1059                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1060                      recv_id,ipulse,
1061                      overlap->mpi_comm,&stat);
1062
1063         /* ADD data from contiguous recv buffer */
1064         if(direction==GMX_SUM_QGRID_FORWARD)
1065         {
1066             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1067             for(i=0;i<recv_nindex*datasize;i++)
1068             {
1069                 p[i] += overlap->recvbuf[i];
1070             }
1071         }
1072     }
1073 }
1074 #endif
1075
1076
1077 static int
1078 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1079 {
1080     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1081     ivec    local_pme_size;
1082     int     i,ix,iy,iz;
1083     int     pmeidx,fftidx;
1084
1085     /* Dimensions should be identical for A/B grid, so we just use A here */
1086     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1087                                    local_fft_ndata,
1088                                    local_fft_offset,
1089                                    local_fft_size);
1090
1091     local_pme_size[0] = pme->pmegrid_nx;
1092     local_pme_size[1] = pme->pmegrid_ny;
1093     local_pme_size[2] = pme->pmegrid_nz;
1094
1095     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1096      the offset is identical, and the PME grid always has more data (due to overlap)
1097      */
1098     {
1099 #ifdef DEBUG_PME
1100         FILE *fp,*fp2;
1101         char fn[STRLEN],format[STRLEN];
1102         real val;
1103         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1104         fp = ffopen(fn,"w");
1105         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1106         fp2 = ffopen(fn,"w");
1107      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1108 #endif
1109
1110     for(ix=0;ix<local_fft_ndata[XX];ix++)
1111     {
1112         for(iy=0;iy<local_fft_ndata[YY];iy++)
1113         {
1114             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1115             {
1116                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1117                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1118                 fftgrid[fftidx] = pmegrid[pmeidx];
1119 #ifdef DEBUG_PME
1120                 val = 100*pmegrid[pmeidx];
1121                 if (pmegrid[pmeidx] != 0)
1122                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1123                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1124                 if (pmegrid[pmeidx] != 0)
1125                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1126                             "qgrid",
1127                             pme->pmegrid_start_ix + ix,
1128                             pme->pmegrid_start_iy + iy,
1129                             pme->pmegrid_start_iz + iz,
1130                             pmegrid[pmeidx]);
1131 #endif
1132             }
1133         }
1134     }
1135 #ifdef DEBUG_PME
1136     ffclose(fp);
1137     ffclose(fp2);
1138 #endif
1139     }
1140     return 0;
1141 }
1142
1143
1144 static gmx_cycles_t omp_cyc_start()
1145 {
1146     return gmx_cycles_read();
1147 }
1148
1149 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1150 {
1151     return gmx_cycles_read() - c;
1152 }
1153
1154
1155 static int
1156 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1157                         int nthread,int thread)
1158 {
1159     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1160     ivec    local_pme_size;
1161     int     ixy0,ixy1,ixy,ix,iy,iz;
1162     int     pmeidx,fftidx;
1163 #ifdef PME_TIME_THREADS
1164     gmx_cycles_t c1;
1165     static double cs1=0;
1166     static int cnt=0;
1167 #endif
1168
1169 #ifdef PME_TIME_THREADS
1170     c1 = omp_cyc_start();
1171 #endif
1172     /* Dimensions should be identical for A/B grid, so we just use A here */
1173     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1174                                    local_fft_ndata,
1175                                    local_fft_offset,
1176                                    local_fft_size);
1177
1178     local_pme_size[0] = pme->pmegrid_nx;
1179     local_pme_size[1] = pme->pmegrid_ny;
1180     local_pme_size[2] = pme->pmegrid_nz;
1181
1182     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1183      the offset is identical, and the PME grid always has more data (due to overlap)
1184      */
1185     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1186     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1187
1188     for(ixy=ixy0;ixy<ixy1;ixy++)
1189     {
1190         ix = ixy/local_fft_ndata[YY];
1191         iy = ixy - ix*local_fft_ndata[YY];
1192
1193         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1194         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1195         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1196         {
1197             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1198         }
1199     }
1200
1201 #ifdef PME_TIME_THREADS
1202     c1 = omp_cyc_end(c1);
1203     cs1 += (double)c1;
1204     cnt++;
1205     if (cnt % 20 == 0)
1206     {
1207         printf("copy %.2f\n",cs1*1e-9);
1208     }
1209 #endif
1210
1211     return 0;
1212 }
1213
1214
1215 static void
1216 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1217 {
1218     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1219
1220     nx = pme->nkx;
1221     ny = pme->nky;
1222     nz = pme->nkz;
1223
1224     pnx = pme->pmegrid_nx;
1225     pny = pme->pmegrid_ny;
1226     pnz = pme->pmegrid_nz;
1227
1228     overlap = pme->pme_order - 1;
1229
1230     /* Add periodic overlap in z */
1231     for(ix=0; ix<pme->pmegrid_nx; ix++)
1232     {
1233         for(iy=0; iy<pme->pmegrid_ny; iy++)
1234         {
1235             for(iz=0; iz<overlap; iz++)
1236             {
1237                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1238                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1239             }
1240         }
1241     }
1242
1243     if (pme->nnodes_minor == 1)
1244     {
1245        for(ix=0; ix<pme->pmegrid_nx; ix++)
1246        {
1247            for(iy=0; iy<overlap; iy++)
1248            {
1249                for(iz=0; iz<nz; iz++)
1250                {
1251                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1252                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1253                }
1254            }
1255        }
1256     }
1257
1258     if (pme->nnodes_major == 1)
1259     {
1260         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1261
1262         for(ix=0; ix<overlap; ix++)
1263         {
1264             for(iy=0; iy<ny_x; iy++)
1265             {
1266                 for(iz=0; iz<nz; iz++)
1267                 {
1268                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1269                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1270                 }
1271             }
1272         }
1273     }
1274 }
1275
1276
1277 static void
1278 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1279 {
1280     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1281
1282     nx = pme->nkx;
1283     ny = pme->nky;
1284     nz = pme->nkz;
1285
1286     pnx = pme->pmegrid_nx;
1287     pny = pme->pmegrid_ny;
1288     pnz = pme->pmegrid_nz;
1289
1290     overlap = pme->pme_order - 1;
1291
1292     if (pme->nnodes_major == 1)
1293     {
1294         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1295
1296         for(ix=0; ix<overlap; ix++)
1297         {
1298             int iy,iz;
1299
1300             for(iy=0; iy<ny_x; iy++)
1301             {
1302                 for(iz=0; iz<nz; iz++)
1303                 {
1304                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1305                         pmegrid[(ix*pny+iy)*pnz+iz];
1306                 }
1307             }
1308         }
1309     }
1310
1311     if (pme->nnodes_minor == 1)
1312     {
1313 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1314        for(ix=0; ix<pme->pmegrid_nx; ix++)
1315        {
1316            int iy,iz;
1317
1318            for(iy=0; iy<overlap; iy++)
1319            {
1320                for(iz=0; iz<nz; iz++)
1321                {
1322                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1323                        pmegrid[(ix*pny+iy)*pnz+iz];
1324                }
1325            }
1326        }
1327     }
1328
1329     /* Copy periodic overlap in z */
1330 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1331     for(ix=0; ix<pme->pmegrid_nx; ix++)
1332     {
1333         int iy,iz;
1334
1335         for(iy=0; iy<pme->pmegrid_ny; iy++)
1336         {
1337             for(iz=0; iz<overlap; iz++)
1338             {
1339                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1340                     pmegrid[(ix*pny+iy)*pnz+iz];
1341             }
1342         }
1343     }
1344 }
1345
1346 static void clear_grid(int nx,int ny,int nz,real *grid,
1347                        ivec fs,int *flag,
1348                        int fx,int fy,int fz,
1349                        int order)
1350 {
1351     int nc,ncz;
1352     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1353     int flind;
1354
1355     nc  = 2 + (order - 2)/FLBS;
1356     ncz = 2 + (order - 2)/FLBSZ;
1357
1358     for(fsx=fx; fsx<fx+nc; fsx++)
1359     {
1360         for(fsy=fy; fsy<fy+nc; fsy++)
1361         {
1362             for(fsz=fz; fsz<fz+ncz; fsz++)
1363             {
1364                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1365                 if (flag[flind] == 0)
1366                 {
1367                     gx = fsx*FLBS;
1368                     gy = fsy*FLBS;
1369                     gz = fsz*FLBSZ;
1370                     g0x = (gx*ny + gy)*nz + gz;
1371                     for(x=0; x<FLBS; x++)
1372                     {
1373                         g0y = g0x;
1374                         for(y=0; y<FLBS; y++)
1375                         {
1376                             for(z=0; z<FLBSZ; z++)
1377                             {
1378                                 grid[g0y+z] = 0;
1379                             }
1380                             g0y += nz;
1381                         }
1382                         g0x += ny*nz;
1383                     }
1384
1385                     flag[flind] = 1;
1386                 }
1387             }
1388         }
1389     }
1390 }
1391
1392 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1393 #define DO_BSPLINE(order)                            \
1394 for(ithx=0; (ithx<order); ithx++)                    \
1395 {                                                    \
1396     index_x = (i0+ithx)*pny*pnz;                     \
1397     valx    = qn*thx[ithx];                          \
1398                                                      \
1399     for(ithy=0; (ithy<order); ithy++)                \
1400     {                                                \
1401         valxy    = valx*thy[ithy];                   \
1402         index_xy = index_x+(j0+ithy)*pnz;            \
1403                                                      \
1404         for(ithz=0; (ithz<order); ithz++)            \
1405         {                                            \
1406             index_xyz        = index_xy+(k0+ithz);   \
1407             grid[index_xyz] += valxy*thz[ithz];      \
1408         }                                            \
1409     }                                                \
1410 }
1411
1412
1413 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1414                                      pme_atomcomm_t *atc, splinedata_t *spline,
1415                                      pme_spline_work_t *work)
1416 {
1417
1418     /* spread charges from home atoms to local grid */
1419     real     *grid;
1420     pme_overlap_t *ol;
1421     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1422     int *    idxptr;
1423     int      order,norder,index_x,index_xy,index_xyz;
1424     real     valx,valxy,qn;
1425     real     *thx,*thy,*thz;
1426     int      localsize, bndsize;
1427     int      pnx,pny,pnz,ndatatot;
1428     int      offx,offy,offz;
1429
1430     pnx = pmegrid->n[XX];
1431     pny = pmegrid->n[YY];
1432     pnz = pmegrid->n[ZZ];
1433
1434     offx = pmegrid->offset[XX];
1435     offy = pmegrid->offset[YY];
1436     offz = pmegrid->offset[ZZ];
1437
1438     ndatatot = pnx*pny*pnz;
1439     grid = pmegrid->grid;
1440     for(i=0;i<ndatatot;i++)
1441     {
1442         grid[i] = 0;
1443     }
1444
1445     order = pmegrid->order;
1446
1447     for(nn=0; nn<spline->n; nn++)
1448     {
1449         n  = spline->ind[nn];
1450         qn = atc->q[n];
1451
1452         if (qn != 0)
1453         {
1454             idxptr = atc->idx[n];
1455             norder = nn*order;
1456
1457             i0   = idxptr[XX] - offx;
1458             j0   = idxptr[YY] - offy;
1459             k0   = idxptr[ZZ] - offz;
1460
1461             thx = spline->theta[XX] + norder;
1462             thy = spline->theta[YY] + norder;
1463             thz = spline->theta[ZZ] + norder;
1464
1465             switch (order) {
1466             case 4:
1467 #ifdef PME_SSE
1468 #ifdef PME_SSE_UNALIGNED
1469 #define PME_SPREAD_SSE_ORDER4
1470 #else
1471 #define PME_SPREAD_SSE_ALIGNED
1472 #define PME_ORDER 4
1473 #endif
1474 #include "pme_sse_single.h"
1475 #else
1476                 DO_BSPLINE(4);
1477 #endif
1478                 break;
1479             case 5:
1480 #ifdef PME_SSE
1481 #define PME_SPREAD_SSE_ALIGNED
1482 #define PME_ORDER 5
1483 #include "pme_sse_single.h"
1484 #else
1485                 DO_BSPLINE(5);
1486 #endif
1487                 break;
1488             default:
1489                 DO_BSPLINE(order);
1490                 break;
1491             }
1492         }
1493     }
1494 }
1495
1496 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1497 {
1498 #ifdef PME_SSE
1499     if (pme_order == 5
1500 #ifndef PME_SSE_UNALIGNED
1501         || pme_order == 4
1502 #endif
1503         )
1504     {
1505         /* Round nz up to a multiple of 4 to ensure alignment */
1506         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1507     }
1508 #endif
1509 }
1510
1511 static void set_gridsize_alignment(int *gridsize,int pme_order)
1512 {
1513 #ifdef PME_SSE
1514 #ifndef PME_SSE_UNALIGNED
1515     if (pme_order == 4)
1516     {
1517         /* Add extra elements to ensured aligned operations do not go
1518          * beyond the allocated grid size.
1519          * Note that for pme_order=5, the pme grid z-size alignment
1520          * ensures that we will not go beyond the grid size.
1521          */
1522          *gridsize += 4;
1523     }
1524 #endif
1525 #endif
1526 }
1527
1528 static void pmegrid_init(pmegrid_t *grid,
1529                          int cx, int cy, int cz,
1530                          int x0, int y0, int z0,
1531                          int x1, int y1, int z1,
1532                          gmx_bool set_alignment,
1533                          int pme_order,
1534                          real *ptr)
1535 {
1536     int nz,gridsize;
1537
1538     grid->ci[XX] = cx;
1539     grid->ci[YY] = cy;
1540     grid->ci[ZZ] = cz;
1541     grid->offset[XX] = x0;
1542     grid->offset[YY] = y0;
1543     grid->offset[ZZ] = z0;
1544     grid->n[XX]      = x1 - x0 + pme_order - 1;
1545     grid->n[YY]      = y1 - y0 + pme_order - 1;
1546     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1547
1548     nz = grid->n[ZZ];
1549     set_grid_alignment(&nz,pme_order);
1550     if (set_alignment)
1551     {
1552         grid->n[ZZ] = nz;
1553     }
1554     else if (nz != grid->n[ZZ])
1555     {
1556         gmx_incons("pmegrid_init call with an unaligned z size");
1557     }
1558
1559     grid->order = pme_order;
1560     if (ptr == NULL)
1561     {
1562         gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1563         set_gridsize_alignment(&gridsize,pme_order);
1564         snew_aligned(grid->grid,gridsize,16);
1565     }
1566     else
1567     {
1568         grid->grid = ptr;
1569     }
1570 }
1571
1572 static int div_round_up(int enumerator,int denominator)
1573 {
1574     return (enumerator + denominator - 1)/denominator;
1575 }
1576
1577 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1578                                   ivec nsub)
1579 {
1580     int gsize_opt,gsize;
1581     int nsx,nsy,nsz;
1582     char *env;
1583
1584     gsize_opt = -1;
1585     for(nsx=1; nsx<=nthread; nsx++)
1586     {
1587         if (nthread % nsx == 0)
1588         {
1589             for(nsy=1; nsy<=nthread; nsy++)
1590             {
1591                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1592                 {
1593                     nsz = nthread/(nsx*nsy);
1594
1595                     /* Determine the number of grid points per thread */
1596                     gsize =
1597                         (div_round_up(n[XX],nsx) + ovl)*
1598                         (div_round_up(n[YY],nsy) + ovl)*
1599                         (div_round_up(n[ZZ],nsz) + ovl);
1600
1601                     /* Minimize the number of grids points per thread
1602                      * and, secondarily, the number of cuts in minor dimensions.
1603                      */
1604                     if (gsize_opt == -1 ||
1605                         gsize < gsize_opt ||
1606                         (gsize == gsize_opt &&
1607                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1608                     {
1609                         nsub[XX] = nsx;
1610                         nsub[YY] = nsy;
1611                         nsub[ZZ] = nsz;
1612                         gsize_opt = gsize;
1613                     }
1614                 }
1615             }
1616         }
1617     }
1618
1619     env = getenv("GMX_PME_THREAD_DIVISION");
1620     if (env != NULL)
1621     {
1622         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1623     }
1624
1625     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1626     {
1627         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1628     }
1629 }
1630
1631 static void pmegrids_init(pmegrids_t *grids,
1632                           int nx,int ny,int nz,int nz_base,
1633                           int pme_order,
1634                           int nthread,
1635                           int overlap_x,
1636                           int overlap_y)
1637 {
1638     ivec n,n_base,g0,g1;
1639     int t,x,y,z,d,i,tfac;
1640     int max_comm_lines;
1641
1642     n[XX] = nx - (pme_order - 1);
1643     n[YY] = ny - (pme_order - 1);
1644     n[ZZ] = nz - (pme_order - 1);
1645
1646     copy_ivec(n,n_base);
1647     n_base[ZZ] = nz_base;
1648
1649     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1650                  NULL);
1651
1652     grids->nthread = nthread;
1653
1654     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1655
1656     if (grids->nthread > 1)
1657     {
1658         ivec nst;
1659         int gridsize;
1660         real *grid_all;
1661
1662         for(d=0; d<DIM; d++)
1663         {
1664             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1665         }
1666         set_grid_alignment(&nst[ZZ],pme_order);
1667
1668         if (debug)
1669         {
1670             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1671                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1672             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1673                     nx,ny,nz,
1674                     nst[XX],nst[YY],nst[ZZ]);
1675         }
1676
1677         snew(grids->grid_th,grids->nthread);
1678         t = 0;
1679         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1680         set_gridsize_alignment(&gridsize,pme_order);
1681         snew_aligned(grid_all,
1682                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1683                      16);
1684
1685         for(x=0; x<grids->nc[XX]; x++)
1686         {
1687             for(y=0; y<grids->nc[YY]; y++)
1688             {
1689                 for(z=0; z<grids->nc[ZZ]; z++)
1690                 {
1691                     pmegrid_init(&grids->grid_th[t],
1692                                  x,y,z,
1693                                  (n[XX]*(x  ))/grids->nc[XX],
1694                                  (n[YY]*(y  ))/grids->nc[YY],
1695                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1696                                  (n[XX]*(x+1))/grids->nc[XX],
1697                                  (n[YY]*(y+1))/grids->nc[YY],
1698                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1699                                  TRUE,
1700                                  pme_order,
1701                                  grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1702                     t++;
1703                 }
1704             }
1705         }
1706     }
1707
1708     snew(grids->g2t,DIM);
1709     tfac = 1;
1710     for(d=DIM-1; d>=0; d--)
1711     {
1712         snew(grids->g2t[d],n[d]);
1713         t = 0;
1714         for(i=0; i<n[d]; i++)
1715         {
1716             /* The second check should match the parameters
1717              * of the pmegrid_init call above.
1718              */
1719             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1720             {
1721                 t++;
1722             }
1723             grids->g2t[d][i] = t*tfac;
1724         }
1725
1726         tfac *= grids->nc[d];
1727
1728         switch (d)
1729         {
1730         case XX: max_comm_lines = overlap_x;     break;
1731         case YY: max_comm_lines = overlap_y;     break;
1732         case ZZ: max_comm_lines = pme_order - 1; break;
1733         }
1734         grids->nthread_comm[d] = 0;
1735         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1736         {
1737             grids->nthread_comm[d]++;
1738         }
1739         if (debug != NULL)
1740         {
1741             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1742                     'x'+d,grids->nthread_comm[d]);
1743         }
1744         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1745          * work, but this is not a problematic restriction.
1746          */
1747         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1748         {
1749             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1750         }
1751     }
1752 }
1753
1754
1755 static void pmegrids_destroy(pmegrids_t *grids)
1756 {
1757     int t;
1758
1759     if (grids->grid.grid != NULL)
1760     {
1761         sfree(grids->grid.grid);
1762
1763         if (grids->nthread > 0)
1764         {
1765             for(t=0; t<grids->nthread; t++)
1766             {
1767                 sfree(grids->grid_th[t].grid);
1768             }
1769             sfree(grids->grid_th);
1770         }
1771     }
1772 }
1773
1774
1775 static void realloc_work(pme_work_t *work,int nkx)
1776 {
1777     if (nkx > work->nalloc)
1778     {
1779         work->nalloc = nkx;
1780         srenew(work->mhx  ,work->nalloc);
1781         srenew(work->mhy  ,work->nalloc);
1782         srenew(work->mhz  ,work->nalloc);
1783         srenew(work->m2   ,work->nalloc);
1784         /* Allocate an aligned pointer for SSE operations, including 3 extra
1785          * elements at the end since SSE operates on 4 elements at a time.
1786          */
1787         sfree_aligned(work->denom);
1788         sfree_aligned(work->tmp1);
1789         sfree_aligned(work->eterm);
1790         snew_aligned(work->denom,work->nalloc+3,16);
1791         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1792         snew_aligned(work->eterm,work->nalloc+3,16);
1793         srenew(work->m2inv,work->nalloc);
1794     }
1795 }
1796
1797
1798 static void free_work(pme_work_t *work)
1799 {
1800     sfree(work->mhx);
1801     sfree(work->mhy);
1802     sfree(work->mhz);
1803     sfree(work->m2);
1804     sfree_aligned(work->denom);
1805     sfree_aligned(work->tmp1);
1806     sfree_aligned(work->eterm);
1807     sfree(work->m2inv);
1808 }
1809
1810
1811 #ifdef PME_SSE
1812     /* Calculate exponentials through SSE in float precision */
1813 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1814 {
1815     {
1816         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1817         __m128 f_sse;
1818         __m128 lu;
1819         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1820         int kx;
1821         f_sse = _mm_load1_ps(&f);
1822         for(kx=0; kx<end; kx+=4)
1823         {
1824             tmp_d1   = _mm_load_ps(d_aligned+kx);
1825             lu       = _mm_rcp_ps(tmp_d1);
1826             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1827             tmp_r    = _mm_load_ps(r_aligned+kx);
1828             tmp_r    = gmx_mm_exp_ps(tmp_r);
1829             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1830             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1831             _mm_store_ps(e_aligned+kx,tmp_e);
1832         }
1833     }
1834 }
1835 #else
1836 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1837 {
1838     int kx;
1839     for(kx=start; kx<end; kx++)
1840     {
1841         d[kx] = 1.0/d[kx];
1842     }
1843     for(kx=start; kx<end; kx++)
1844     {
1845         r[kx] = exp(r[kx]);
1846     }
1847     for(kx=start; kx<end; kx++)
1848     {
1849         e[kx] = f*r[kx]*d[kx];
1850     }
1851 }
1852 #endif
1853
1854
1855 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1856                          real ewaldcoeff,real vol,
1857                          gmx_bool bEnerVir,
1858                          int nthread,int thread)
1859 {
1860     /* do recip sum over local cells in grid */
1861     /* y major, z middle, x minor or continuous */
1862     t_complex *p0;
1863     int     kx,ky,kz,maxkx,maxky,maxkz;
1864     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1865     real    mx,my,mz;
1866     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1867     real    ets2,struct2,vfactor,ets2vf;
1868     real    d1,d2,energy=0;
1869     real    by,bz;
1870     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1871     real    rxx,ryx,ryy,rzx,rzy,rzz;
1872     pme_work_t *work;
1873     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1874     real    mhxk,mhyk,mhzk,m2k;
1875     real    corner_fac;
1876     ivec    complex_order;
1877     ivec    local_ndata,local_offset,local_size;
1878     real    elfac;
1879
1880     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1881
1882     nx = pme->nkx;
1883     ny = pme->nky;
1884     nz = pme->nkz;
1885
1886     /* Dimensions should be identical for A/B grid, so we just use A here */
1887     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1888                                       complex_order,
1889                                       local_ndata,
1890                                       local_offset,
1891                                       local_size);
1892
1893     rxx = pme->recipbox[XX][XX];
1894     ryx = pme->recipbox[YY][XX];
1895     ryy = pme->recipbox[YY][YY];
1896     rzx = pme->recipbox[ZZ][XX];
1897     rzy = pme->recipbox[ZZ][YY];
1898     rzz = pme->recipbox[ZZ][ZZ];
1899
1900     maxkx = (nx+1)/2;
1901     maxky = (ny+1)/2;
1902     maxkz = nz/2+1;
1903
1904     work = &pme->work[thread];
1905     mhx   = work->mhx;
1906     mhy   = work->mhy;
1907     mhz   = work->mhz;
1908     m2    = work->m2;
1909     denom = work->denom;
1910     tmp1  = work->tmp1;
1911     eterm = work->eterm;
1912     m2inv = work->m2inv;
1913
1914     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1915     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1916
1917     for(iyz=iyz0; iyz<iyz1; iyz++)
1918     {
1919         iy = iyz/local_ndata[ZZ];
1920         iz = iyz - iy*local_ndata[ZZ];
1921
1922         ky = iy + local_offset[YY];
1923
1924         if (ky < maxky)
1925         {
1926             my = ky;
1927         }
1928         else
1929         {
1930             my = (ky - ny);
1931         }
1932
1933         by = M_PI*vol*pme->bsp_mod[YY][ky];
1934
1935         kz = iz + local_offset[ZZ];
1936
1937         mz = kz;
1938
1939         bz = pme->bsp_mod[ZZ][kz];
1940
1941         /* 0.5 correction for corner points */
1942         corner_fac = 1;
1943         if (kz == 0 || kz == (nz+1)/2)
1944         {
1945             corner_fac = 0.5;
1946         }
1947
1948         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1949
1950         /* We should skip the k-space point (0,0,0) */
1951         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1952         {
1953             kxstart = local_offset[XX];
1954         }
1955         else
1956         {
1957             kxstart = local_offset[XX] + 1;
1958             p0++;
1959         }
1960         kxend = local_offset[XX] + local_ndata[XX];
1961
1962         if (bEnerVir)
1963         {
1964             /* More expensive inner loop, especially because of the storage
1965              * of the mh elements in array's.
1966              * Because x is the minor grid index, all mh elements
1967              * depend on kx for triclinic unit cells.
1968              */
1969
1970                 /* Two explicit loops to avoid a conditional inside the loop */
1971             for(kx=kxstart; kx<maxkx; kx++)
1972             {
1973                 mx = kx;
1974
1975                 mhxk      = mx * rxx;
1976                 mhyk      = mx * ryx + my * ryy;
1977                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1978                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1979                 mhx[kx]   = mhxk;
1980                 mhy[kx]   = mhyk;
1981                 mhz[kx]   = mhzk;
1982                 m2[kx]    = m2k;
1983                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1984                 tmp1[kx]  = -factor*m2k;
1985             }
1986
1987             for(kx=maxkx; kx<kxend; kx++)
1988             {
1989                 mx = (kx - nx);
1990
1991                 mhxk      = mx * rxx;
1992                 mhyk      = mx * ryx + my * ryy;
1993                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1994                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1995                 mhx[kx]   = mhxk;
1996                 mhy[kx]   = mhyk;
1997                 mhz[kx]   = mhzk;
1998                 m2[kx]    = m2k;
1999                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2000                 tmp1[kx]  = -factor*m2k;
2001             }
2002
2003             for(kx=kxstart; kx<kxend; kx++)
2004             {
2005                 m2inv[kx] = 1.0/m2[kx];
2006             }
2007
2008             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2009
2010             for(kx=kxstart; kx<kxend; kx++,p0++)
2011             {
2012                 d1      = p0->re;
2013                 d2      = p0->im;
2014
2015                 p0->re  = d1*eterm[kx];
2016                 p0->im  = d2*eterm[kx];
2017
2018                 struct2 = 2.0*(d1*d1+d2*d2);
2019
2020                 tmp1[kx] = eterm[kx]*struct2;
2021             }
2022
2023             for(kx=kxstart; kx<kxend; kx++)
2024             {
2025                 ets2     = corner_fac*tmp1[kx];
2026                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2027                 energy  += ets2;
2028
2029                 ets2vf   = ets2*vfactor;
2030                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2031                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2032                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2033                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2034                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2035                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2036             }
2037         }
2038         else
2039         {
2040             /* We don't need to calculate the energy and the virial.
2041              * In this case the triclinic overhead is small.
2042              */
2043
2044             /* Two explicit loops to avoid a conditional inside the loop */
2045
2046             for(kx=kxstart; kx<maxkx; kx++)
2047             {
2048                 mx = kx;
2049
2050                 mhxk      = mx * rxx;
2051                 mhyk      = mx * ryx + my * ryy;
2052                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2053                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2054                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2055                 tmp1[kx]  = -factor*m2k;
2056             }
2057
2058             for(kx=maxkx; kx<kxend; kx++)
2059             {
2060                 mx = (kx - nx);
2061
2062                 mhxk      = mx * rxx;
2063                 mhyk      = mx * ryx + my * ryy;
2064                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2065                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2066                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2067                 tmp1[kx]  = -factor*m2k;
2068             }
2069
2070             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2071
2072             for(kx=kxstart; kx<kxend; kx++,p0++)
2073             {
2074                 d1      = p0->re;
2075                 d2      = p0->im;
2076
2077                 p0->re  = d1*eterm[kx];
2078                 p0->im  = d2*eterm[kx];
2079             }
2080         }
2081     }
2082
2083     if (bEnerVir)
2084     {
2085         /* Update virial with local values.
2086          * The virial is symmetric by definition.
2087          * this virial seems ok for isotropic scaling, but I'm
2088          * experiencing problems on semiisotropic membranes.
2089          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2090          */
2091         work->vir[XX][XX] = 0.25*virxx;
2092         work->vir[YY][YY] = 0.25*viryy;
2093         work->vir[ZZ][ZZ] = 0.25*virzz;
2094         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2095         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2096         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2097
2098         /* This energy should be corrected for a charged system */
2099         work->energy = 0.5*energy;
2100     }
2101
2102     /* Return the loop count */
2103     return local_ndata[YY]*local_ndata[XX];
2104 }
2105
2106 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2107                              real *mesh_energy,matrix vir)
2108 {
2109     /* This function sums output over threads
2110      * and should therefore only be called after thread synchronization.
2111      */
2112     int thread;
2113
2114     *mesh_energy = pme->work[0].energy;
2115     copy_mat(pme->work[0].vir,vir);
2116
2117     for(thread=1; thread<nthread; thread++)
2118     {
2119         *mesh_energy += pme->work[thread].energy;
2120         m_add(vir,pme->work[thread].vir,vir);
2121     }
2122 }
2123
2124 #define DO_FSPLINE(order)                      \
2125 for(ithx=0; (ithx<order); ithx++)              \
2126 {                                              \
2127     index_x = (i0+ithx)*pny*pnz;               \
2128     tx      = thx[ithx];                       \
2129     dx      = dthx[ithx];                      \
2130                                                \
2131     for(ithy=0; (ithy<order); ithy++)          \
2132     {                                          \
2133         index_xy = index_x+(j0+ithy)*pnz;      \
2134         ty       = thy[ithy];                  \
2135         dy       = dthy[ithy];                 \
2136         fxy1     = fz1 = 0;                    \
2137                                                \
2138         for(ithz=0; (ithz<order); ithz++)      \
2139         {                                      \
2140             gval  = grid[index_xy+(k0+ithz)];  \
2141             fxy1 += thz[ithz]*gval;            \
2142             fz1  += dthz[ithz]*gval;           \
2143         }                                      \
2144         fx += dx*ty*fxy1;                      \
2145         fy += tx*dy*fxy1;                      \
2146         fz += tx*ty*fz1;                       \
2147     }                                          \
2148 }
2149
2150
2151 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2152                               gmx_bool bClearF,pme_atomcomm_t *atc,
2153                               splinedata_t *spline,
2154                               real scale)
2155 {
2156     /* sum forces for local particles */
2157     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2158     int     index_x,index_xy;
2159     int     nx,ny,nz,pnx,pny,pnz;
2160     int *   idxptr;
2161     real    tx,ty,dx,dy,qn;
2162     real    fx,fy,fz,gval;
2163     real    fxy1,fz1;
2164     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2165     int     norder;
2166     real    rxx,ryx,ryy,rzx,rzy,rzz;
2167     int     order;
2168
2169     pme_spline_work_t *work;
2170
2171     work = &pme->spline_work;
2172
2173     order = pme->pme_order;
2174     thx   = spline->theta[XX];
2175     thy   = spline->theta[YY];
2176     thz   = spline->theta[ZZ];
2177     dthx  = spline->dtheta[XX];
2178     dthy  = spline->dtheta[YY];
2179     dthz  = spline->dtheta[ZZ];
2180     nx    = pme->nkx;
2181     ny    = pme->nky;
2182     nz    = pme->nkz;
2183     pnx   = pme->pmegrid_nx;
2184     pny   = pme->pmegrid_ny;
2185     pnz   = pme->pmegrid_nz;
2186
2187     rxx   = pme->recipbox[XX][XX];
2188     ryx   = pme->recipbox[YY][XX];
2189     ryy   = pme->recipbox[YY][YY];
2190     rzx   = pme->recipbox[ZZ][XX];
2191     rzy   = pme->recipbox[ZZ][YY];
2192     rzz   = pme->recipbox[ZZ][ZZ];
2193
2194     for(nn=0; nn<spline->n; nn++)
2195     {
2196         n  = spline->ind[nn];
2197         qn = scale*atc->q[n];
2198
2199         if (bClearF)
2200         {
2201             atc->f[n][XX] = 0;
2202             atc->f[n][YY] = 0;
2203             atc->f[n][ZZ] = 0;
2204         }
2205         if (qn != 0)
2206         {
2207             fx     = 0;
2208             fy     = 0;
2209             fz     = 0;
2210             idxptr = atc->idx[n];
2211             norder = nn*order;
2212
2213             i0   = idxptr[XX];
2214             j0   = idxptr[YY];
2215             k0   = idxptr[ZZ];
2216
2217             /* Pointer arithmetic alert, next six statements */
2218             thx  = spline->theta[XX] + norder;
2219             thy  = spline->theta[YY] + norder;
2220             thz  = spline->theta[ZZ] + norder;
2221             dthx = spline->dtheta[XX] + norder;
2222             dthy = spline->dtheta[YY] + norder;
2223             dthz = spline->dtheta[ZZ] + norder;
2224
2225             switch (order) {
2226             case 4:
2227 #ifdef PME_SSE
2228 #ifdef PME_SSE_UNALIGNED
2229 #define PME_GATHER_F_SSE_ORDER4
2230 #else
2231 #define PME_GATHER_F_SSE_ALIGNED
2232 #define PME_ORDER 4
2233 #endif
2234 #include "pme_sse_single.h"
2235 #else
2236                 DO_FSPLINE(4);
2237 #endif
2238                 break;
2239             case 5:
2240 #ifdef PME_SSE
2241 #define PME_GATHER_F_SSE_ALIGNED
2242 #define PME_ORDER 5
2243 #include "pme_sse_single.h"
2244 #else
2245                 DO_FSPLINE(5);
2246 #endif
2247                 break;
2248             default:
2249                 DO_FSPLINE(order);
2250                 break;
2251             }
2252
2253             atc->f[n][XX] += -qn*( fx*nx*rxx );
2254             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2255             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2256         }
2257     }
2258     /* Since the energy and not forces are interpolated
2259      * the net force might not be exactly zero.
2260      * This can be solved by also interpolating F, but
2261      * that comes at a cost.
2262      * A better hack is to remove the net force every
2263      * step, but that must be done at a higher level
2264      * since this routine doesn't see all atoms if running
2265      * in parallel. Don't know how important it is?  EL 990726
2266      */
2267 }
2268
2269
2270 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2271                                    pme_atomcomm_t *atc)
2272 {
2273     splinedata_t *spline;
2274     int     n,ithx,ithy,ithz,i0,j0,k0;
2275     int     index_x,index_xy;
2276     int *   idxptr;
2277     real    energy,pot,tx,ty,qn,gval;
2278     real    *thx,*thy,*thz;
2279     int     norder;
2280     int     order;
2281
2282     spline = &atc->spline[0];
2283
2284     order = pme->pme_order;
2285
2286     energy = 0;
2287     for(n=0; (n<atc->n); n++) {
2288         qn      = atc->q[n];
2289
2290         if (qn != 0) {
2291             idxptr = atc->idx[n];
2292             norder = n*order;
2293
2294             i0   = idxptr[XX];
2295             j0   = idxptr[YY];
2296             k0   = idxptr[ZZ];
2297
2298             /* Pointer arithmetic alert, next three statements */
2299             thx  = spline->theta[XX] + norder;
2300             thy  = spline->theta[YY] + norder;
2301             thz  = spline->theta[ZZ] + norder;
2302
2303             pot = 0;
2304             for(ithx=0; (ithx<order); ithx++)
2305             {
2306                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2307                 tx      = thx[ithx];
2308
2309                 for(ithy=0; (ithy<order); ithy++)
2310                 {
2311                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2312                     ty       = thy[ithy];
2313
2314                     for(ithz=0; (ithz<order); ithz++)
2315                     {
2316                         gval  = grid[index_xy+(k0+ithz)];
2317                         pot  += tx*ty*thz[ithz]*gval;
2318                     }
2319
2320                 }
2321             }
2322
2323             energy += pot*qn;
2324         }
2325     }
2326
2327     return energy;
2328 }
2329
2330 /* Macro to force loop unrolling by fixing order.
2331  * This gives a significant performance gain.
2332  */
2333 #define CALC_SPLINE(order)                     \
2334 {                                              \
2335     int j,k,l;                                 \
2336     real dr,div;                               \
2337     real data[PME_ORDER_MAX];                  \
2338     real ddata[PME_ORDER_MAX];                 \
2339                                                \
2340     for(j=0; (j<DIM); j++)                     \
2341     {                                          \
2342         dr  = xptr[j];                         \
2343                                                \
2344         /* dr is relative offset from lower cell limit */ \
2345         data[order-1] = 0;                     \
2346         data[1] = dr;                          \
2347         data[0] = 1 - dr;                      \
2348                                                \
2349         for(k=3; (k<order); k++)               \
2350         {                                      \
2351             div = 1.0/(k - 1.0);               \
2352             data[k-1] = div*dr*data[k-2];      \
2353             for(l=1; (l<(k-1)); l++)           \
2354             {                                  \
2355                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2356                                    data[k-l-1]);                \
2357             }                                  \
2358             data[0] = div*(1-dr)*data[0];      \
2359         }                                      \
2360         /* differentiate */                    \
2361         ddata[0] = -data[0];                   \
2362         for(k=1; (k<order); k++)               \
2363         {                                      \
2364             ddata[k] = data[k-1] - data[k];    \
2365         }                                      \
2366                                                \
2367         div = 1.0/(order - 1);                 \
2368         data[order-1] = div*dr*data[order-2];  \
2369         for(l=1; (l<(order-1)); l++)           \
2370         {                                      \
2371             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2372                                (order-l-dr)*data[order-l-1]); \
2373         }                                      \
2374         data[0] = div*(1 - dr)*data[0];        \
2375                                                \
2376         for(k=0; k<order; k++)                 \
2377         {                                      \
2378             theta[j][i*order+k]  = data[k];    \
2379             dtheta[j][i*order+k] = ddata[k];   \
2380         }                                      \
2381     }                                          \
2382 }
2383
2384 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2385                    rvec fractx[],int nr,int ind[],real charge[],
2386                    gmx_bool bFreeEnergy)
2387 {
2388     /* construct splines for local atoms */
2389     int  i,ii;
2390     real *xptr;
2391
2392     for(i=0; i<nr; i++)
2393     {
2394         /* With free energy we do not use the charge check.
2395          * In most cases this will be more efficient than calling make_bsplines
2396          * twice, since usually more than half the particles have charges.
2397          */
2398         ii = ind[i];
2399         if (bFreeEnergy || charge[ii] != 0.0) {
2400             xptr = fractx[ii];
2401             switch(order) {
2402             case 4:  CALC_SPLINE(4);     break;
2403             case 5:  CALC_SPLINE(5);     break;
2404             default: CALC_SPLINE(order); break;
2405             }
2406         }
2407     }
2408 }
2409
2410
2411 void make_dft_mod(real *mod,real *data,int ndata)
2412 {
2413   int i,j;
2414   real sc,ss,arg;
2415
2416   for(i=0;i<ndata;i++) {
2417     sc=ss=0;
2418     for(j=0;j<ndata;j++) {
2419       arg=(2.0*M_PI*i*j)/ndata;
2420       sc+=data[j]*cos(arg);
2421       ss+=data[j]*sin(arg);
2422     }
2423     mod[i]=sc*sc+ss*ss;
2424   }
2425   for(i=0;i<ndata;i++)
2426     if(mod[i]<1e-7)
2427       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2428 }
2429
2430
2431 static void make_bspline_moduli(splinevec bsp_mod,
2432                                 int nx,int ny,int nz,int order)
2433 {
2434   int nmax=max(nx,max(ny,nz));
2435   real *data,*ddata,*bsp_data;
2436   int i,k,l;
2437   real div;
2438
2439   snew(data,order);
2440   snew(ddata,order);
2441   snew(bsp_data,nmax);
2442
2443   data[order-1]=0;
2444   data[1]=0;
2445   data[0]=1;
2446
2447   for(k=3;k<order;k++) {
2448     div=1.0/(k-1.0);
2449     data[k-1]=0;
2450     for(l=1;l<(k-1);l++)
2451       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2452     data[0]=div*data[0];
2453   }
2454   /* differentiate */
2455   ddata[0]=-data[0];
2456   for(k=1;k<order;k++)
2457     ddata[k]=data[k-1]-data[k];
2458   div=1.0/(order-1);
2459   data[order-1]=0;
2460   for(l=1;l<(order-1);l++)
2461     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2462   data[0]=div*data[0];
2463
2464   for(i=0;i<nmax;i++)
2465     bsp_data[i]=0;
2466   for(i=1;i<=order;i++)
2467     bsp_data[i]=data[i-1];
2468
2469   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2470   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2471   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2472
2473   sfree(data);
2474   sfree(ddata);
2475   sfree(bsp_data);
2476 }
2477
2478
2479 /* Return the P3M optimal influence function */
2480 static double do_p3m_influence(double z, int order)
2481 {
2482     double z2,z4;
2483
2484     z2 = z*z;
2485     z4 = z2*z2;
2486
2487     /* The formula and most constants can be found in:
2488      * Ballenegger et al., JCTC 8, 936 (2012)
2489      */
2490     switch(order)
2491     {
2492     case 2:
2493         return 1.0 - 2.0*z2/3.0;
2494         break;
2495     case 3:
2496         return 1.0 - z2 + 2.0*z4/15.0;
2497         break;
2498     case 4:
2499         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2500         break;
2501     case 5:
2502         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2503         break;
2504     case 6:
2505         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2506         break;
2507     case 7:
2508         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2509     case 8:
2510         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2511         break;
2512     }
2513
2514     return 0.0;
2515 }
2516
2517 /* Calculate the P3M B-spline moduli for one dimension */
2518 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2519 {
2520     double zarg,zai,sinzai,infl;
2521     int    maxk,i;
2522
2523     if (order > 8)
2524     {
2525         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2526     }
2527
2528     zarg = M_PI/n;
2529
2530     maxk = (n + 1)/2;
2531
2532     for(i=-maxk; i<0; i++)
2533     {
2534         zai    = zarg*i;
2535         sinzai = sin(zai);
2536         infl   = do_p3m_influence(sinzai,order);
2537         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2538     }
2539     bsp_mod[0] = 1.0;
2540     for(i=1; i<maxk; i++)
2541     {
2542         zai    = zarg*i;
2543         sinzai = sin(zai);
2544         infl   = do_p3m_influence(sinzai,order);
2545         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2546     }
2547 }
2548
2549 /* Calculate the P3M B-spline moduli */
2550 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2551                                     int nx,int ny,int nz,int order)
2552 {
2553     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2554     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2555     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2556 }
2557
2558
2559 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2560 {
2561   int nslab,n,i;
2562   int fw,bw;
2563
2564   nslab = atc->nslab;
2565
2566   n = 0;
2567   for(i=1; i<=nslab/2; i++) {
2568     fw = (atc->nodeid + i) % nslab;
2569     bw = (atc->nodeid - i + nslab) % nslab;
2570     if (n < nslab - 1) {
2571       atc->node_dest[n] = fw;
2572       atc->node_src[n]  = bw;
2573       n++;
2574     }
2575     if (n < nslab - 1) {
2576       atc->node_dest[n] = bw;
2577       atc->node_src[n]  = fw;
2578       n++;
2579     }
2580   }
2581 }
2582
2583 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2584 {
2585     int thread;
2586
2587     if(NULL != log)
2588     {
2589         fprintf(log,"Destroying PME data structures.\n");
2590     }
2591
2592     sfree((*pmedata)->nnx);
2593     sfree((*pmedata)->nny);
2594     sfree((*pmedata)->nnz);
2595
2596     pmegrids_destroy(&(*pmedata)->pmegridA);
2597
2598     sfree((*pmedata)->fftgridA);
2599     sfree((*pmedata)->cfftgridA);
2600     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2601
2602     if ((*pmedata)->pmegridB.grid.grid != NULL)
2603     {
2604         pmegrids_destroy(&(*pmedata)->pmegridB);
2605         sfree((*pmedata)->fftgridB);
2606         sfree((*pmedata)->cfftgridB);
2607         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2608     }
2609     for(thread=0; thread<(*pmedata)->nthread; thread++)
2610     {
2611         free_work(&(*pmedata)->work[thread]);
2612     }
2613     sfree((*pmedata)->work);
2614
2615     sfree(*pmedata);
2616     *pmedata = NULL;
2617
2618   return 0;
2619 }
2620
2621 static int mult_up(int n,int f)
2622 {
2623     return ((n + f - 1)/f)*f;
2624 }
2625
2626
2627 static double pme_load_imbalance(gmx_pme_t pme)
2628 {
2629     int    nma,nmi;
2630     double n1,n2,n3;
2631
2632     nma = pme->nnodes_major;
2633     nmi = pme->nnodes_minor;
2634
2635     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2636     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2637     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2638
2639     /* pme_solve is roughly double the cost of an fft */
2640
2641     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2642 }
2643
2644 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2645                           int dimind,gmx_bool bSpread)
2646 {
2647     int nk,k,s,thread;
2648
2649     atc->dimind = dimind;
2650     atc->nslab  = 1;
2651     atc->nodeid = 0;
2652     atc->pd_nalloc = 0;
2653 #ifdef GMX_MPI
2654     if (pme->nnodes > 1)
2655     {
2656         atc->mpi_comm = pme->mpi_comm_d[dimind];
2657         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2658         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2659     }
2660     if (debug)
2661     {
2662         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2663     }
2664 #endif
2665
2666     atc->bSpread   = bSpread;
2667     atc->pme_order = pme->pme_order;
2668
2669     if (atc->nslab > 1)
2670     {
2671         /* These three allocations are not required for particle decomp. */
2672         snew(atc->node_dest,atc->nslab);
2673         snew(atc->node_src,atc->nslab);
2674         setup_coordinate_communication(atc);
2675
2676         snew(atc->count_thread,pme->nthread);
2677         for(thread=0; thread<pme->nthread; thread++)
2678         {
2679             snew(atc->count_thread[thread],atc->nslab);
2680         }
2681         atc->count = atc->count_thread[0];
2682         snew(atc->rcount,atc->nslab);
2683         snew(atc->buf_index,atc->nslab);
2684     }
2685
2686     atc->nthread = pme->nthread;
2687     if (atc->nthread > 1)
2688     {
2689         snew(atc->thread_plist,atc->nthread);
2690     }
2691     snew(atc->spline,atc->nthread);
2692     for(thread=0; thread<atc->nthread; thread++)
2693     {
2694         if (atc->nthread > 1)
2695         {
2696             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2697             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2698         }
2699     }
2700 }
2701
2702 static void
2703 init_overlap_comm(pme_overlap_t *  ol,
2704                   int              norder,
2705 #ifdef GMX_MPI
2706                   MPI_Comm         comm,
2707 #endif
2708                   int              nnodes,
2709                   int              nodeid,
2710                   int              ndata,
2711                   int              commplainsize)
2712 {
2713     int lbnd,rbnd,maxlr,b,i;
2714     int exten;
2715     int nn,nk;
2716     pme_grid_comm_t *pgc;
2717     gmx_bool bCont;
2718     int fft_start,fft_end,send_index1,recv_index1;
2719
2720 #ifdef GMX_MPI
2721     ol->mpi_comm = comm;
2722 #endif
2723
2724     ol->nnodes = nnodes;
2725     ol->nodeid = nodeid;
2726
2727     /* Linear translation of the PME grid wo'nt affect reciprocal space
2728      * calculations, so to optimize we only interpolate "upwards",
2729      * which also means we only have to consider overlap in one direction.
2730      * I.e., particles on this node might also be spread to grid indices
2731      * that belong to higher nodes (modulo nnodes)
2732      */
2733
2734     snew(ol->s2g0,ol->nnodes+1);
2735     snew(ol->s2g1,ol->nnodes);
2736     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2737     for(i=0; i<nnodes; i++)
2738     {
2739         /* s2g0 the local interpolation grid start.
2740          * s2g1 the local interpolation grid end.
2741          * Because grid overlap communication only goes forward,
2742          * the grid the slabs for fft's should be rounded down.
2743          */
2744         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2745         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2746
2747         if (debug)
2748         {
2749             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2750         }
2751     }
2752     ol->s2g0[nnodes] = ndata;
2753     if (debug) { fprintf(debug,"\n"); }
2754
2755     /* Determine with how many nodes we need to communicate the grid overlap */
2756     b = 0;
2757     do
2758     {
2759         b++;
2760         bCont = FALSE;
2761         for(i=0; i<nnodes; i++)
2762         {
2763             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2764                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2765             {
2766                 bCont = TRUE;
2767             }
2768         }
2769     }
2770     while (bCont && b < nnodes);
2771     ol->noverlap_nodes = b - 1;
2772
2773     snew(ol->send_id,ol->noverlap_nodes);
2774     snew(ol->recv_id,ol->noverlap_nodes);
2775     for(b=0; b<ol->noverlap_nodes; b++)
2776     {
2777         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2778         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2779     }
2780     snew(ol->comm_data, ol->noverlap_nodes);
2781
2782     for(b=0; b<ol->noverlap_nodes; b++)
2783     {
2784         pgc = &ol->comm_data[b];
2785         /* Send */
2786         fft_start        = ol->s2g0[ol->send_id[b]];
2787         fft_end          = ol->s2g0[ol->send_id[b]+1];
2788         if (ol->send_id[b] < nodeid)
2789         {
2790             fft_start += ndata;
2791             fft_end   += ndata;
2792         }
2793         send_index1      = ol->s2g1[nodeid];
2794         send_index1      = min(send_index1,fft_end);
2795         pgc->send_index0 = fft_start;
2796         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2797
2798         /* We always start receiving to the first index of our slab */
2799         fft_start        = ol->s2g0[ol->nodeid];
2800         fft_end          = ol->s2g0[ol->nodeid+1];
2801         recv_index1      = ol->s2g1[ol->recv_id[b]];
2802         if (ol->recv_id[b] > nodeid)
2803         {
2804             recv_index1 -= ndata;
2805         }
2806         recv_index1      = min(recv_index1,fft_end);
2807         pgc->recv_index0 = fft_start;
2808         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2809     }
2810
2811     /* For non-divisible grid we need pme_order iso pme_order-1 */
2812     snew(ol->sendbuf,norder*commplainsize);
2813     snew(ol->recvbuf,norder*commplainsize);
2814 }
2815
2816 static void
2817 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2818                               int **global_to_local,
2819                               real **fraction_shift)
2820 {
2821     int i;
2822     int * gtl;
2823     real * fsh;
2824
2825     snew(gtl,5*n);
2826     snew(fsh,5*n);
2827     for(i=0; (i<5*n); i++)
2828     {
2829         /* Determine the global to local grid index */
2830         gtl[i] = (i - local_start + n) % n;
2831         /* For coordinates that fall within the local grid the fraction
2832          * is correct, we don't need to shift it.
2833          */
2834         fsh[i] = 0;
2835         if (local_range < n)
2836         {
2837             /* Due to rounding issues i could be 1 beyond the lower or
2838              * upper boundary of the local grid. Correct the index for this.
2839              * If we shift the index, we need to shift the fraction by
2840              * the same amount in the other direction to not affect
2841              * the weights.
2842              * Note that due to this shifting the weights at the end of
2843              * the spline might change, but that will only involve values
2844              * between zero and values close to the precision of a real,
2845              * which is anyhow the accuracy of the whole mesh calculation.
2846              */
2847             /* With local_range=0 we should not change i=local_start */
2848             if (i % n != local_start)
2849             {
2850                 if (gtl[i] == n-1)
2851                 {
2852                     gtl[i] = 0;
2853                     fsh[i] = -1;
2854                 }
2855                 else if (gtl[i] == local_range)
2856                 {
2857                     gtl[i] = local_range - 1;
2858                     fsh[i] = 1;
2859                 }
2860             }
2861         }
2862     }
2863
2864     *global_to_local = gtl;
2865     *fraction_shift  = fsh;
2866 }
2867
2868 static void sse_mask_init(pme_spline_work_t *work,int order)
2869 {
2870 #ifdef PME_SSE
2871     float  tmp[8];
2872     __m128 zero_SSE;
2873     int    of,i;
2874
2875     zero_SSE = _mm_setzero_ps();
2876
2877     for(of=0; of<8-(order-1); of++)
2878     {
2879         for(i=0; i<8; i++)
2880         {
2881             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2882         }
2883         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2884         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2885         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2886         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2887     }
2888 #endif
2889 }
2890
2891 static void
2892 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2893 {
2894     int nk_new;
2895
2896     if (*nk % nnodes != 0)
2897     {
2898         nk_new = nnodes*(*nk/nnodes + 1);
2899
2900         if (2*nk_new >= 3*(*nk))
2901         {
2902             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2903                       dim,*nk,dim,nnodes,dim);
2904         }
2905
2906         if (fplog != NULL)
2907         {
2908             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2909                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2910         }
2911
2912         *nk = nk_new;
2913     }
2914 }
2915
2916 int gmx_pme_init(gmx_pme_t *         pmedata,
2917                  t_commrec *         cr,
2918                  int                 nnodes_major,
2919                  int                 nnodes_minor,
2920                  t_inputrec *        ir,
2921                  int                 homenr,
2922                  gmx_bool            bFreeEnergy,
2923                  gmx_bool            bReproducible,
2924                  int                 nthread)
2925 {
2926     gmx_pme_t pme=NULL;
2927
2928     pme_atomcomm_t *atc;
2929     ivec ndata;
2930
2931     if (debug)
2932         fprintf(debug,"Creating PME data structures.\n");
2933     snew(pme,1);
2934
2935     pme->redist_init         = FALSE;
2936     pme->sum_qgrid_tmp       = NULL;
2937     pme->sum_qgrid_dd_tmp    = NULL;
2938     pme->buf_nalloc          = 0;
2939     pme->redist_buf_nalloc   = 0;
2940
2941     pme->nnodes              = 1;
2942     pme->bPPnode             = TRUE;
2943
2944     pme->nnodes_major        = nnodes_major;
2945     pme->nnodes_minor        = nnodes_minor;
2946
2947 #ifdef GMX_MPI
2948     if (nnodes_major*nnodes_minor > 1)
2949     {
2950         pme->mpi_comm = cr->mpi_comm_mygroup;
2951
2952         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2953         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2954         if (pme->nnodes != nnodes_major*nnodes_minor)
2955         {
2956             gmx_incons("PME node count mismatch");
2957         }
2958     }
2959     else
2960     {
2961         pme->mpi_comm = MPI_COMM_NULL;
2962     }
2963 #endif
2964
2965     if (pme->nnodes == 1)
2966     {
2967         pme->ndecompdim = 0;
2968         pme->nodeid_major = 0;
2969         pme->nodeid_minor = 0;
2970 #ifdef GMX_MPI
2971         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2972 #endif
2973     }
2974     else
2975     {
2976         if (nnodes_minor == 1)
2977         {
2978 #ifdef GMX_MPI
2979             pme->mpi_comm_d[0] = pme->mpi_comm;
2980             pme->mpi_comm_d[1] = MPI_COMM_NULL;
2981 #endif
2982             pme->ndecompdim = 1;
2983             pme->nodeid_major = pme->nodeid;
2984             pme->nodeid_minor = 0;
2985
2986         }
2987         else if (nnodes_major == 1)
2988         {
2989 #ifdef GMX_MPI
2990             pme->mpi_comm_d[0] = MPI_COMM_NULL;
2991             pme->mpi_comm_d[1] = pme->mpi_comm;
2992 #endif
2993             pme->ndecompdim = 1;
2994             pme->nodeid_major = 0;
2995             pme->nodeid_minor = pme->nodeid;
2996         }
2997         else
2998         {
2999             if (pme->nnodes % nnodes_major != 0)
3000             {
3001                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3002             }
3003             pme->ndecompdim = 2;
3004
3005 #ifdef GMX_MPI
3006             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3007                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3008             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3009                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3010
3011             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3012             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3013             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3014             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3015 #endif
3016         }
3017         pme->bPPnode = (cr->duty & DUTY_PP);
3018     }
3019
3020     pme->nthread = nthread;
3021
3022     if (ir->ePBC == epbcSCREW)
3023     {
3024         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3025     }
3026
3027     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3028     pme->nkx         = ir->nkx;
3029     pme->nky         = ir->nky;
3030     pme->nkz         = ir->nkz;
3031     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3032     pme->pme_order   = ir->pme_order;
3033     pme->epsilon_r   = ir->epsilon_r;
3034
3035     if (pme->pme_order > PME_ORDER_MAX)
3036     {
3037         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3038                   pme->pme_order,PME_ORDER_MAX);
3039     }
3040
3041     /* Currently pme.c supports only the fft5d FFT code.
3042      * Therefore the grid always needs to be divisible by nnodes.
3043      * When the old 1D code is also supported again, change this check.
3044      *
3045      * This check should be done before calling gmx_pme_init
3046      * and fplog should be passed iso stderr.
3047      *
3048     if (pme->ndecompdim >= 2)
3049     */
3050     if (pme->ndecompdim >= 1)
3051     {
3052         /*
3053         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3054                                         'x',nnodes_major,&pme->nkx);
3055         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3056                                         'y',nnodes_minor,&pme->nky);
3057         */
3058     }
3059
3060     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3061         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3062         pme->nkz <= pme->pme_order)
3063     {
3064         gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3065     }
3066
3067     if (pme->nnodes > 1) {
3068         double imbal;
3069
3070 #ifdef GMX_MPI
3071         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3072         MPI_Type_commit(&(pme->rvec_mpi));
3073 #endif
3074
3075         /* Note that the charge spreading and force gathering, which usually
3076          * takes about the same amount of time as FFT+solve_pme,
3077          * is always fully load balanced
3078          * (unless the charge distribution is inhomogeneous).
3079          */
3080
3081         imbal = pme_load_imbalance(pme);
3082         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3083         {
3084             fprintf(stderr,
3085                     "\n"
3086                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3087                     "      For optimal PME load balancing\n"
3088                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3089                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3090                     "\n",
3091                     (int)((imbal-1)*100 + 0.5),
3092                     pme->nkx,pme->nky,pme->nnodes_major,
3093                     pme->nky,pme->nkz,pme->nnodes_minor);
3094         }
3095     }
3096
3097     /* For non-divisible grid we need pme_order iso pme_order-1 */
3098     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3099      * y is always copied through a buffer: we don't need padding in z,
3100      * but we do need the overlap in x because of the communication order.
3101      */
3102     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3103 #ifdef GMX_MPI
3104                       pme->mpi_comm_d[0],
3105 #endif
3106                       pme->nnodes_major,pme->nodeid_major,
3107                       pme->nkx,
3108                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3109
3110     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3111 #ifdef GMX_MPI
3112                       pme->mpi_comm_d[1],
3113 #endif
3114                       pme->nnodes_minor,pme->nodeid_minor,
3115                       pme->nky,
3116                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3117
3118     /* Check for a limitation of the (current) sum_fftgrid_dd code */
3119     if (pme->nthread > 1 &&
3120         (pme->overlap[0].noverlap_nodes > 1 ||
3121          pme->overlap[1].noverlap_nodes > 1))
3122     {
3123         gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3124     }
3125
3126     snew(pme->bsp_mod[XX],pme->nkx);
3127     snew(pme->bsp_mod[YY],pme->nky);
3128     snew(pme->bsp_mod[ZZ],pme->nkz);
3129
3130     /* The required size of the interpolation grid, including overlap.
3131      * The allocated size (pmegrid_n?) might be slightly larger.
3132      */
3133     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3134                       pme->overlap[0].s2g0[pme->nodeid_major];
3135     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3136                       pme->overlap[1].s2g0[pme->nodeid_minor];
3137     pme->pmegrid_nz_base = pme->nkz;
3138     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3139     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3140
3141     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3142     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3143     pme->pmegrid_start_iz = 0;
3144
3145     make_gridindex5_to_localindex(pme->nkx,
3146                                   pme->pmegrid_start_ix,
3147                                   pme->pmegrid_nx - (pme->pme_order-1),
3148                                   &pme->nnx,&pme->fshx);
3149     make_gridindex5_to_localindex(pme->nky,
3150                                   pme->pmegrid_start_iy,
3151                                   pme->pmegrid_ny - (pme->pme_order-1),
3152                                   &pme->nny,&pme->fshy);
3153     make_gridindex5_to_localindex(pme->nkz,
3154                                   pme->pmegrid_start_iz,
3155                                   pme->pmegrid_nz_base,
3156                                   &pme->nnz,&pme->fshz);
3157
3158     pmegrids_init(&pme->pmegridA,
3159                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3160                   pme->pmegrid_nz_base,
3161                   pme->pme_order,
3162                   pme->nthread,
3163                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3164                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3165
3166     sse_mask_init(&pme->spline_work,pme->pme_order);
3167
3168     ndata[0] = pme->nkx;
3169     ndata[1] = pme->nky;
3170     ndata[2] = pme->nkz;
3171
3172     /* This routine will allocate the grid data to fit the FFTs */
3173     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3174                             &pme->fftgridA,&pme->cfftgridA,
3175                             pme->mpi_comm_d,
3176                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3177                             bReproducible,pme->nthread);
3178
3179     if (bFreeEnergy)
3180     {
3181         pmegrids_init(&pme->pmegridB,
3182                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3183                       pme->pmegrid_nz_base,
3184                       pme->pme_order,
3185                       pme->nthread,
3186                       pme->nkx % pme->nnodes_major != 0,
3187                       pme->nky % pme->nnodes_minor != 0);
3188
3189         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3190                                 &pme->fftgridB,&pme->cfftgridB,
3191                                 pme->mpi_comm_d,
3192                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3193                                 bReproducible,pme->nthread);
3194     }
3195     else
3196     {
3197         pme->pmegridB.grid.grid = NULL;
3198         pme->fftgridB           = NULL;
3199         pme->cfftgridB          = NULL;
3200     }
3201
3202     if (!pme->bP3M)
3203     {
3204         /* Use plain SPME B-spline interpolation */
3205         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3206     }
3207     else
3208     {
3209         /* Use the P3M grid-optimized influence function */
3210         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3211     }
3212
3213     /* Use atc[0] for spreading */
3214     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3215     if (pme->ndecompdim >= 2)
3216     {
3217         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3218     }
3219
3220     if (pme->nnodes == 1) {
3221         pme->atc[0].n = homenr;
3222         pme_realloc_atomcomm_things(&pme->atc[0]);
3223     }
3224
3225     {
3226         int thread;
3227
3228         /* Use fft5d, order after FFT is y major, z, x minor */
3229
3230         snew(pme->work,pme->nthread);
3231         for(thread=0; thread<pme->nthread; thread++)
3232         {
3233             realloc_work(&pme->work[thread],pme->nkx);
3234         }
3235     }
3236
3237     *pmedata = pme;
3238
3239     return 0;
3240 }
3241
3242
3243 static void copy_local_grid(gmx_pme_t pme,
3244                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3245 {
3246     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3247     int  fft_my,fft_mz;
3248     int  nsx,nsy,nsz;
3249     ivec nf;
3250     int  offx,offy,offz,x,y,z,i0,i0t;
3251     int  d;
3252     pmegrid_t *pmegrid;
3253     real *grid_th;
3254
3255     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3256                                    local_fft_ndata,
3257                                    local_fft_offset,
3258                                    local_fft_size);
3259     fft_my = local_fft_size[YY];
3260     fft_mz = local_fft_size[ZZ];
3261
3262     pmegrid = &pmegrids->grid_th[thread];
3263
3264     nsx = pmegrid->n[XX];
3265     nsy = pmegrid->n[YY];
3266     nsz = pmegrid->n[ZZ];
3267
3268     for(d=0; d<DIM; d++)
3269     {
3270         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3271                     local_fft_ndata[d] - pmegrid->offset[d]);
3272     }
3273
3274     offx = pmegrid->offset[XX];
3275     offy = pmegrid->offset[YY];
3276     offz = pmegrid->offset[ZZ];
3277
3278     /* Directly copy the non-overlapping parts of the local grids.
3279      * This also initializes the full grid.
3280      */
3281     grid_th = pmegrid->grid;
3282     for(x=0; x<nf[XX]; x++)
3283     {
3284         for(y=0; y<nf[YY]; y++)
3285         {
3286             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3287             i0t = (x*nsy + y)*nsz;
3288             for(z=0; z<nf[ZZ]; z++)
3289             {
3290                 fftgrid[i0+z] = grid_th[i0t+z];
3291             }
3292         }
3293     }
3294 }
3295
3296 static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3297 {
3298     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3299     pme_overlap_t *overlap;
3300     int datasize,nind;
3301     int i,x,y,z,n;
3302
3303     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3304                                    local_fft_ndata,
3305                                    local_fft_offset,
3306                                    local_fft_size);
3307     /* Major dimension */
3308     overlap = &pme->overlap[0];
3309
3310     nind   = overlap->comm_data[0].send_nindex;
3311
3312     for(y=0; y<local_fft_ndata[YY]; y++) {
3313          printf(" %2d",y);
3314     }
3315     printf("\n");
3316
3317     i = 0;
3318     for(x=0; x<nind; x++) {
3319         for(y=0; y<local_fft_ndata[YY]; y++) {
3320             n = 0;
3321             for(z=0; z<local_fft_ndata[ZZ]; z++) {
3322                 if (sendbuf[i] != 0) n++;
3323                 i++;
3324             }
3325             printf(" %2d",n);
3326         }
3327         printf("\n");
3328     }
3329 }
3330
3331 static void
3332 reduce_threadgrid_overlap(gmx_pme_t pme,
3333                           const pmegrids_t *pmegrids,int thread,
3334                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3335 {
3336     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3337     int  fft_nx,fft_ny,fft_nz;
3338     int  fft_my,fft_mz;
3339     int  buf_my=-1;
3340     int  nsx,nsy,nsz;
3341     ivec ne;
3342     int  offx,offy,offz,x,y,z,i0,i0t;
3343     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3344     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3345     gmx_bool bCommX,bCommY;
3346     int  d;
3347     int  thread_f;
3348     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3349     const real *grid_th;
3350     real *commbuf=NULL;
3351
3352     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3353                                    local_fft_ndata,
3354                                    local_fft_offset,
3355                                    local_fft_size);
3356     fft_nx = local_fft_ndata[XX];
3357     fft_ny = local_fft_ndata[YY];
3358     fft_nz = local_fft_ndata[ZZ];
3359
3360     fft_my = local_fft_size[YY];
3361     fft_mz = local_fft_size[ZZ];
3362
3363     /* This routine is called when all thread have finished spreading.
3364      * Here each thread sums grid contributions calculated by other threads
3365      * to the thread local grid volume.
3366      * To minimize the number of grid copying operations,
3367      * this routines sums immediately from the pmegrid to the fftgrid.
3368      */
3369
3370     /* Determine which part of the full node grid we should operate on,
3371      * this is our thread local part of the full grid.
3372      */
3373     pmegrid = &pmegrids->grid_th[thread];
3374
3375     for(d=0; d<DIM; d++)
3376     {
3377         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3378                     local_fft_ndata[d]);
3379     }
3380
3381     offx = pmegrid->offset[XX];
3382     offy = pmegrid->offset[YY];
3383     offz = pmegrid->offset[ZZ];
3384
3385
3386     bClearBufX  = TRUE;
3387     bClearBufY  = TRUE;
3388     bClearBufXY = TRUE;
3389
3390     /* Now loop over all the thread data blocks that contribute
3391      * to the grid region we (our thread) are operating on.
3392      */
3393     /* Note that ffy_nx/y is equal to the number of grid points
3394      * between the first point of our node grid and the one of the next node.
3395      */
3396     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3397     {
3398         fx = pmegrid->ci[XX] + sx;
3399         ox = 0;
3400         bCommX = FALSE;
3401         if (fx < 0) {
3402             fx += pmegrids->nc[XX];
3403             ox -= fft_nx;
3404             bCommX = (pme->nnodes_major > 1);
3405         }
3406         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3407         ox += pmegrid_g->offset[XX];
3408         if (!bCommX)
3409         {
3410             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3411         }
3412         else
3413         {
3414             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3415         }
3416
3417         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3418         {
3419             fy = pmegrid->ci[YY] + sy;
3420             oy = 0;
3421             bCommY = FALSE;
3422             if (fy < 0) {
3423                 fy += pmegrids->nc[YY];
3424                 oy -= fft_ny;
3425                 bCommY = (pme->nnodes_minor > 1);
3426             }
3427             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3428             oy += pmegrid_g->offset[YY];
3429             if (!bCommY)
3430             {
3431                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3432             }
3433             else
3434             {
3435                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3436             }
3437
3438             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3439             {
3440                 fz = pmegrid->ci[ZZ] + sz;
3441                 oz = 0;
3442                 if (fz < 0)
3443                 {
3444                     fz += pmegrids->nc[ZZ];
3445                     oz -= fft_nz;
3446                 }
3447                 pmegrid_g = &pmegrids->grid_th[fz];
3448                 oz += pmegrid_g->offset[ZZ];
3449                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3450
3451                 if (sx == 0 && sy == 0 && sz == 0)
3452                 {
3453                     /* We have already added our local contribution
3454                      * before calling this routine, so skip it here.
3455                      */
3456                     continue;
3457                 }
3458
3459                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3460
3461                 pmegrid_f = &pmegrids->grid_th[thread_f];
3462
3463                 grid_th = pmegrid_f->grid;
3464
3465                 nsx = pmegrid_f->n[XX];
3466                 nsy = pmegrid_f->n[YY];
3467                 nsz = pmegrid_f->n[ZZ];
3468
3469 #ifdef DEBUG_PME_REDUCE
3470                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3471                        pme->nodeid,thread,thread_f,
3472                        pme->pmegrid_start_ix,
3473                        pme->pmegrid_start_iy,
3474                        pme->pmegrid_start_iz,
3475                        sx,sy,sz,
3476                        offx-ox,tx1-ox,offx,tx1,
3477                        offy-oy,ty1-oy,offy,ty1,
3478                        offz-oz,tz1-oz,offz,tz1);
3479 #endif
3480
3481                 if (!(bCommX || bCommY))
3482                 {
3483                     /* Copy from the thread local grid to the node grid */
3484                     for(x=offx; x<tx1; x++)
3485                     {
3486                         for(y=offy; y<ty1; y++)
3487                         {
3488                             i0  = (x*fft_my + y)*fft_mz;
3489                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3490                             for(z=offz; z<tz1; z++)
3491                             {
3492                                 fftgrid[i0+z] += grid_th[i0t+z];
3493                             }
3494                         }
3495                     }
3496                 }
3497                 else
3498                 {
3499                     /* The order of this conditional decides
3500                      * where the corner volume gets stored with x+y decomp.
3501                      */
3502                     if (bCommY)
3503                     {
3504                         commbuf = commbuf_y;
3505                         buf_my  = ty1 - offy;
3506                         if (bCommX)
3507                         {
3508                             /* We index commbuf modulo the local grid size */
3509                             commbuf += buf_my*fft_nx*fft_nz;
3510
3511                             bClearBuf  = bClearBufXY;
3512                             bClearBufXY = FALSE;
3513                         }
3514                         else
3515                         {
3516                             bClearBuf  = bClearBufY;
3517                             bClearBufY = FALSE;
3518                         }
3519                     }
3520                     else
3521                     {
3522                         commbuf = commbuf_x;
3523                         buf_my  = fft_ny;
3524                         bClearBuf  = bClearBufX;
3525                         bClearBufX = FALSE;
3526                     }
3527
3528                     /* Copy to the communication buffer */
3529                     for(x=offx; x<tx1; x++)
3530                     {
3531                         for(y=offy; y<ty1; y++)
3532                         {
3533                             i0  = (x*buf_my + y)*fft_nz;
3534                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3535
3536                             if (bClearBuf)
3537                             {
3538                                 /* First access of commbuf, initialize it */
3539                                 for(z=offz; z<tz1; z++)
3540                                 {
3541                                     commbuf[i0+z]  = grid_th[i0t+z];
3542                                 }
3543                             }
3544                             else
3545                             {
3546                                 for(z=offz; z<tz1; z++)
3547                                 {
3548                                     commbuf[i0+z] += grid_th[i0t+z];
3549                                 }
3550                             }
3551                         }
3552                     }
3553                 }
3554             }
3555         }
3556     }
3557 }
3558
3559
3560 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3561 {
3562     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3563     pme_overlap_t *overlap;
3564     int  send_nindex;
3565     int  recv_index0,recv_nindex;
3566 #ifdef GMX_MPI
3567     MPI_Status stat;
3568 #endif
3569     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3570     real *sendptr,*recvptr;
3571     int  x,y,z,indg,indb;
3572
3573     /* Note that this routine is only used for forward communication.
3574      * Since the force gathering, unlike the charge spreading,
3575      * can be trivially parallelized over the particles,
3576      * the backwards process is much simpler and can use the "old"
3577      * communication setup.
3578      */
3579
3580     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3581                                    local_fft_ndata,
3582                                    local_fft_offset,
3583                                    local_fft_size);
3584
3585     /* Currently supports only a single communication pulse */
3586
3587 /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3588     if (pme->nnodes_minor > 1)
3589     {
3590         /* Major dimension */
3591         overlap = &pme->overlap[1];
3592
3593         if (pme->nnodes_major > 1)
3594         {
3595              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3596         }
3597         else
3598         {
3599             size_yx = 0;
3600         }
3601         datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3602
3603         ipulse = 0;
3604
3605         send_id = overlap->send_id[ipulse];
3606         recv_id = overlap->recv_id[ipulse];
3607         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3608         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3609         recv_index0 = 0;
3610         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3611
3612         sendptr = overlap->sendbuf;
3613         recvptr = overlap->recvbuf;
3614
3615         /*
3616         printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3617                local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3618         printf("node %d send %f, %f\n",pme->nodeid,
3619                sendptr[0],sendptr[send_nindex*datasize-1]);
3620         */
3621
3622 #ifdef GMX_MPI
3623         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3624                      send_id,ipulse,
3625                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3626                      recv_id,ipulse,
3627                      overlap->mpi_comm,&stat);
3628 #endif
3629
3630         for(x=0; x<local_fft_ndata[XX]; x++)
3631         {
3632             for(y=0; y<recv_nindex; y++)
3633             {
3634                 indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3635                 indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3636                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3637                 {
3638                     fftgrid[indg+z] += recvptr[indb+z];
3639                 }
3640             }
3641         }
3642         if (pme->nnodes_major > 1)
3643         {
3644             sendptr = pme->overlap[0].sendbuf;
3645             for(x=0; x<size_yx; x++)
3646             {
3647                 for(y=0; y<recv_nindex; y++)
3648                 {
3649                     indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3650                     indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3651                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3652                     {
3653                         sendptr[indg+z] += recvptr[indb+z];
3654                     }
3655                 }
3656             }
3657         }
3658     }
3659
3660     /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3661     if (pme->nnodes_major > 1)
3662     {
3663         /* Major dimension */
3664         overlap = &pme->overlap[0];
3665
3666         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3667         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3668
3669         ipulse = 0;
3670
3671         send_id = overlap->send_id[ipulse];
3672         recv_id = overlap->recv_id[ipulse];
3673         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3674         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3675         recv_index0 = 0;
3676         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3677
3678         sendptr = overlap->sendbuf;
3679         recvptr = overlap->recvbuf;
3680
3681         if (debug != NULL)
3682         {
3683             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3684                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3685         }
3686
3687 #ifdef GMX_MPI
3688         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3689                      send_id,ipulse,
3690                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3691                      recv_id,ipulse,
3692                      overlap->mpi_comm,&stat);
3693 #endif
3694
3695         for(x=0; x<recv_nindex; x++)
3696         {
3697             for(y=0; y<local_fft_ndata[YY]; y++)
3698             {
3699                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3700                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3701                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3702                 {
3703                     fftgrid[indg+z] += recvptr[indb+z];
3704                 }
3705             }
3706         }
3707     }
3708 }
3709
3710
3711 static void spread_on_grid(gmx_pme_t pme,
3712                            pme_atomcomm_t *atc,pmegrids_t *grids,
3713                            gmx_bool bCalcSplines,gmx_bool bSpread,
3714                            real *fftgrid)
3715 {
3716     int nthread,thread;
3717 #ifdef PME_TIME_THREADS
3718     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3719     static double cs1=0,cs2=0,cs3=0;
3720     static double cs1a[6]={0,0,0,0,0,0};
3721     static int cnt=0;
3722 #endif
3723
3724     nthread = pme->nthread;
3725     assert(nthread>0);
3726
3727 #ifdef PME_TIME_THREADS
3728     c1 = omp_cyc_start();
3729 #endif
3730     if (bCalcSplines)
3731     {
3732 #pragma omp parallel for num_threads(nthread) schedule(static)
3733         for(thread=0; thread<nthread; thread++)
3734         {
3735             int start,end;
3736
3737             start = atc->n* thread   /nthread;
3738             end   = atc->n*(thread+1)/nthread;
3739
3740             /* Compute fftgrid index for all atoms,
3741              * with help of some extra variables.
3742              */
3743             calc_interpolation_idx(pme,atc,start,end,thread);
3744         }
3745     }
3746 #ifdef PME_TIME_THREADS
3747     c1 = omp_cyc_end(c1);
3748     cs1 += (double)c1;
3749 #endif
3750
3751 #ifdef PME_TIME_THREADS
3752     c2 = omp_cyc_start();
3753 #endif
3754 #pragma omp parallel for num_threads(nthread) schedule(static)
3755     for(thread=0; thread<nthread; thread++)
3756     {
3757         splinedata_t *spline;
3758         pmegrid_t *grid;
3759
3760         /* make local bsplines  */
3761         if (grids == NULL || grids->nthread == 1)
3762         {
3763             spline = &atc->spline[0];
3764
3765             spline->n = atc->n;
3766
3767             grid = &grids->grid;
3768         }
3769         else
3770         {
3771             spline = &atc->spline[thread];
3772
3773             make_thread_local_ind(atc,thread,spline);
3774
3775             grid = &grids->grid_th[thread];
3776         }
3777
3778         if (bCalcSplines)
3779         {
3780             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3781                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3782         }
3783
3784         if (bSpread)
3785         {
3786             /* put local atoms on grid. */
3787 #ifdef PME_TIME_SPREAD
3788             ct1a = omp_cyc_start();
3789 #endif
3790             spread_q_bsplines_thread(grid,atc,spline,&pme->spline_work);
3791
3792             if (grids->nthread > 1)
3793             {
3794                 copy_local_grid(pme,grids,thread,fftgrid);
3795             }
3796 #ifdef PME_TIME_SPREAD
3797             ct1a = omp_cyc_end(ct1a);
3798             cs1a[thread] += (double)ct1a;
3799 #endif
3800         }
3801     }
3802 #ifdef PME_TIME_THREADS
3803     c2 = omp_cyc_end(c2);
3804     cs2 += (double)c2;
3805 #endif
3806
3807     if (bSpread && grids->nthread > 1)
3808     {
3809 #ifdef PME_TIME_THREADS
3810         c3 = omp_cyc_start();
3811 #endif
3812 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3813         for(thread=0; thread<grids->nthread; thread++)
3814         {
3815             reduce_threadgrid_overlap(pme,grids,thread,
3816                                       fftgrid,
3817                                       pme->overlap[0].sendbuf,
3818                                       pme->overlap[1].sendbuf);
3819 #ifdef PRINT_PME_SENDBUF
3820             print_sendbuf(pme,pme->overlap[0].sendbuf);
3821 #endif
3822         }
3823 #ifdef PME_TIME_THREADS
3824         c3 = omp_cyc_end(c3);
3825         cs3 += (double)c3;
3826 #endif
3827
3828         if (pme->nnodes > 1)
3829         {
3830             /* Communicate the overlapping part of the fftgrid */
3831             sum_fftgrid_dd(pme,fftgrid);
3832         }
3833     }
3834
3835 #ifdef PME_TIME_THREADS
3836     cnt++;
3837     if (cnt % 20 == 0)
3838     {
3839         printf("idx %.2f spread %.2f red %.2f",
3840                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3841 #ifdef PME_TIME_SPREAD
3842         for(thread=0; thread<nthread; thread++)
3843             printf(" %.2f",cs1a[thread]*1e-9);
3844 #endif
3845         printf("\n");
3846     }
3847 #endif
3848 }
3849
3850
3851 static void dump_grid(FILE *fp,
3852                       int sx,int sy,int sz,int nx,int ny,int nz,
3853                       int my,int mz,const real *g)
3854 {
3855     int x,y,z;
3856
3857     for(x=0; x<nx; x++)
3858     {
3859         for(y=0; y<ny; y++)
3860         {
3861             for(z=0; z<nz; z++)
3862             {
3863                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3864                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3865             }
3866         }
3867     }
3868 }
3869
3870 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3871 {
3872     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3873
3874     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3875                                    local_fft_ndata,
3876                                    local_fft_offset,
3877                                    local_fft_size);
3878
3879     dump_grid(stderr,
3880               pme->pmegrid_start_ix,
3881               pme->pmegrid_start_iy,
3882               pme->pmegrid_start_iz,
3883               pme->pmegrid_nx-pme->pme_order+1,
3884               pme->pmegrid_ny-pme->pme_order+1,
3885               pme->pmegrid_nz-pme->pme_order+1,
3886               local_fft_size[YY],
3887               local_fft_size[ZZ],
3888               fftgrid);
3889 }
3890
3891
3892 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3893 {
3894     pme_atomcomm_t *atc;
3895     pmegrids_t *grid;
3896
3897     if (pme->nnodes > 1)
3898     {
3899         gmx_incons("gmx_pme_calc_energy called in parallel");
3900     }
3901     if (pme->bFEP > 1)
3902     {
3903         gmx_incons("gmx_pme_calc_energy with free energy");
3904     }
3905
3906     atc = &pme->atc_energy;
3907     atc->nthread   = 1;
3908     if (atc->spline == NULL)
3909     {
3910         snew(atc->spline,atc->nthread);
3911     }
3912     atc->nslab     = 1;
3913     atc->bSpread   = TRUE;
3914     atc->pme_order = pme->pme_order;
3915     atc->n         = n;
3916     pme_realloc_atomcomm_things(atc);
3917     atc->x         = x;
3918     atc->q         = q;
3919
3920     /* We only use the A-charges grid */
3921     grid = &pme->pmegridA;
3922
3923     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3924
3925     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3926 }
3927
3928
3929 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3930         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3931 {
3932     /* Reset all the counters related to performance over the run */
3933     wallcycle_stop(wcycle,ewcRUN);
3934     wallcycle_reset_all(wcycle);
3935     init_nrnb(nrnb);
3936     ir->init_step += step_rel;
3937     ir->nsteps    -= step_rel;
3938     wallcycle_start(wcycle,ewcRUN);
3939 }
3940
3941
3942 int gmx_pmeonly(gmx_pme_t pme,
3943                 t_commrec *cr,    t_nrnb *nrnb,
3944                 gmx_wallcycle_t wcycle,
3945                 real ewaldcoeff,  gmx_bool bGatherOnly,
3946                 t_inputrec *ir)
3947 {
3948     gmx_pme_pp_t pme_pp;
3949     int  natoms;
3950     matrix box;
3951     rvec *x_pp=NULL,*f_pp=NULL;
3952     real *chargeA=NULL,*chargeB=NULL;
3953     real lambda=0;
3954     int  maxshift_x=0,maxshift_y=0;
3955     real energy,dvdlambda;
3956     matrix vir;
3957     float cycles;
3958     int  count;
3959     gmx_bool bEnerVir;
3960     gmx_large_int_t step,step_rel;
3961
3962
3963     pme_pp = gmx_pme_pp_init(cr);
3964
3965     init_nrnb(nrnb);
3966
3967     count = 0;
3968     do /****** this is a quasi-loop over time steps! */
3969     {
3970         /* Domain decomposition */
3971         natoms = gmx_pme_recv_q_x(pme_pp,
3972                                   &chargeA,&chargeB,box,&x_pp,&f_pp,
3973                                   &maxshift_x,&maxshift_y,
3974                                   &pme->bFEP,&lambda,
3975                                   &bEnerVir,
3976                                   &step);
3977
3978         if (natoms == -1) {
3979             /* We should stop: break out of the loop */
3980             break;
3981         }
3982
3983         step_rel = step - ir->init_step;
3984
3985         if (count == 0)
3986             wallcycle_start(wcycle,ewcRUN);
3987
3988         wallcycle_start(wcycle,ewcPMEMESH);
3989
3990         dvdlambda = 0;
3991         clear_mat(vir);
3992         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
3993                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
3994                    &energy,lambda,&dvdlambda,
3995                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
3996
3997         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
3998
3999         gmx_pme_send_force_vir_ener(pme_pp,
4000                                     f_pp,vir,energy,dvdlambda,
4001                                     cycles);
4002
4003         count++;
4004
4005         if (step_rel == wcycle_get_reset_counters(wcycle))
4006         {
4007             /* Reset all the counters related to performance over the run */
4008             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4009             wcycle_set_reset_counters(wcycle, 0);
4010         }
4011
4012     } /***** end of quasi-loop, we stop with the break above */
4013     while (TRUE);
4014
4015     return 0;
4016 }
4017
4018 int gmx_pme_do(gmx_pme_t pme,
4019                int start,       int homenr,
4020                rvec x[],        rvec f[],
4021                real *chargeA,   real *chargeB,
4022                matrix box, t_commrec *cr,
4023                int  maxshift_x, int maxshift_y,
4024                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4025                matrix vir,      real ewaldcoeff,
4026                real *energy,    real lambda,
4027                real *dvdlambda, int flags)
4028 {
4029     int     q,d,i,j,ntot,npme;
4030     int     nx,ny,nz;
4031     int     n_d,local_ny;
4032     pme_atomcomm_t *atc=NULL;
4033     pmegrids_t *pmegrid=NULL;
4034     real    *grid=NULL;
4035     real    *ptr;
4036     rvec    *x_d,*f_d;
4037     real    *charge=NULL,*q_d;
4038     real    energy_AB[2];
4039     matrix  vir_AB[2];
4040     gmx_bool bClearF;
4041     gmx_parallel_3dfft_t pfft_setup;
4042     real *  fftgrid;
4043     t_complex * cfftgrid;
4044     int     thread;
4045     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4046     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4047
4048     assert(pme->nnodes > 0);
4049     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4050
4051     if (pme->nnodes > 1) {
4052         atc = &pme->atc[0];
4053         atc->npd = homenr;
4054         if (atc->npd > atc->pd_nalloc) {
4055             atc->pd_nalloc = over_alloc_dd(atc->npd);
4056             srenew(atc->pd,atc->pd_nalloc);
4057         }
4058         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4059     }
4060     else
4061     {
4062         /* This could be necessary for TPI */
4063         pme->atc[0].n = homenr;
4064     }
4065
4066     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4067         if (q == 0) {
4068             pmegrid = &pme->pmegridA;
4069             fftgrid = pme->fftgridA;
4070             cfftgrid = pme->cfftgridA;
4071             pfft_setup = pme->pfft_setupA;
4072             charge = chargeA+start;
4073         } else {
4074             pmegrid = &pme->pmegridB;
4075             fftgrid = pme->fftgridB;
4076             cfftgrid = pme->cfftgridB;
4077             pfft_setup = pme->pfft_setupB;
4078             charge = chargeB+start;
4079         }
4080         grid = pmegrid->grid.grid;
4081         /* Unpack structure */
4082         if (debug) {
4083             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4084                     cr->nnodes,cr->nodeid);
4085             fprintf(debug,"Grid = %p\n",(void*)grid);
4086             if (grid == NULL)
4087                 gmx_fatal(FARGS,"No grid!");
4088         }
4089         where();
4090
4091         m_inv_ur0(box,pme->recipbox);
4092
4093         if (pme->nnodes == 1) {
4094             atc = &pme->atc[0];
4095             if (DOMAINDECOMP(cr)) {
4096                 atc->n = homenr;
4097                 pme_realloc_atomcomm_things(atc);
4098             }
4099             atc->x = x;
4100             atc->q = charge;
4101             atc->f = f;
4102         } else {
4103             wallcycle_start(wcycle,ewcPME_REDISTXF);
4104             for(d=pme->ndecompdim-1; d>=0; d--)
4105             {
4106                 if (d == pme->ndecompdim-1)
4107                 {
4108                     n_d = homenr;
4109                     x_d = x + start;
4110                     q_d = charge;
4111                 }
4112                 else
4113                 {
4114                     n_d = pme->atc[d+1].n;
4115                     x_d = atc->x;
4116                     q_d = atc->q;
4117                 }
4118                 atc = &pme->atc[d];
4119                 atc->npd = n_d;
4120                 if (atc->npd > atc->pd_nalloc) {
4121                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4122                     srenew(atc->pd,atc->pd_nalloc);
4123                 }
4124                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4125                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4126                 where();
4127
4128                 GMX_BARRIER(cr->mpi_comm_mygroup);
4129                 /* Redistribute x (only once) and qA or qB */
4130                 if (DOMAINDECOMP(cr)) {
4131                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4132                 } else {
4133                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4134                 }
4135             }
4136             where();
4137
4138             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4139         }
4140
4141         if (debug)
4142             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4143                     cr->nodeid,atc->n);
4144
4145         if (flags & GMX_PME_SPREAD_Q)
4146         {
4147             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4148
4149             /* Spread the charges on a grid */
4150             GMX_MPE_LOG(ev_spread_on_grid_start);
4151
4152             /* Spread the charges on a grid */
4153             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4154             GMX_MPE_LOG(ev_spread_on_grid_finish);
4155
4156             if (q == 0)
4157             {
4158                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4159             }
4160             inc_nrnb(nrnb,eNR_SPREADQBSP,
4161                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4162
4163             if (pme->nthread == 1)
4164             {
4165                 wrap_periodic_pmegrid(pme,grid);
4166
4167                 /* sum contributions to local grid from other nodes */
4168 #ifdef GMX_MPI
4169                 if (pme->nnodes > 1)
4170                 {
4171                     GMX_BARRIER(cr->mpi_comm_mygroup);
4172                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4173                     where();
4174                 }
4175 #endif
4176
4177                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4178             }
4179
4180             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4181
4182             /*
4183             dump_local_fftgrid(pme,fftgrid);
4184             exit(0);
4185             */
4186         }
4187
4188         /* Here we start a large thread parallel region */
4189 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4190         for(thread=0; thread<pme->nthread; thread++)
4191         {
4192             if (flags & GMX_PME_SOLVE)
4193             {
4194                 int loop_count;
4195
4196                 /* do 3d-fft */
4197                 if (thread == 0)
4198                 {
4199                     GMX_BARRIER(cr->mpi_comm_mygroup);
4200                     GMX_MPE_LOG(ev_gmxfft3d_start);
4201                     wallcycle_start(wcycle,ewcPME_FFT);
4202                 }
4203                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4204                                            fftgrid,cfftgrid,thread,wcycle);
4205                 if (thread == 0)
4206                 {
4207                     wallcycle_stop(wcycle,ewcPME_FFT);
4208                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4209                 }
4210                 where();
4211
4212                 /* solve in k-space for our local cells */
4213                 if (thread == 0)
4214                 {
4215                     GMX_BARRIER(cr->mpi_comm_mygroup);
4216                     GMX_MPE_LOG(ev_solve_pme_start);
4217                     wallcycle_start(wcycle,ewcPME_SOLVE);
4218                 }
4219                 loop_count =
4220                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4221                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4222                                   bCalcEnerVir,
4223                                   pme->nthread,thread);
4224                 if (thread == 0)
4225                 {
4226                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4227                     where();
4228                     GMX_MPE_LOG(ev_solve_pme_finish);
4229                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4230                 }
4231             }
4232
4233             if (bCalcF)
4234             {
4235                 /* do 3d-invfft */
4236                 if (thread == 0)
4237                 {
4238                     GMX_BARRIER(cr->mpi_comm_mygroup);
4239                     GMX_MPE_LOG(ev_gmxfft3d_start);
4240                     where();
4241                     wallcycle_start(wcycle,ewcPME_FFT);
4242                 }
4243                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4244                                            cfftgrid,fftgrid,thread,wcycle);
4245                 if (thread == 0)
4246                 {
4247                     wallcycle_stop(wcycle,ewcPME_FFT);
4248
4249                     where();
4250                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4251
4252                     if (pme->nodeid == 0)
4253                     {
4254                         ntot = pme->nkx*pme->nky*pme->nkz;
4255                         npme  = ntot*log((real)ntot)/log(2.0);
4256                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4257                     }
4258
4259                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4260                 }
4261
4262                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4263             }
4264         }
4265         /* End of thread parallel section.
4266          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4267          */
4268
4269         if (bCalcF)
4270         {
4271             /* distribute local grid to all nodes */
4272 #ifdef GMX_MPI
4273             if (pme->nnodes > 1) {
4274                 GMX_BARRIER(cr->mpi_comm_mygroup);
4275                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4276             }
4277 #endif
4278             where();
4279
4280             unwrap_periodic_pmegrid(pme,grid);
4281
4282             /* interpolate forces for our local atoms */
4283             GMX_BARRIER(cr->mpi_comm_mygroup);
4284             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4285
4286             where();
4287
4288             /* If we are running without parallelization,
4289              * atc->f is the actual force array, not a buffer,
4290              * therefore we should not clear it.
4291              */
4292             bClearF = (q == 0 && PAR(cr));
4293 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4294             for(thread=0; thread<pme->nthread; thread++)
4295             {
4296                 gather_f_bsplines(pme,grid,bClearF,atc,
4297                                   &atc->spline[thread],
4298                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4299             }
4300
4301             where();
4302
4303             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4304
4305             inc_nrnb(nrnb,eNR_GATHERFBSP,
4306                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4307             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4308         }
4309
4310         if (bCalcEnerVir)
4311         {
4312             /* This should only be called on the master thread
4313              * and after the threads have synchronized.
4314              */
4315             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4316         }
4317     } /* of q-loop */
4318
4319     if (bCalcF && pme->nnodes > 1) {
4320         wallcycle_start(wcycle,ewcPME_REDISTXF);
4321         for(d=0; d<pme->ndecompdim; d++)
4322         {
4323             atc = &pme->atc[d];
4324             if (d == pme->ndecompdim - 1)
4325             {
4326                 n_d = homenr;
4327                 f_d = f + start;
4328             }
4329             else
4330             {
4331                 n_d = pme->atc[d+1].n;
4332                 f_d = pme->atc[d+1].f;
4333             }
4334             GMX_BARRIER(cr->mpi_comm_mygroup);
4335             if (DOMAINDECOMP(cr)) {
4336                 dd_pmeredist_f(pme,atc,n_d,f_d,
4337                                d==pme->ndecompdim-1 && pme->bPPnode);
4338             } else {
4339                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4340             }
4341         }
4342
4343         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4344     }
4345     where();
4346
4347     if (bCalcEnerVir)
4348     {
4349         if (!pme->bFEP) {
4350             *energy = energy_AB[0];
4351             m_add(vir,vir_AB[0],vir);
4352         } else {
4353             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4354             *dvdlambda += energy_AB[1] - energy_AB[0];
4355             for(i=0; i<DIM; i++)
4356             {
4357                 for(j=0; j<DIM; j++)
4358                 {
4359                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4360                         lambda*vir_AB[1][i][j];
4361                 }
4362             }
4363         }
4364     }
4365     else
4366     {
4367         *energy = 0;
4368     }
4369
4370     if (debug)
4371     {
4372         fprintf(debug,"PME mesh energy: %g\n",*energy);
4373     }
4374
4375     return 0;
4376 }