fixes typo gmx_omp introduced by 697bcdc
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include <stdio.h>
71 #include <string.h>
72 #include <math.h>
73 #include <assert.h>
74 #include "typedefs.h"
75 #include "txtdump.h"
76 #include "vec.h"
77 #include "gmxcomplex.h"
78 #include "smalloc.h"
79 #include "futil.h"
80 #include "coulomb.h"
81 #include "gmx_fatal.h"
82 #include "pme.h"
83 #include "network.h"
84 #include "physics.h"
85 #include "nrnb.h"
86 #include "copyrite.h"
87 #include "gmx_wallcycle.h"
88 #include "gmx_parallel_3dfft.h"
89 #include "pdbio.h"
90 #include "gmx_cyclecounter.h"
91
92 /* Single precision, with SSE2 or higher available */
93 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
94
95 #include "gmx_x86_sse2.h"
96 #include "gmx_math_x86_sse2_single.h"
97
98 #define PME_SSE
99 /* Some old AMD processors could have problems with unaligned loads+stores */
100 #ifndef GMX_FAHCORE
101 #define PME_SSE_UNALIGNED
102 #endif
103 #endif
104
105 #include "mpelogging.h"
106
107 #define DFT_TOL 1e-7
108 /* #define PRT_FORCE */
109 /* conditions for on the fly time-measurement */
110 /* #define TAKETIME (step > 1 && timesteps < 10) */
111 #define TAKETIME FALSE
112
113 /* #define PME_TIME_THREADS */
114
115 #ifdef GMX_DOUBLE
116 #define mpi_type MPI_DOUBLE
117 #else
118 #define mpi_type MPI_FLOAT
119 #endif
120
121 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
122 #define GMX_CACHE_SEP 64
123
124 /* We only define a maximum to be able to use local arrays without allocation.
125  * An order larger than 12 should never be needed, even for test cases.
126  * If needed it can be changed here.
127  */
128 #define PME_ORDER_MAX 12
129
130 /* Internal datastructures */
131 typedef struct {
132     int send_index0;
133     int send_nindex;
134     int recv_index0;
135     int recv_nindex;
136 } pme_grid_comm_t;
137
138 typedef struct {
139 #ifdef GMX_MPI
140     MPI_Comm mpi_comm;
141 #endif
142     int  nnodes,nodeid;
143     int  *s2g0;
144     int  *s2g1;
145     int  noverlap_nodes;
146     int  *send_id,*recv_id;
147     pme_grid_comm_t *comm_data;
148     real *sendbuf;
149     real *recvbuf;
150 } pme_overlap_t;
151
152 typedef struct {
153     int *n;     /* Cumulative counts of the number of particles per thread */
154     int nalloc; /* Allocation size of i */
155     int *i;     /* Particle indices ordered on thread index (n) */
156 } thread_plist_t;
157
158 typedef struct {
159     int  n;
160     int  *ind;
161     splinevec theta;
162     splinevec dtheta;
163 } splinedata_t;
164
165 typedef struct {
166     int  dimind;            /* The index of the dimension, 0=x, 1=y */
167     int  nslab;
168     int  nodeid;
169 #ifdef GMX_MPI
170     MPI_Comm mpi_comm;
171 #endif
172
173     int  *node_dest;        /* The nodes to send x and q to with DD */
174     int  *node_src;         /* The nodes to receive x and q from with DD */
175     int  *buf_index;        /* Index for commnode into the buffers */
176
177     int  maxshift;
178
179     int  npd;
180     int  pd_nalloc;
181     int  *pd;
182     int  *count;            /* The number of atoms to send to each node */
183     int  **count_thread;
184     int  *rcount;           /* The number of atoms to receive */
185
186     int  n;
187     int  nalloc;
188     rvec *x;
189     real *q;
190     rvec *f;
191     gmx_bool bSpread;       /* These coordinates are used for spreading */
192     int  pme_order;
193     ivec *idx;
194     rvec *fractx;            /* Fractional coordinate relative to the
195                               * lower cell boundary
196                               */
197     int  nthread;
198     int  *thread_idx;        /* Which thread should spread which charge */
199     thread_plist_t *thread_plist;
200     splinedata_t *spline;
201 } pme_atomcomm_t;
202
203 #define FLBS  3
204 #define FLBSZ 4
205
206 typedef struct {
207     ivec ci;     /* The spatial location of this grid       */
208     ivec n;      /* The size of *grid, including order-1    */
209     ivec offset; /* The grid offset from the full node grid */
210     int  order;  /* PME spreading order                     */
211     real *grid;  /* The grid local thread, size n           */
212 } pmegrid_t;
213
214 typedef struct {
215     pmegrid_t grid;     /* The full node grid (non thread-local)            */
216     int  nthread;       /* The number of threads operating on this grid     */
217     ivec nc;            /* The local spatial decomposition over the threads */
218     pmegrid_t *grid_th; /* Array of grids for each thread                   */
219     int  **g2t;         /* The grid to thread index                         */
220     ivec nthread_comm;  /* The number of threads to communicate with        */
221 } pmegrids_t;
222
223
224 typedef struct {
225 #ifdef PME_SSE
226     /* Masks for SSE aligned spreading and gathering */
227     __m128 mask_SSE0[6],mask_SSE1[6];
228 #else
229     int dummy; /* C89 requires that struct has at least one member */
230 #endif
231 } pme_spline_work_t;
232
233 typedef struct {
234     /* work data for solve_pme */
235     int      nalloc;
236     real *   mhx;
237     real *   mhy;
238     real *   mhz;
239     real *   m2;
240     real *   denom;
241     real *   tmp1_alloc;
242     real *   tmp1;
243     real *   eterm;
244     real *   m2inv;
245
246     real     energy;
247     matrix   vir;
248 } pme_work_t;
249
250 typedef struct gmx_pme {
251     int  ndecompdim;         /* The number of decomposition dimensions */
252     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
253     int  nodeid_major;
254     int  nodeid_minor;
255     int  nnodes;             /* The number of nodes doing PME */
256     int  nnodes_major;
257     int  nnodes_minor;
258
259     MPI_Comm mpi_comm;
260     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
261 #ifdef GMX_MPI
262     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
263 #endif
264
265     int  nthread;            /* The number of threads doing PME */
266
267     gmx_bool bPPnode;        /* Node also does particle-particle forces */
268     gmx_bool bFEP;           /* Compute Free energy contribution */
269     int nkx,nky,nkz;         /* Grid dimensions */
270     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
271     int pme_order;
272     real epsilon_r;
273
274     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
275     pmegrids_t pmegridB;
276     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
277     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
278     /* pmegrid_nz might be larger than strictly necessary to ensure
279      * memory alignment, pmegrid_nz_base gives the real base size.
280      */
281     int     pmegrid_nz_base;
282     /* The local PME grid starting indices */
283     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
284
285     /* Work data for spreading and gathering */
286     pme_spline_work_t spline_work;
287
288     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
289     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
290     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
291
292     t_complex *cfftgridA;             /* Grids for complex FFT data */
293     t_complex *cfftgridB;
294     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
295
296     gmx_parallel_3dfft_t  pfft_setupA;
297     gmx_parallel_3dfft_t  pfft_setupB;
298
299     int  *nnx,*nny,*nnz;
300     real *fshx,*fshy,*fshz;
301
302     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
303     matrix    recipbox;
304     splinevec bsp_mod;
305
306     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
307
308     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
309
310     rvec *bufv;             /* Communication buffer */
311     real *bufr;             /* Communication buffer */
312     int  buf_nalloc;        /* The communication buffer size */
313
314     /* thread local work data for solve_pme */
315     pme_work_t *work;
316
317     /* Work data for PME_redist */
318     gmx_bool redist_init;
319     int *    scounts;
320     int *    rcounts;
321     int *    sdispls;
322     int *    rdispls;
323     int *    sidx;
324     int *    idxa;
325     real *   redist_buf;
326     int      redist_buf_nalloc;
327
328     /* Work data for sum_qgrid */
329     real *   sum_qgrid_tmp;
330     real *   sum_qgrid_dd_tmp;
331 } t_gmx_pme;
332
333
334 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
335                                    int start,int end,int thread)
336 {
337     int  i;
338     int  *idxptr,tix,tiy,tiz;
339     real *xptr,*fptr,tx,ty,tz;
340     real rxx,ryx,ryy,rzx,rzy,rzz;
341     int  nx,ny,nz;
342     int  start_ix,start_iy,start_iz;
343     int  *g2tx,*g2ty,*g2tz;
344     gmx_bool bThreads;
345     int  *thread_idx=NULL;
346     thread_plist_t *tpl=NULL;
347     int  *tpl_n=NULL;
348     int  thread_i;
349
350     nx  = pme->nkx;
351     ny  = pme->nky;
352     nz  = pme->nkz;
353
354     start_ix = pme->pmegrid_start_ix;
355     start_iy = pme->pmegrid_start_iy;
356     start_iz = pme->pmegrid_start_iz;
357
358     rxx = pme->recipbox[XX][XX];
359     ryx = pme->recipbox[YY][XX];
360     ryy = pme->recipbox[YY][YY];
361     rzx = pme->recipbox[ZZ][XX];
362     rzy = pme->recipbox[ZZ][YY];
363     rzz = pme->recipbox[ZZ][ZZ];
364
365     g2tx = pme->pmegridA.g2t[XX];
366     g2ty = pme->pmegridA.g2t[YY];
367     g2tz = pme->pmegridA.g2t[ZZ];
368
369     bThreads = (atc->nthread > 1);
370     if (bThreads)
371     {
372         thread_idx = atc->thread_idx;
373
374         tpl   = &atc->thread_plist[thread];
375         tpl_n = tpl->n;
376         for(i=0; i<atc->nthread; i++)
377         {
378             tpl_n[i] = 0;
379         }
380     }
381
382     for(i=start; i<end; i++) {
383         xptr   = atc->x[i];
384         idxptr = atc->idx[i];
385         fptr   = atc->fractx[i];
386
387         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
388         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
389         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
390         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
391
392         tix = (int)(tx);
393         tiy = (int)(ty);
394         tiz = (int)(tz);
395
396         /* Because decomposition only occurs in x and y,
397          * we never have a fraction correction in z.
398          */
399         fptr[XX] = tx - tix + pme->fshx[tix];
400         fptr[YY] = ty - tiy + pme->fshy[tiy];
401         fptr[ZZ] = tz - tiz;
402
403         idxptr[XX] = pme->nnx[tix];
404         idxptr[YY] = pme->nny[tiy];
405         idxptr[ZZ] = pme->nnz[tiz];
406
407 #ifdef DEBUG
408         range_check(idxptr[XX],0,pme->pmegrid_nx);
409         range_check(idxptr[YY],0,pme->pmegrid_ny);
410         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
411 #endif
412
413         if (bThreads)
414         {
415             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
416             thread_idx[i] = thread_i;
417             tpl_n[thread_i]++;
418         }
419     }
420
421     if (bThreads)
422     {
423         /* Make a list of particle indices sorted on thread */
424
425         /* Get the cumulative count */
426         for(i=1; i<atc->nthread; i++)
427         {
428             tpl_n[i] += tpl_n[i-1];
429         }
430         /* The current implementation distributes particles equally
431          * over the threads, so we could actually allocate for that
432          * in pme_realloc_atomcomm_things.
433          */
434         if (tpl_n[atc->nthread-1] > tpl->nalloc)
435         {
436             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
437             srenew(tpl->i,tpl->nalloc);
438         }
439         /* Set tpl_n to the cumulative start */
440         for(i=atc->nthread-1; i>=1; i--)
441         {
442             tpl_n[i] = tpl_n[i-1];
443         }
444         tpl_n[0] = 0;
445
446         /* Fill our thread local array with indices sorted on thread */
447         for(i=start; i<end; i++)
448         {
449             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
450         }
451         /* Now tpl_n contains the cummulative count again */
452     }
453 }
454
455 static void make_thread_local_ind(pme_atomcomm_t *atc,
456                                   int thread,splinedata_t *spline)
457 {
458     int  n,t,i,start,end;
459     thread_plist_t *tpl;
460
461     /* Combine the indices made by each thread into one index */
462
463     n = 0;
464     start = 0;
465     for(t=0; t<atc->nthread; t++)
466     {
467         tpl = &atc->thread_plist[t];
468         /* Copy our part (start - end) from the list of thread t */
469         if (thread > 0)
470         {
471             start = tpl->n[thread-1];
472         }
473         end = tpl->n[thread];
474         for(i=start; i<end; i++)
475         {
476             spline->ind[n++] = tpl->i[i];
477         }
478     }
479
480     spline->n = n;
481 }
482
483
484 static void pme_calc_pidx(int start, int end,
485                           matrix recipbox, rvec x[],
486                           pme_atomcomm_t *atc, int *count)
487 {
488     int  nslab,i;
489     int  si;
490     real *xptr,s;
491     real rxx,ryx,rzx,ryy,rzy;
492     int *pd;
493
494     /* Calculate PME task index (pidx) for each grid index.
495      * Here we always assign equally sized slabs to each node
496      * for load balancing reasons (the PME grid spacing is not used).
497      */
498
499     nslab = atc->nslab;
500     pd    = atc->pd;
501
502     /* Reset the count */
503     for(i=0; i<nslab; i++)
504     {
505         count[i] = 0;
506     }
507
508     if (atc->dimind == 0)
509     {
510         rxx = recipbox[XX][XX];
511         ryx = recipbox[YY][XX];
512         rzx = recipbox[ZZ][XX];
513         /* Calculate the node index in x-dimension */
514         for(i=start; i<end; i++)
515         {
516             xptr   = x[i];
517             /* Fractional coordinates along box vectors */
518             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
519             si = (int)(s + 2*nslab) % nslab;
520             pd[i] = si;
521             count[si]++;
522         }
523     }
524     else
525     {
526         ryy = recipbox[YY][YY];
527         rzy = recipbox[ZZ][YY];
528         /* Calculate the node index in y-dimension */
529         for(i=start; i<end; i++)
530         {
531             xptr   = x[i];
532             /* Fractional coordinates along box vectors */
533             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
534             si = (int)(s + 2*nslab) % nslab;
535             pd[i] = si;
536             count[si]++;
537         }
538     }
539 }
540
541 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
542                                   pme_atomcomm_t *atc)
543 {
544     int nthread,thread,slab;
545
546     nthread = atc->nthread;
547
548 #pragma omp parallel for num_threads(nthread) schedule(static)
549     for(thread=0; thread<nthread; thread++)
550     {
551         pme_calc_pidx(natoms* thread   /nthread,
552                       natoms*(thread+1)/nthread,
553                       recipbox,x,atc,atc->count_thread[thread]);
554     }
555     /* Non-parallel reduction, since nslab is small */
556
557     for(thread=1; thread<nthread; thread++)
558     {
559         for(slab=0; slab<atc->nslab; slab++)
560         {
561             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
562         }
563     }
564 }
565
566 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
567 {
568     int i,d;
569
570     srenew(spline->ind,atc->nalloc);
571     /* Initialize the index to identity so it works without threads */
572     for(i=0; i<atc->nalloc; i++)
573     {
574         spline->ind[i] = i;
575     }
576
577     for(d=0;d<DIM;d++)
578     {
579         srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
580         srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
581     }
582 }
583
584 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
585 {
586     int nalloc_old,i,j,nalloc_tpl;
587
588     /* We have to avoid a NULL pointer for atc->x to avoid
589      * possible fatal errors in MPI routines.
590      */
591     if (atc->n > atc->nalloc || atc->nalloc == 0)
592     {
593         nalloc_old = atc->nalloc;
594         atc->nalloc = over_alloc_dd(max(atc->n,1));
595
596         if (atc->nslab > 1) {
597             srenew(atc->x,atc->nalloc);
598             srenew(atc->q,atc->nalloc);
599             srenew(atc->f,atc->nalloc);
600             for(i=nalloc_old; i<atc->nalloc; i++)
601             {
602                 clear_rvec(atc->f[i]);
603             }
604         }
605         if (atc->bSpread) {
606             srenew(atc->fractx,atc->nalloc);
607             srenew(atc->idx   ,atc->nalloc);
608
609             if (atc->nthread > 1)
610             {
611                 srenew(atc->thread_idx,atc->nalloc);
612             }
613
614             for(i=0; i<atc->nthread; i++)
615             {
616                 pme_realloc_splinedata(&atc->spline[i],atc);
617             }
618         }
619     }
620 }
621
622 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
623                          int n, gmx_bool bXF, rvec *x_f, real *charge,
624                          pme_atomcomm_t *atc)
625 /* Redistribute particle data for PME calculation */
626 /* domain decomposition by x coordinate           */
627 {
628     int *idxa;
629     int i, ii;
630
631     if(FALSE == pme->redist_init) {
632         snew(pme->scounts,atc->nslab);
633         snew(pme->rcounts,atc->nslab);
634         snew(pme->sdispls,atc->nslab);
635         snew(pme->rdispls,atc->nslab);
636         snew(pme->sidx,atc->nslab);
637         pme->redist_init = TRUE;
638     }
639     if (n > pme->redist_buf_nalloc) {
640         pme->redist_buf_nalloc = over_alloc_dd(n);
641         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
642     }
643
644     pme->idxa = atc->pd;
645
646 #ifdef GMX_MPI
647     if (forw && bXF) {
648         /* forward, redistribution from pp to pme */
649
650         /* Calculate send counts and exchange them with other nodes */
651         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
652         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
653         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
654
655         /* Calculate send and receive displacements and index into send
656            buffer */
657         pme->sdispls[0]=0;
658         pme->rdispls[0]=0;
659         pme->sidx[0]=0;
660         for(i=1; i<atc->nslab; i++) {
661             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
662             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
663             pme->sidx[i]=pme->sdispls[i];
664         }
665         /* Total # of particles to be received */
666         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
667
668         pme_realloc_atomcomm_things(atc);
669
670         /* Copy particle coordinates into send buffer and exchange*/
671         for(i=0; (i<n); i++) {
672             ii=DIM*pme->sidx[pme->idxa[i]];
673             pme->sidx[pme->idxa[i]]++;
674             pme->redist_buf[ii+XX]=x_f[i][XX];
675             pme->redist_buf[ii+YY]=x_f[i][YY];
676             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
677         }
678         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
679                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
680                       pme->rvec_mpi, atc->mpi_comm);
681     }
682     if (forw) {
683         /* Copy charge into send buffer and exchange*/
684         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
685         for(i=0; (i<n); i++) {
686             ii=pme->sidx[pme->idxa[i]];
687             pme->sidx[pme->idxa[i]]++;
688             pme->redist_buf[ii]=charge[i];
689         }
690         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
691                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
692                       atc->mpi_comm);
693     }
694     else { /* backward, redistribution from pme to pp */
695         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
696                       pme->redist_buf, pme->scounts, pme->sdispls,
697                       pme->rvec_mpi, atc->mpi_comm);
698
699         /* Copy data from receive buffer */
700         for(i=0; i<atc->nslab; i++)
701             pme->sidx[i] = pme->sdispls[i];
702         for(i=0; (i<n); i++) {
703             ii = DIM*pme->sidx[pme->idxa[i]];
704             x_f[i][XX] += pme->redist_buf[ii+XX];
705             x_f[i][YY] += pme->redist_buf[ii+YY];
706             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
707             pme->sidx[pme->idxa[i]]++;
708         }
709     }
710 #endif
711 }
712
713 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
714                             gmx_bool bBackward,int shift,
715                             void *buf_s,int nbyte_s,
716                             void *buf_r,int nbyte_r)
717 {
718 #ifdef GMX_MPI
719     int dest,src;
720     MPI_Status stat;
721
722     if (bBackward == FALSE) {
723         dest = atc->node_dest[shift];
724         src  = atc->node_src[shift];
725     } else {
726         dest = atc->node_src[shift];
727         src  = atc->node_dest[shift];
728     }
729
730     if (nbyte_s > 0 && nbyte_r > 0) {
731         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
732                      dest,shift,
733                      buf_r,nbyte_r,MPI_BYTE,
734                      src,shift,
735                      atc->mpi_comm,&stat);
736     } else if (nbyte_s > 0) {
737         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
738                  dest,shift,
739                  atc->mpi_comm);
740     } else if (nbyte_r > 0) {
741         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
742                  src,shift,
743                  atc->mpi_comm,&stat);
744     }
745 #endif
746 }
747
748 static void dd_pmeredist_x_q(gmx_pme_t pme,
749                              int n, gmx_bool bX, rvec *x, real *charge,
750                              pme_atomcomm_t *atc)
751 {
752     int *commnode,*buf_index;
753     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
754
755     commnode  = atc->node_dest;
756     buf_index = atc->buf_index;
757
758     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
759
760     nsend = 0;
761     for(i=0; i<nnodes_comm; i++) {
762         buf_index[commnode[i]] = nsend;
763         nsend += atc->count[commnode[i]];
764     }
765     if (bX) {
766         if (atc->count[atc->nodeid] + nsend != n)
767             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
768                       "This usually means that your system is not well equilibrated.",
769                       n - (atc->count[atc->nodeid] + nsend),
770                       pme->nodeid,'x'+atc->dimind);
771
772         if (nsend > pme->buf_nalloc) {
773             pme->buf_nalloc = over_alloc_dd(nsend);
774             srenew(pme->bufv,pme->buf_nalloc);
775             srenew(pme->bufr,pme->buf_nalloc);
776         }
777
778         atc->n = atc->count[atc->nodeid];
779         for(i=0; i<nnodes_comm; i++) {
780             scount = atc->count[commnode[i]];
781             /* Communicate the count */
782             if (debug)
783                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
784                         atc->dimind,atc->nodeid,commnode[i],scount);
785             pme_dd_sendrecv(atc,FALSE,i,
786                             &scount,sizeof(int),
787                             &atc->rcount[i],sizeof(int));
788             atc->n += atc->rcount[i];
789         }
790
791         pme_realloc_atomcomm_things(atc);
792     }
793
794     local_pos = 0;
795     for(i=0; i<n; i++) {
796         node = atc->pd[i];
797         if (node == atc->nodeid) {
798             /* Copy direct to the receive buffer */
799             if (bX) {
800                 copy_rvec(x[i],atc->x[local_pos]);
801             }
802             atc->q[local_pos] = charge[i];
803             local_pos++;
804         } else {
805             /* Copy to the send buffer */
806             if (bX) {
807                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
808             }
809             pme->bufr[buf_index[node]] = charge[i];
810             buf_index[node]++;
811         }
812     }
813
814     buf_pos = 0;
815     for(i=0; i<nnodes_comm; i++) {
816         scount = atc->count[commnode[i]];
817         rcount = atc->rcount[i];
818         if (scount > 0 || rcount > 0) {
819             if (bX) {
820                 /* Communicate the coordinates */
821                 pme_dd_sendrecv(atc,FALSE,i,
822                                 pme->bufv[buf_pos],scount*sizeof(rvec),
823                                 atc->x[local_pos],rcount*sizeof(rvec));
824             }
825             /* Communicate the charges */
826             pme_dd_sendrecv(atc,FALSE,i,
827                             pme->bufr+buf_pos,scount*sizeof(real),
828                             atc->q+local_pos,rcount*sizeof(real));
829             buf_pos   += scount;
830             local_pos += atc->rcount[i];
831         }
832     }
833 }
834
835 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
836                            int n, rvec *f,
837                            gmx_bool bAddF)
838 {
839   int *commnode,*buf_index;
840   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
841
842   commnode  = atc->node_dest;
843   buf_index = atc->buf_index;
844
845   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
846
847   local_pos = atc->count[atc->nodeid];
848   buf_pos = 0;
849   for(i=0; i<nnodes_comm; i++) {
850     scount = atc->rcount[i];
851     rcount = atc->count[commnode[i]];
852     if (scount > 0 || rcount > 0) {
853       /* Communicate the forces */
854       pme_dd_sendrecv(atc,TRUE,i,
855                       atc->f[local_pos],scount*sizeof(rvec),
856                       pme->bufv[buf_pos],rcount*sizeof(rvec));
857       local_pos += scount;
858     }
859     buf_index[commnode[i]] = buf_pos;
860     buf_pos   += rcount;
861   }
862
863     local_pos = 0;
864     if (bAddF)
865     {
866         for(i=0; i<n; i++)
867         {
868             node = atc->pd[i];
869             if (node == atc->nodeid)
870             {
871                 /* Add from the local force array */
872                 rvec_inc(f[i],atc->f[local_pos]);
873                 local_pos++;
874             }
875             else
876             {
877                 /* Add from the receive buffer */
878                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
879                 buf_index[node]++;
880             }
881         }
882     }
883     else
884     {
885         for(i=0; i<n; i++)
886         {
887             node = atc->pd[i];
888             if (node == atc->nodeid)
889             {
890                 /* Copy from the local force array */
891                 copy_rvec(atc->f[local_pos],f[i]);
892                 local_pos++;
893             }
894             else
895             {
896                 /* Copy from the receive buffer */
897                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
898                 buf_index[node]++;
899             }
900         }
901     }
902 }
903
904 #ifdef GMX_MPI
905 static void
906 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
907 {
908     pme_overlap_t *overlap;
909     int send_index0,send_nindex;
910     int recv_index0,recv_nindex;
911     MPI_Status stat;
912     int i,j,k,ix,iy,iz,icnt;
913     int ipulse,send_id,recv_id,datasize;
914     real *p;
915     real *sendptr,*recvptr;
916
917     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
918     overlap = &pme->overlap[1];
919
920     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
921     {
922         /* Since we have already (un)wrapped the overlap in the z-dimension,
923          * we only have to communicate 0 to nkz (not pmegrid_nz).
924          */
925         if (direction==GMX_SUM_QGRID_FORWARD)
926         {
927             send_id = overlap->send_id[ipulse];
928             recv_id = overlap->recv_id[ipulse];
929             send_index0   = overlap->comm_data[ipulse].send_index0;
930             send_nindex   = overlap->comm_data[ipulse].send_nindex;
931             recv_index0   = overlap->comm_data[ipulse].recv_index0;
932             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
933         }
934         else
935         {
936             send_id = overlap->recv_id[ipulse];
937             recv_id = overlap->send_id[ipulse];
938             send_index0   = overlap->comm_data[ipulse].recv_index0;
939             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
940             recv_index0   = overlap->comm_data[ipulse].send_index0;
941             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
942         }
943
944         /* Copy data to contiguous send buffer */
945         if (debug)
946         {
947             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
948                     pme->nodeid,overlap->nodeid,send_id,
949                     pme->pmegrid_start_iy,
950                     send_index0-pme->pmegrid_start_iy,
951                     send_index0-pme->pmegrid_start_iy+send_nindex);
952         }
953         icnt = 0;
954         for(i=0;i<pme->pmegrid_nx;i++)
955         {
956             ix = i;
957             for(j=0;j<send_nindex;j++)
958             {
959                 iy = j + send_index0 - pme->pmegrid_start_iy;
960                 for(k=0;k<pme->nkz;k++)
961                 {
962                     iz = k;
963                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
964                 }
965             }
966         }
967
968         datasize      = pme->pmegrid_nx * pme->nkz;
969
970         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
971                      send_id,ipulse,
972                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
973                      recv_id,ipulse,
974                      overlap->mpi_comm,&stat);
975
976         /* Get data from contiguous recv buffer */
977         if (debug)
978         {
979             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
980                     pme->nodeid,overlap->nodeid,recv_id,
981                     pme->pmegrid_start_iy,
982                     recv_index0-pme->pmegrid_start_iy,
983                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
984         }
985         icnt = 0;
986         for(i=0;i<pme->pmegrid_nx;i++)
987         {
988             ix = i;
989             for(j=0;j<recv_nindex;j++)
990             {
991                 iy = j + recv_index0 - pme->pmegrid_start_iy;
992                 for(k=0;k<pme->nkz;k++)
993                 {
994                     iz = k;
995                     if(direction==GMX_SUM_QGRID_FORWARD)
996                     {
997                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
998                     }
999                     else
1000                     {
1001                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1002                     }
1003                 }
1004             }
1005         }
1006     }
1007
1008     /* Major dimension is easier, no copying required,
1009      * but we might have to sum to separate array.
1010      * Since we don't copy, we have to communicate up to pmegrid_nz,
1011      * not nkz as for the minor direction.
1012      */
1013     overlap = &pme->overlap[0];
1014
1015     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1016     {
1017         if(direction==GMX_SUM_QGRID_FORWARD)
1018         {
1019             send_id = overlap->send_id[ipulse];
1020             recv_id = overlap->recv_id[ipulse];
1021             send_index0   = overlap->comm_data[ipulse].send_index0;
1022             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1023             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1024             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1025             recvptr   = overlap->recvbuf;
1026         }
1027         else
1028         {
1029             send_id = overlap->recv_id[ipulse];
1030             recv_id = overlap->send_id[ipulse];
1031             send_index0   = overlap->comm_data[ipulse].recv_index0;
1032             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1033             recv_index0   = overlap->comm_data[ipulse].send_index0;
1034             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1035             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1036         }
1037
1038         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1039         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1040
1041         if (debug)
1042         {
1043             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1044                     pme->nodeid,overlap->nodeid,send_id,
1045                     pme->pmegrid_start_ix,
1046                     send_index0-pme->pmegrid_start_ix,
1047                     send_index0-pme->pmegrid_start_ix+send_nindex);
1048             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1049                     pme->nodeid,overlap->nodeid,recv_id,
1050                     pme->pmegrid_start_ix,
1051                     recv_index0-pme->pmegrid_start_ix,
1052                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1053         }
1054
1055         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1056                      send_id,ipulse,
1057                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1058                      recv_id,ipulse,
1059                      overlap->mpi_comm,&stat);
1060
1061         /* ADD data from contiguous recv buffer */
1062         if(direction==GMX_SUM_QGRID_FORWARD)
1063         {
1064             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1065             for(i=0;i<recv_nindex*datasize;i++)
1066             {
1067                 p[i] += overlap->recvbuf[i];
1068             }
1069         }
1070     }
1071 }
1072 #endif
1073
1074
1075 static int
1076 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1077 {
1078     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1079     ivec    local_pme_size;
1080     int     i,ix,iy,iz;
1081     int     pmeidx,fftidx;
1082
1083     /* Dimensions should be identical for A/B grid, so we just use A here */
1084     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1085                                    local_fft_ndata,
1086                                    local_fft_offset,
1087                                    local_fft_size);
1088
1089     local_pme_size[0] = pme->pmegrid_nx;
1090     local_pme_size[1] = pme->pmegrid_ny;
1091     local_pme_size[2] = pme->pmegrid_nz;
1092
1093     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1094      the offset is identical, and the PME grid always has more data (due to overlap)
1095      */
1096     {
1097 #ifdef DEBUG_PME
1098         FILE *fp,*fp2;
1099         char fn[STRLEN],format[STRLEN];
1100         real val;
1101         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1102         fp = ffopen(fn,"w");
1103         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1104         fp2 = ffopen(fn,"w");
1105      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1106 #endif
1107
1108     for(ix=0;ix<local_fft_ndata[XX];ix++)
1109     {
1110         for(iy=0;iy<local_fft_ndata[YY];iy++)
1111         {
1112             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1113             {
1114                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1115                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1116                 fftgrid[fftidx] = pmegrid[pmeidx];
1117 #ifdef DEBUG_PME
1118                 val = 100*pmegrid[pmeidx];
1119                 if (pmegrid[pmeidx] != 0)
1120                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1121                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1122                 if (pmegrid[pmeidx] != 0)
1123                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1124                             "qgrid",
1125                             pme->pmegrid_start_ix + ix,
1126                             pme->pmegrid_start_iy + iy,
1127                             pme->pmegrid_start_iz + iz,
1128                             pmegrid[pmeidx]);
1129 #endif
1130             }
1131         }
1132     }
1133 #ifdef DEBUG_PME
1134     ffclose(fp);
1135     ffclose(fp2);
1136 #endif
1137     }
1138     return 0;
1139 }
1140
1141
1142 static gmx_cycles_t omp_cyc_start()
1143 {
1144     return gmx_cycles_read();
1145 }
1146
1147 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1148 {
1149     return gmx_cycles_read() - c;
1150 }
1151
1152
1153 static int
1154 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1155                         int nthread,int thread)
1156 {
1157     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1158     ivec    local_pme_size;
1159     int     ixy0,ixy1,ixy,ix,iy,iz;
1160     int     pmeidx,fftidx;
1161 #ifdef PME_TIME_THREADS
1162     gmx_cycles_t c1;
1163     static double cs1=0;
1164     static int cnt=0;
1165 #endif
1166
1167 #ifdef PME_TIME_THREADS
1168     c1 = omp_cyc_start();
1169 #endif
1170     /* Dimensions should be identical for A/B grid, so we just use A here */
1171     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1172                                    local_fft_ndata,
1173                                    local_fft_offset,
1174                                    local_fft_size);
1175
1176     local_pme_size[0] = pme->pmegrid_nx;
1177     local_pme_size[1] = pme->pmegrid_ny;
1178     local_pme_size[2] = pme->pmegrid_nz;
1179
1180     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1181      the offset is identical, and the PME grid always has more data (due to overlap)
1182      */
1183     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1184     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1185
1186     for(ixy=ixy0;ixy<ixy1;ixy++)
1187     {
1188         ix = ixy/local_fft_ndata[YY];
1189         iy = ixy - ix*local_fft_ndata[YY];
1190
1191         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1192         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1193         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1194         {
1195             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1196         }
1197     }
1198
1199 #ifdef PME_TIME_THREADS
1200     c1 = omp_cyc_end(c1);
1201     cs1 += (double)c1;
1202     cnt++;
1203     if (cnt % 20 == 0)
1204     {
1205         printf("copy %.2f\n",cs1*1e-9);
1206     }
1207 #endif
1208
1209     return 0;
1210 }
1211
1212
1213 static void
1214 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1215 {
1216     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1217
1218     nx = pme->nkx;
1219     ny = pme->nky;
1220     nz = pme->nkz;
1221
1222     pnx = pme->pmegrid_nx;
1223     pny = pme->pmegrid_ny;
1224     pnz = pme->pmegrid_nz;
1225
1226     overlap = pme->pme_order - 1;
1227
1228     /* Add periodic overlap in z */
1229     for(ix=0; ix<pme->pmegrid_nx; ix++)
1230     {
1231         for(iy=0; iy<pme->pmegrid_ny; iy++)
1232         {
1233             for(iz=0; iz<overlap; iz++)
1234             {
1235                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1236                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1237             }
1238         }
1239     }
1240
1241     if (pme->nnodes_minor == 1)
1242     {
1243        for(ix=0; ix<pme->pmegrid_nx; ix++)
1244        {
1245            for(iy=0; iy<overlap; iy++)
1246            {
1247                for(iz=0; iz<nz; iz++)
1248                {
1249                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1250                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1251                }
1252            }
1253        }
1254     }
1255
1256     if (pme->nnodes_major == 1)
1257     {
1258         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1259
1260         for(ix=0; ix<overlap; ix++)
1261         {
1262             for(iy=0; iy<ny_x; iy++)
1263             {
1264                 for(iz=0; iz<nz; iz++)
1265                 {
1266                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1267                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1268                 }
1269             }
1270         }
1271     }
1272 }
1273
1274
1275 static void
1276 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1277 {
1278     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1279
1280     nx = pme->nkx;
1281     ny = pme->nky;
1282     nz = pme->nkz;
1283
1284     pnx = pme->pmegrid_nx;
1285     pny = pme->pmegrid_ny;
1286     pnz = pme->pmegrid_nz;
1287
1288     overlap = pme->pme_order - 1;
1289
1290     if (pme->nnodes_major == 1)
1291     {
1292         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1293
1294         for(ix=0; ix<overlap; ix++)
1295         {
1296             int iy,iz;
1297
1298             for(iy=0; iy<ny_x; iy++)
1299             {
1300                 for(iz=0; iz<nz; iz++)
1301                 {
1302                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1303                         pmegrid[(ix*pny+iy)*pnz+iz];
1304                 }
1305             }
1306         }
1307     }
1308
1309     if (pme->nnodes_minor == 1)
1310     {
1311 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1312        for(ix=0; ix<pme->pmegrid_nx; ix++)
1313        {
1314            int iy,iz;
1315
1316            for(iy=0; iy<overlap; iy++)
1317            {
1318                for(iz=0; iz<nz; iz++)
1319                {
1320                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1321                        pmegrid[(ix*pny+iy)*pnz+iz];
1322                }
1323            }
1324        }
1325     }
1326
1327     /* Copy periodic overlap in z */
1328 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1329     for(ix=0; ix<pme->pmegrid_nx; ix++)
1330     {
1331         int iy,iz;
1332
1333         for(iy=0; iy<pme->pmegrid_ny; iy++)
1334         {
1335             for(iz=0; iz<overlap; iz++)
1336             {
1337                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1338                     pmegrid[(ix*pny+iy)*pnz+iz];
1339             }
1340         }
1341     }
1342 }
1343
1344 static void clear_grid(int nx,int ny,int nz,real *grid,
1345                        ivec fs,int *flag,
1346                        int fx,int fy,int fz,
1347                        int order)
1348 {
1349     int nc,ncz;
1350     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1351     int flind;
1352
1353     nc  = 2 + (order - 2)/FLBS;
1354     ncz = 2 + (order - 2)/FLBSZ;
1355
1356     for(fsx=fx; fsx<fx+nc; fsx++)
1357     {
1358         for(fsy=fy; fsy<fy+nc; fsy++)
1359         {
1360             for(fsz=fz; fsz<fz+ncz; fsz++)
1361             {
1362                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1363                 if (flag[flind] == 0)
1364                 {
1365                     gx = fsx*FLBS;
1366                     gy = fsy*FLBS;
1367                     gz = fsz*FLBSZ;
1368                     g0x = (gx*ny + gy)*nz + gz;
1369                     for(x=0; x<FLBS; x++)
1370                     {
1371                         g0y = g0x;
1372                         for(y=0; y<FLBS; y++)
1373                         {
1374                             for(z=0; z<FLBSZ; z++)
1375                             {
1376                                 grid[g0y+z] = 0;
1377                             }
1378                             g0y += nz;
1379                         }
1380                         g0x += ny*nz;
1381                     }
1382
1383                     flag[flind] = 1;
1384                 }
1385             }
1386         }
1387     }
1388 }
1389
1390 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1391 #define DO_BSPLINE(order)                            \
1392 for(ithx=0; (ithx<order); ithx++)                    \
1393 {                                                    \
1394     index_x = (i0+ithx)*pny*pnz;                     \
1395     valx    = qn*thx[ithx];                          \
1396                                                      \
1397     for(ithy=0; (ithy<order); ithy++)                \
1398     {                                                \
1399         valxy    = valx*thy[ithy];                   \
1400         index_xy = index_x+(j0+ithy)*pnz;            \
1401                                                      \
1402         for(ithz=0; (ithz<order); ithz++)            \
1403         {                                            \
1404             index_xyz        = index_xy+(k0+ithz);   \
1405             grid[index_xyz] += valxy*thz[ithz];      \
1406         }                                            \
1407     }                                                \
1408 }
1409
1410
1411 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1412                                      pme_atomcomm_t *atc, splinedata_t *spline,
1413                                      pme_spline_work_t *work)
1414 {
1415
1416     /* spread charges from home atoms to local grid */
1417     real     *grid;
1418     pme_overlap_t *ol;
1419     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1420     int *    idxptr;
1421     int      order,norder,index_x,index_xy,index_xyz;
1422     real     valx,valxy,qn;
1423     real     *thx,*thy,*thz;
1424     int      localsize, bndsize;
1425     int      pnx,pny,pnz,ndatatot;
1426     int      offx,offy,offz;
1427
1428     pnx = pmegrid->n[XX];
1429     pny = pmegrid->n[YY];
1430     pnz = pmegrid->n[ZZ];
1431
1432     offx = pmegrid->offset[XX];
1433     offy = pmegrid->offset[YY];
1434     offz = pmegrid->offset[ZZ];
1435
1436     ndatatot = pnx*pny*pnz;
1437     grid = pmegrid->grid;
1438     for(i=0;i<ndatatot;i++)
1439     {
1440         grid[i] = 0;
1441     }
1442
1443     order = pmegrid->order;
1444
1445     for(nn=0; nn<spline->n; nn++)
1446     {
1447         n  = spline->ind[nn];
1448         qn = atc->q[n];
1449
1450         if (qn != 0)
1451         {
1452             idxptr = atc->idx[n];
1453             norder = nn*order;
1454
1455             i0   = idxptr[XX] - offx;
1456             j0   = idxptr[YY] - offy;
1457             k0   = idxptr[ZZ] - offz;
1458
1459             thx = spline->theta[XX] + norder;
1460             thy = spline->theta[YY] + norder;
1461             thz = spline->theta[ZZ] + norder;
1462
1463             switch (order) {
1464             case 4:
1465 #ifdef PME_SSE
1466 #ifdef PME_SSE_UNALIGNED
1467 #define PME_SPREAD_SSE_ORDER4
1468 #else
1469 #define PME_SPREAD_SSE_ALIGNED
1470 #define PME_ORDER 4
1471 #endif
1472 #include "pme_sse_single.h"
1473 #else
1474                 DO_BSPLINE(4);
1475 #endif
1476                 break;
1477             case 5:
1478 #ifdef PME_SSE
1479 #define PME_SPREAD_SSE_ALIGNED
1480 #define PME_ORDER 5
1481 #include "pme_sse_single.h"
1482 #else
1483                 DO_BSPLINE(5);
1484 #endif
1485                 break;
1486             default:
1487                 DO_BSPLINE(order);
1488                 break;
1489             }
1490         }
1491     }
1492 }
1493
1494 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1495 {
1496 #ifdef PME_SSE
1497     if (pme_order == 5
1498 #ifndef PME_SSE_UNALIGNED
1499         || pme_order == 4
1500 #endif
1501         )
1502     {
1503         /* Round nz up to a multiple of 4 to ensure alignment */
1504         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1505     }
1506 #endif
1507 }
1508
1509 static void set_gridsize_alignment(int *gridsize,int pme_order)
1510 {
1511 #ifdef PME_SSE
1512 #ifndef PME_SSE_UNALIGNED
1513     if (pme_order == 4)
1514     {
1515         /* Add extra elements to ensured aligned operations do not go
1516          * beyond the allocated grid size.
1517          * Note that for pme_order=5, the pme grid z-size alignment
1518          * ensures that we will not go beyond the grid size.
1519          */
1520          *gridsize += 4;
1521     }
1522 #endif
1523 #endif
1524 }
1525
1526 static void pmegrid_init(pmegrid_t *grid,
1527                          int cx, int cy, int cz,
1528                          int x0, int y0, int z0,
1529                          int x1, int y1, int z1,
1530                          gmx_bool set_alignment,
1531                          int pme_order,
1532                          real *ptr)
1533 {
1534     int nz,gridsize;
1535
1536     grid->ci[XX] = cx;
1537     grid->ci[YY] = cy;
1538     grid->ci[ZZ] = cz;
1539     grid->offset[XX] = x0;
1540     grid->offset[YY] = y0;
1541     grid->offset[ZZ] = z0;
1542     grid->n[XX]      = x1 - x0 + pme_order - 1;
1543     grid->n[YY]      = y1 - y0 + pme_order - 1;
1544     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1545
1546     nz = grid->n[ZZ];
1547     set_grid_alignment(&nz,pme_order);
1548     if (set_alignment)
1549     {
1550         grid->n[ZZ] = nz;
1551     }
1552     else if (nz != grid->n[ZZ])
1553     {
1554         gmx_incons("pmegrid_init call with an unaligned z size");
1555     }
1556
1557     grid->order = pme_order;
1558     if (ptr == NULL)
1559     {
1560         gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1561         set_gridsize_alignment(&gridsize,pme_order);
1562         snew_aligned(grid->grid,gridsize,16);
1563     }
1564     else
1565     {
1566         grid->grid = ptr;
1567     }
1568 }
1569
1570 static int div_round_up(int enumerator,int denominator)
1571 {
1572     return (enumerator + denominator - 1)/denominator;
1573 }
1574
1575 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1576                                   ivec nsub)
1577 {
1578     int gsize_opt,gsize;
1579     int nsx,nsy,nsz;
1580     char *env;
1581
1582     gsize_opt = -1;
1583     for(nsx=1; nsx<=nthread; nsx++)
1584     {
1585         if (nthread % nsx == 0)
1586         {
1587             for(nsy=1; nsy<=nthread; nsy++)
1588             {
1589                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1590                 {
1591                     nsz = nthread/(nsx*nsy);
1592
1593                     /* Determine the number of grid points per thread */
1594                     gsize =
1595                         (div_round_up(n[XX],nsx) + ovl)*
1596                         (div_round_up(n[YY],nsy) + ovl)*
1597                         (div_round_up(n[ZZ],nsz) + ovl);
1598
1599                     /* Minimize the number of grids points per thread
1600                      * and, secondarily, the number of cuts in minor dimensions.
1601                      */
1602                     if (gsize_opt == -1 ||
1603                         gsize < gsize_opt ||
1604                         (gsize == gsize_opt &&
1605                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1606                     {
1607                         nsub[XX] = nsx;
1608                         nsub[YY] = nsy;
1609                         nsub[ZZ] = nsz;
1610                         gsize_opt = gsize;
1611                     }
1612                 }
1613             }
1614         }
1615     }
1616
1617     env = getenv("GMX_PME_THREAD_DIVISION");
1618     if (env != NULL)
1619     {
1620         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1621     }
1622
1623     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1624     {
1625         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1626     }
1627 }
1628
1629 static void pmegrids_init(pmegrids_t *grids,
1630                           int nx,int ny,int nz,int nz_base,
1631                           int pme_order,
1632                           int nthread,
1633                           int overlap_x,
1634                           int overlap_y)
1635 {
1636     ivec n,n_base,g0,g1;
1637     int t,x,y,z,d,i,tfac;
1638     int max_comm_lines;
1639
1640     n[XX] = nx - (pme_order - 1);
1641     n[YY] = ny - (pme_order - 1);
1642     n[ZZ] = nz - (pme_order - 1);
1643
1644     copy_ivec(n,n_base);
1645     n_base[ZZ] = nz_base;
1646
1647     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1648                  NULL);
1649
1650     grids->nthread = nthread;
1651
1652     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1653
1654     if (grids->nthread > 1)
1655     {
1656         ivec nst;
1657         int gridsize;
1658         real *grid_all;
1659
1660         for(d=0; d<DIM; d++)
1661         {
1662             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1663         }
1664         set_grid_alignment(&nst[ZZ],pme_order);
1665
1666         if (debug)
1667         {
1668             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1669                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1670             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1671                     nx,ny,nz,
1672                     nst[XX],nst[YY],nst[ZZ]);
1673         }
1674
1675         snew(grids->grid_th,grids->nthread);
1676         t = 0;
1677         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1678         set_gridsize_alignment(&gridsize,pme_order);
1679         snew_aligned(grid_all,
1680                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1681                      16);
1682
1683         for(x=0; x<grids->nc[XX]; x++)
1684         {
1685             for(y=0; y<grids->nc[YY]; y++)
1686             {
1687                 for(z=0; z<grids->nc[ZZ]; z++)
1688                 {
1689                     pmegrid_init(&grids->grid_th[t],
1690                                  x,y,z,
1691                                  (n[XX]*(x  ))/grids->nc[XX],
1692                                  (n[YY]*(y  ))/grids->nc[YY],
1693                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1694                                  (n[XX]*(x+1))/grids->nc[XX],
1695                                  (n[YY]*(y+1))/grids->nc[YY],
1696                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1697                                  TRUE,
1698                                  pme_order,
1699                                  grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1700                     t++;
1701                 }
1702             }
1703         }
1704     }
1705
1706     snew(grids->g2t,DIM);
1707     tfac = 1;
1708     for(d=DIM-1; d>=0; d--)
1709     {
1710         snew(grids->g2t[d],n[d]);
1711         t = 0;
1712         for(i=0; i<n[d]; i++)
1713         {
1714             /* The second check should match the parameters
1715              * of the pmegrid_init call above.
1716              */
1717             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1718             {
1719                 t++;
1720             }
1721             grids->g2t[d][i] = t*tfac;
1722         }
1723
1724         tfac *= grids->nc[d];
1725
1726         switch (d)
1727         {
1728         case XX: max_comm_lines = overlap_x;     break;
1729         case YY: max_comm_lines = overlap_y;     break;
1730         case ZZ: max_comm_lines = pme_order - 1; break;
1731         }
1732         grids->nthread_comm[d] = 0;
1733         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1734         {
1735             grids->nthread_comm[d]++;
1736         }
1737         if (debug != NULL)
1738         {
1739             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1740                     'x'+d,grids->nthread_comm[d]);
1741         }
1742         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1743          * work, but this is not a problematic restriction.
1744          */
1745         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1746         {
1747             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1748         }
1749     }
1750 }
1751
1752
1753 static void pmegrids_destroy(pmegrids_t *grids)
1754 {
1755     int t;
1756
1757     if (grids->grid.grid != NULL)
1758     {
1759         sfree(grids->grid.grid);
1760
1761         if (grids->nthread > 0)
1762         {
1763             for(t=0; t<grids->nthread; t++)
1764             {
1765                 sfree(grids->grid_th[t].grid);
1766             }
1767             sfree(grids->grid_th);
1768         }
1769     }
1770 }
1771
1772
1773 static void realloc_work(pme_work_t *work,int nkx)
1774 {
1775     if (nkx > work->nalloc)
1776     {
1777         work->nalloc = nkx;
1778         srenew(work->mhx  ,work->nalloc);
1779         srenew(work->mhy  ,work->nalloc);
1780         srenew(work->mhz  ,work->nalloc);
1781         srenew(work->m2   ,work->nalloc);
1782         /* Allocate an aligned pointer for SSE operations, including 3 extra
1783          * elements at the end since SSE operates on 4 elements at a time.
1784          */
1785         sfree_aligned(work->denom);
1786         sfree_aligned(work->tmp1);
1787         sfree_aligned(work->eterm);
1788         snew_aligned(work->denom,work->nalloc+3,16);
1789         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1790         snew_aligned(work->eterm,work->nalloc+3,16);
1791         srenew(work->m2inv,work->nalloc);
1792     }
1793 }
1794
1795
1796 static void free_work(pme_work_t *work)
1797 {
1798     sfree(work->mhx);
1799     sfree(work->mhy);
1800     sfree(work->mhz);
1801     sfree(work->m2);
1802     sfree_aligned(work->denom);
1803     sfree_aligned(work->tmp1);
1804     sfree_aligned(work->eterm);
1805     sfree(work->m2inv);
1806 }
1807
1808
1809 #ifdef PME_SSE
1810     /* Calculate exponentials through SSE in float precision */
1811 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1812 {
1813     {
1814         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1815         __m128 f_sse;
1816         __m128 lu;
1817         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1818         int kx;
1819         f_sse = _mm_load1_ps(&f);
1820         for(kx=0; kx<end; kx+=4)
1821         {
1822             tmp_d1   = _mm_load_ps(d_aligned+kx);
1823             lu       = _mm_rcp_ps(tmp_d1);
1824             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1825             tmp_r    = _mm_load_ps(r_aligned+kx);
1826             tmp_r    = gmx_mm_exp_ps(tmp_r);
1827             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1828             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1829             _mm_store_ps(e_aligned+kx,tmp_e);
1830         }
1831     }
1832 }
1833 #else
1834 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1835 {
1836     int kx;
1837     for(kx=start; kx<end; kx++)
1838     {
1839         d[kx] = 1.0/d[kx];
1840     }
1841     for(kx=start; kx<end; kx++)
1842     {
1843         r[kx] = exp(r[kx]);
1844     }
1845     for(kx=start; kx<end; kx++)
1846     {
1847         e[kx] = f*r[kx]*d[kx];
1848     }
1849 }
1850 #endif
1851
1852
1853 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1854                          real ewaldcoeff,real vol,
1855                          gmx_bool bEnerVir,
1856                          int nthread,int thread)
1857 {
1858     /* do recip sum over local cells in grid */
1859     /* y major, z middle, x minor or continuous */
1860     t_complex *p0;
1861     int     kx,ky,kz,maxkx,maxky,maxkz;
1862     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1863     real    mx,my,mz;
1864     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1865     real    ets2,struct2,vfactor,ets2vf;
1866     real    d1,d2,energy=0;
1867     real    by,bz;
1868     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1869     real    rxx,ryx,ryy,rzx,rzy,rzz;
1870     pme_work_t *work;
1871     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1872     real    mhxk,mhyk,mhzk,m2k;
1873     real    corner_fac;
1874     ivec    complex_order;
1875     ivec    local_ndata,local_offset,local_size;
1876     real    elfac;
1877
1878     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1879
1880     nx = pme->nkx;
1881     ny = pme->nky;
1882     nz = pme->nkz;
1883
1884     /* Dimensions should be identical for A/B grid, so we just use A here */
1885     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1886                                       complex_order,
1887                                       local_ndata,
1888                                       local_offset,
1889                                       local_size);
1890
1891     rxx = pme->recipbox[XX][XX];
1892     ryx = pme->recipbox[YY][XX];
1893     ryy = pme->recipbox[YY][YY];
1894     rzx = pme->recipbox[ZZ][XX];
1895     rzy = pme->recipbox[ZZ][YY];
1896     rzz = pme->recipbox[ZZ][ZZ];
1897
1898     maxkx = (nx+1)/2;
1899     maxky = (ny+1)/2;
1900     maxkz = nz/2+1;
1901
1902     work = &pme->work[thread];
1903     mhx   = work->mhx;
1904     mhy   = work->mhy;
1905     mhz   = work->mhz;
1906     m2    = work->m2;
1907     denom = work->denom;
1908     tmp1  = work->tmp1;
1909     eterm = work->eterm;
1910     m2inv = work->m2inv;
1911
1912     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1913     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1914
1915     for(iyz=iyz0; iyz<iyz1; iyz++)
1916     {
1917         iy = iyz/local_ndata[ZZ];
1918         iz = iyz - iy*local_ndata[ZZ];
1919
1920         ky = iy + local_offset[YY];
1921
1922         if (ky < maxky)
1923         {
1924             my = ky;
1925         }
1926         else
1927         {
1928             my = (ky - ny);
1929         }
1930
1931         by = M_PI*vol*pme->bsp_mod[YY][ky];
1932
1933         kz = iz + local_offset[ZZ];
1934
1935         mz = kz;
1936
1937         bz = pme->bsp_mod[ZZ][kz];
1938
1939         /* 0.5 correction for corner points */
1940         corner_fac = 1;
1941         if (kz == 0 || kz == (nz+1)/2)
1942         {
1943             corner_fac = 0.5;
1944         }
1945
1946         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1947
1948         /* We should skip the k-space point (0,0,0) */
1949         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1950         {
1951             kxstart = local_offset[XX];
1952         }
1953         else
1954         {
1955             kxstart = local_offset[XX] + 1;
1956             p0++;
1957         }
1958         kxend = local_offset[XX] + local_ndata[XX];
1959
1960         if (bEnerVir)
1961         {
1962             /* More expensive inner loop, especially because of the storage
1963              * of the mh elements in array's.
1964              * Because x is the minor grid index, all mh elements
1965              * depend on kx for triclinic unit cells.
1966              */
1967
1968                 /* Two explicit loops to avoid a conditional inside the loop */
1969             for(kx=kxstart; kx<maxkx; kx++)
1970             {
1971                 mx = kx;
1972
1973                 mhxk      = mx * rxx;
1974                 mhyk      = mx * ryx + my * ryy;
1975                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1976                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1977                 mhx[kx]   = mhxk;
1978                 mhy[kx]   = mhyk;
1979                 mhz[kx]   = mhzk;
1980                 m2[kx]    = m2k;
1981                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1982                 tmp1[kx]  = -factor*m2k;
1983             }
1984
1985             for(kx=maxkx; kx<kxend; kx++)
1986             {
1987                 mx = (kx - nx);
1988
1989                 mhxk      = mx * rxx;
1990                 mhyk      = mx * ryx + my * ryy;
1991                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1992                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1993                 mhx[kx]   = mhxk;
1994                 mhy[kx]   = mhyk;
1995                 mhz[kx]   = mhzk;
1996                 m2[kx]    = m2k;
1997                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1998                 tmp1[kx]  = -factor*m2k;
1999             }
2000
2001             for(kx=kxstart; kx<kxend; kx++)
2002             {
2003                 m2inv[kx] = 1.0/m2[kx];
2004             }
2005
2006             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2007
2008             for(kx=kxstart; kx<kxend; kx++,p0++)
2009             {
2010                 d1      = p0->re;
2011                 d2      = p0->im;
2012
2013                 p0->re  = d1*eterm[kx];
2014                 p0->im  = d2*eterm[kx];
2015
2016                 struct2 = 2.0*(d1*d1+d2*d2);
2017
2018                 tmp1[kx] = eterm[kx]*struct2;
2019             }
2020
2021             for(kx=kxstart; kx<kxend; kx++)
2022             {
2023                 ets2     = corner_fac*tmp1[kx];
2024                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2025                 energy  += ets2;
2026
2027                 ets2vf   = ets2*vfactor;
2028                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2029                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2030                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2031                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2032                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2033                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2034             }
2035         }
2036         else
2037         {
2038             /* We don't need to calculate the energy and the virial.
2039              * In this case the triclinic overhead is small.
2040              */
2041
2042             /* Two explicit loops to avoid a conditional inside the loop */
2043
2044             for(kx=kxstart; kx<maxkx; kx++)
2045             {
2046                 mx = kx;
2047
2048                 mhxk      = mx * rxx;
2049                 mhyk      = mx * ryx + my * ryy;
2050                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2051                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2052                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2053                 tmp1[kx]  = -factor*m2k;
2054             }
2055
2056             for(kx=maxkx; kx<kxend; kx++)
2057             {
2058                 mx = (kx - nx);
2059
2060                 mhxk      = mx * rxx;
2061                 mhyk      = mx * ryx + my * ryy;
2062                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2063                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2064                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2065                 tmp1[kx]  = -factor*m2k;
2066             }
2067
2068             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2069
2070             for(kx=kxstart; kx<kxend; kx++,p0++)
2071             {
2072                 d1      = p0->re;
2073                 d2      = p0->im;
2074
2075                 p0->re  = d1*eterm[kx];
2076                 p0->im  = d2*eterm[kx];
2077             }
2078         }
2079     }
2080
2081     if (bEnerVir)
2082     {
2083         /* Update virial with local values.
2084          * The virial is symmetric by definition.
2085          * this virial seems ok for isotropic scaling, but I'm
2086          * experiencing problems on semiisotropic membranes.
2087          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2088          */
2089         work->vir[XX][XX] = 0.25*virxx;
2090         work->vir[YY][YY] = 0.25*viryy;
2091         work->vir[ZZ][ZZ] = 0.25*virzz;
2092         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2093         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2094         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2095
2096         /* This energy should be corrected for a charged system */
2097         work->energy = 0.5*energy;
2098     }
2099
2100     /* Return the loop count */
2101     return local_ndata[YY]*local_ndata[XX];
2102 }
2103
2104 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2105                              real *mesh_energy,matrix vir)
2106 {
2107     /* This function sums output over threads
2108      * and should therefore only be called after thread synchronization.
2109      */
2110     int thread;
2111
2112     *mesh_energy = pme->work[0].energy;
2113     copy_mat(pme->work[0].vir,vir);
2114
2115     for(thread=1; thread<nthread; thread++)
2116     {
2117         *mesh_energy += pme->work[thread].energy;
2118         m_add(vir,pme->work[thread].vir,vir);
2119     }
2120 }
2121
2122 #define DO_FSPLINE(order)                      \
2123 for(ithx=0; (ithx<order); ithx++)              \
2124 {                                              \
2125     index_x = (i0+ithx)*pny*pnz;               \
2126     tx      = thx[ithx];                       \
2127     dx      = dthx[ithx];                      \
2128                                                \
2129     for(ithy=0; (ithy<order); ithy++)          \
2130     {                                          \
2131         index_xy = index_x+(j0+ithy)*pnz;      \
2132         ty       = thy[ithy];                  \
2133         dy       = dthy[ithy];                 \
2134         fxy1     = fz1 = 0;                    \
2135                                                \
2136         for(ithz=0; (ithz<order); ithz++)      \
2137         {                                      \
2138             gval  = grid[index_xy+(k0+ithz)];  \
2139             fxy1 += thz[ithz]*gval;            \
2140             fz1  += dthz[ithz]*gval;           \
2141         }                                      \
2142         fx += dx*ty*fxy1;                      \
2143         fy += tx*dy*fxy1;                      \
2144         fz += tx*ty*fz1;                       \
2145     }                                          \
2146 }
2147
2148
2149 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2150                               gmx_bool bClearF,pme_atomcomm_t *atc,
2151                               splinedata_t *spline,
2152                               real scale)
2153 {
2154     /* sum forces for local particles */
2155     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2156     int     index_x,index_xy;
2157     int     nx,ny,nz,pnx,pny,pnz;
2158     int *   idxptr;
2159     real    tx,ty,dx,dy,qn;
2160     real    fx,fy,fz,gval;
2161     real    fxy1,fz1;
2162     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2163     int     norder;
2164     real    rxx,ryx,ryy,rzx,rzy,rzz;
2165     int     order;
2166
2167     pme_spline_work_t *work;
2168
2169     work = &pme->spline_work;
2170
2171     order = pme->pme_order;
2172     thx   = spline->theta[XX];
2173     thy   = spline->theta[YY];
2174     thz   = spline->theta[ZZ];
2175     dthx  = spline->dtheta[XX];
2176     dthy  = spline->dtheta[YY];
2177     dthz  = spline->dtheta[ZZ];
2178     nx    = pme->nkx;
2179     ny    = pme->nky;
2180     nz    = pme->nkz;
2181     pnx   = pme->pmegrid_nx;
2182     pny   = pme->pmegrid_ny;
2183     pnz   = pme->pmegrid_nz;
2184
2185     rxx   = pme->recipbox[XX][XX];
2186     ryx   = pme->recipbox[YY][XX];
2187     ryy   = pme->recipbox[YY][YY];
2188     rzx   = pme->recipbox[ZZ][XX];
2189     rzy   = pme->recipbox[ZZ][YY];
2190     rzz   = pme->recipbox[ZZ][ZZ];
2191
2192     for(nn=0; nn<spline->n; nn++)
2193     {
2194         n  = spline->ind[nn];
2195         qn = scale*atc->q[n];
2196
2197         if (bClearF)
2198         {
2199             atc->f[n][XX] = 0;
2200             atc->f[n][YY] = 0;
2201             atc->f[n][ZZ] = 0;
2202         }
2203         if (qn != 0)
2204         {
2205             fx     = 0;
2206             fy     = 0;
2207             fz     = 0;
2208             idxptr = atc->idx[n];
2209             norder = nn*order;
2210
2211             i0   = idxptr[XX];
2212             j0   = idxptr[YY];
2213             k0   = idxptr[ZZ];
2214
2215             /* Pointer arithmetic alert, next six statements */
2216             thx  = spline->theta[XX] + norder;
2217             thy  = spline->theta[YY] + norder;
2218             thz  = spline->theta[ZZ] + norder;
2219             dthx = spline->dtheta[XX] + norder;
2220             dthy = spline->dtheta[YY] + norder;
2221             dthz = spline->dtheta[ZZ] + norder;
2222
2223             switch (order) {
2224             case 4:
2225 #ifdef PME_SSE
2226 #ifdef PME_SSE_UNALIGNED
2227 #define PME_GATHER_F_SSE_ORDER4
2228 #else
2229 #define PME_GATHER_F_SSE_ALIGNED
2230 #define PME_ORDER 4
2231 #endif
2232 #include "pme_sse_single.h"
2233 #else
2234                 DO_FSPLINE(4);
2235 #endif
2236                 break;
2237             case 5:
2238 #ifdef PME_SSE
2239 #define PME_GATHER_F_SSE_ALIGNED
2240 #define PME_ORDER 5
2241 #include "pme_sse_single.h"
2242 #else
2243                 DO_FSPLINE(5);
2244 #endif
2245                 break;
2246             default:
2247                 DO_FSPLINE(order);
2248                 break;
2249             }
2250
2251             atc->f[n][XX] += -qn*( fx*nx*rxx );
2252             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2253             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2254         }
2255     }
2256     /* Since the energy and not forces are interpolated
2257      * the net force might not be exactly zero.
2258      * This can be solved by also interpolating F, but
2259      * that comes at a cost.
2260      * A better hack is to remove the net force every
2261      * step, but that must be done at a higher level
2262      * since this routine doesn't see all atoms if running
2263      * in parallel. Don't know how important it is?  EL 990726
2264      */
2265 }
2266
2267
2268 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2269                                    pme_atomcomm_t *atc)
2270 {
2271     splinedata_t *spline;
2272     int     n,ithx,ithy,ithz,i0,j0,k0;
2273     int     index_x,index_xy;
2274     int *   idxptr;
2275     real    energy,pot,tx,ty,qn,gval;
2276     real    *thx,*thy,*thz;
2277     int     norder;
2278     int     order;
2279
2280     spline = &atc->spline[0];
2281
2282     order = pme->pme_order;
2283
2284     energy = 0;
2285     for(n=0; (n<atc->n); n++) {
2286         qn      = atc->q[n];
2287
2288         if (qn != 0) {
2289             idxptr = atc->idx[n];
2290             norder = n*order;
2291
2292             i0   = idxptr[XX];
2293             j0   = idxptr[YY];
2294             k0   = idxptr[ZZ];
2295
2296             /* Pointer arithmetic alert, next three statements */
2297             thx  = spline->theta[XX] + norder;
2298             thy  = spline->theta[YY] + norder;
2299             thz  = spline->theta[ZZ] + norder;
2300
2301             pot = 0;
2302             for(ithx=0; (ithx<order); ithx++)
2303             {
2304                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2305                 tx      = thx[ithx];
2306
2307                 for(ithy=0; (ithy<order); ithy++)
2308                 {
2309                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2310                     ty       = thy[ithy];
2311
2312                     for(ithz=0; (ithz<order); ithz++)
2313                     {
2314                         gval  = grid[index_xy+(k0+ithz)];
2315                         pot  += tx*ty*thz[ithz]*gval;
2316                     }
2317
2318                 }
2319             }
2320
2321             energy += pot*qn;
2322         }
2323     }
2324
2325     return energy;
2326 }
2327
2328 /* Macro to force loop unrolling by fixing order.
2329  * This gives a significant performance gain.
2330  */
2331 #define CALC_SPLINE(order)                     \
2332 {                                              \
2333     int j,k,l;                                 \
2334     real dr,div;                               \
2335     real data[PME_ORDER_MAX];                  \
2336     real ddata[PME_ORDER_MAX];                 \
2337                                                \
2338     for(j=0; (j<DIM); j++)                     \
2339     {                                          \
2340         dr  = xptr[j];                         \
2341                                                \
2342         /* dr is relative offset from lower cell limit */ \
2343         data[order-1] = 0;                     \
2344         data[1] = dr;                          \
2345         data[0] = 1 - dr;                      \
2346                                                \
2347         for(k=3; (k<order); k++)               \
2348         {                                      \
2349             div = 1.0/(k - 1.0);               \
2350             data[k-1] = div*dr*data[k-2];      \
2351             for(l=1; (l<(k-1)); l++)           \
2352             {                                  \
2353                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2354                                    data[k-l-1]);                \
2355             }                                  \
2356             data[0] = div*(1-dr)*data[0];      \
2357         }                                      \
2358         /* differentiate */                    \
2359         ddata[0] = -data[0];                   \
2360         for(k=1; (k<order); k++)               \
2361         {                                      \
2362             ddata[k] = data[k-1] - data[k];    \
2363         }                                      \
2364                                                \
2365         div = 1.0/(order - 1);                 \
2366         data[order-1] = div*dr*data[order-2];  \
2367         for(l=1; (l<(order-1)); l++)           \
2368         {                                      \
2369             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2370                                (order-l-dr)*data[order-l-1]); \
2371         }                                      \
2372         data[0] = div*(1 - dr)*data[0];        \
2373                                                \
2374         for(k=0; k<order; k++)                 \
2375         {                                      \
2376             theta[j][i*order+k]  = data[k];    \
2377             dtheta[j][i*order+k] = ddata[k];   \
2378         }                                      \
2379     }                                          \
2380 }
2381
2382 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2383                    rvec fractx[],int nr,int ind[],real charge[],
2384                    gmx_bool bFreeEnergy)
2385 {
2386     /* construct splines for local atoms */
2387     int  i,ii;
2388     real *xptr;
2389
2390     for(i=0; i<nr; i++)
2391     {
2392         /* With free energy we do not use the charge check.
2393          * In most cases this will be more efficient than calling make_bsplines
2394          * twice, since usually more than half the particles have charges.
2395          */
2396         ii = ind[i];
2397         if (bFreeEnergy || charge[ii] != 0.0) {
2398             xptr = fractx[ii];
2399             switch(order) {
2400             case 4:  CALC_SPLINE(4);     break;
2401             case 5:  CALC_SPLINE(5);     break;
2402             default: CALC_SPLINE(order); break;
2403             }
2404         }
2405     }
2406 }
2407
2408
2409 void make_dft_mod(real *mod,real *data,int ndata)
2410 {
2411   int i,j;
2412   real sc,ss,arg;
2413
2414   for(i=0;i<ndata;i++) {
2415     sc=ss=0;
2416     for(j=0;j<ndata;j++) {
2417       arg=(2.0*M_PI*i*j)/ndata;
2418       sc+=data[j]*cos(arg);
2419       ss+=data[j]*sin(arg);
2420     }
2421     mod[i]=sc*sc+ss*ss;
2422   }
2423   for(i=0;i<ndata;i++)
2424     if(mod[i]<1e-7)
2425       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2426 }
2427
2428
2429 static void make_bspline_moduli(splinevec bsp_mod,
2430                                 int nx,int ny,int nz,int order)
2431 {
2432   int nmax=max(nx,max(ny,nz));
2433   real *data,*ddata,*bsp_data;
2434   int i,k,l;
2435   real div;
2436
2437   snew(data,order);
2438   snew(ddata,order);
2439   snew(bsp_data,nmax);
2440
2441   data[order-1]=0;
2442   data[1]=0;
2443   data[0]=1;
2444
2445   for(k=3;k<order;k++) {
2446     div=1.0/(k-1.0);
2447     data[k-1]=0;
2448     for(l=1;l<(k-1);l++)
2449       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2450     data[0]=div*data[0];
2451   }
2452   /* differentiate */
2453   ddata[0]=-data[0];
2454   for(k=1;k<order;k++)
2455     ddata[k]=data[k-1]-data[k];
2456   div=1.0/(order-1);
2457   data[order-1]=0;
2458   for(l=1;l<(order-1);l++)
2459     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2460   data[0]=div*data[0];
2461
2462   for(i=0;i<nmax;i++)
2463     bsp_data[i]=0;
2464   for(i=1;i<=order;i++)
2465     bsp_data[i]=data[i-1];
2466
2467   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2468   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2469   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2470
2471   sfree(data);
2472   sfree(ddata);
2473   sfree(bsp_data);
2474 }
2475
2476
2477 /* Return the P3M optimal influence function */
2478 static double do_p3m_influence(double z, int order)
2479 {
2480     double z2,z4;
2481
2482     z2 = z*z;
2483     z4 = z2*z2;
2484
2485     /* The formula and most constants can be found in:
2486      * Ballenegger et al., JCTC 8, 936 (2012)
2487      */
2488     switch(order)
2489     {
2490     case 2:
2491         return 1.0 - 2.0*z2/3.0;
2492         break;
2493     case 3:
2494         return 1.0 - z2 + 2.0*z4/15.0;
2495         break;
2496     case 4:
2497         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2498         break;
2499     case 5:
2500         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2501         break;
2502     case 6:
2503         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2504         break;
2505     case 7:
2506         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2507     case 8:
2508         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2509         break;
2510     }
2511
2512     return 0.0;
2513 }
2514
2515 /* Calculate the P3M B-spline moduli for one dimension */
2516 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2517 {
2518     double zarg,zai,sinzai,infl;
2519     int    maxk,i;
2520
2521     if (order > 8)
2522     {
2523         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2524     }
2525
2526     zarg = M_PI/n;
2527
2528     maxk = (n + 1)/2;
2529
2530     for(i=-maxk; i<0; i++)
2531     {
2532         zai    = zarg*i;
2533         sinzai = sin(zai);
2534         infl   = do_p3m_influence(sinzai,order);
2535         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2536     }
2537     bsp_mod[0] = 1.0;
2538     for(i=1; i<maxk; i++)
2539     {
2540         zai    = zarg*i;
2541         sinzai = sin(zai);
2542         infl   = do_p3m_influence(sinzai,order);
2543         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2544     }
2545 }
2546
2547 /* Calculate the P3M B-spline moduli */
2548 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2549                                     int nx,int ny,int nz,int order)
2550 {
2551     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2552     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2553     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2554 }
2555
2556
2557 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2558 {
2559   int nslab,n,i;
2560   int fw,bw;
2561
2562   nslab = atc->nslab;
2563
2564   n = 0;
2565   for(i=1; i<=nslab/2; i++) {
2566     fw = (atc->nodeid + i) % nslab;
2567     bw = (atc->nodeid - i + nslab) % nslab;
2568     if (n < nslab - 1) {
2569       atc->node_dest[n] = fw;
2570       atc->node_src[n]  = bw;
2571       n++;
2572     }
2573     if (n < nslab - 1) {
2574       atc->node_dest[n] = bw;
2575       atc->node_src[n]  = fw;
2576       n++;
2577     }
2578   }
2579 }
2580
2581 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2582 {
2583     int thread;
2584
2585     if(NULL != log)
2586     {
2587         fprintf(log,"Destroying PME data structures.\n");
2588     }
2589
2590     sfree((*pmedata)->nnx);
2591     sfree((*pmedata)->nny);
2592     sfree((*pmedata)->nnz);
2593
2594     pmegrids_destroy(&(*pmedata)->pmegridA);
2595
2596     sfree((*pmedata)->fftgridA);
2597     sfree((*pmedata)->cfftgridA);
2598     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2599
2600     if ((*pmedata)->pmegridB.grid.grid != NULL)
2601     {
2602         pmegrids_destroy(&(*pmedata)->pmegridB);
2603         sfree((*pmedata)->fftgridB);
2604         sfree((*pmedata)->cfftgridB);
2605         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2606     }
2607     for(thread=0; thread<(*pmedata)->nthread; thread++)
2608     {
2609         free_work(&(*pmedata)->work[thread]);
2610     }
2611     sfree((*pmedata)->work);
2612
2613     sfree(*pmedata);
2614     *pmedata = NULL;
2615
2616   return 0;
2617 }
2618
2619 static int mult_up(int n,int f)
2620 {
2621     return ((n + f - 1)/f)*f;
2622 }
2623
2624
2625 static double pme_load_imbalance(gmx_pme_t pme)
2626 {
2627     int    nma,nmi;
2628     double n1,n2,n3;
2629
2630     nma = pme->nnodes_major;
2631     nmi = pme->nnodes_minor;
2632
2633     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2634     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2635     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2636
2637     /* pme_solve is roughly double the cost of an fft */
2638
2639     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2640 }
2641
2642 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2643                           int dimind,gmx_bool bSpread)
2644 {
2645     int nk,k,s,thread;
2646
2647     atc->dimind = dimind;
2648     atc->nslab  = 1;
2649     atc->nodeid = 0;
2650     atc->pd_nalloc = 0;
2651 #ifdef GMX_MPI
2652     if (pme->nnodes > 1)
2653     {
2654         atc->mpi_comm = pme->mpi_comm_d[dimind];
2655         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2656         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2657     }
2658     if (debug)
2659     {
2660         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2661     }
2662 #endif
2663
2664     atc->bSpread   = bSpread;
2665     atc->pme_order = pme->pme_order;
2666
2667     if (atc->nslab > 1)
2668     {
2669         /* These three allocations are not required for particle decomp. */
2670         snew(atc->node_dest,atc->nslab);
2671         snew(atc->node_src,atc->nslab);
2672         setup_coordinate_communication(atc);
2673
2674         snew(atc->count_thread,pme->nthread);
2675         for(thread=0; thread<pme->nthread; thread++)
2676         {
2677             snew(atc->count_thread[thread],atc->nslab);
2678         }
2679         atc->count = atc->count_thread[0];
2680         snew(atc->rcount,atc->nslab);
2681         snew(atc->buf_index,atc->nslab);
2682     }
2683
2684     atc->nthread = pme->nthread;
2685     if (atc->nthread > 1)
2686     {
2687         snew(atc->thread_plist,atc->nthread);
2688     }
2689     snew(atc->spline,atc->nthread);
2690     for(thread=0; thread<atc->nthread; thread++)
2691     {
2692         if (atc->nthread > 1)
2693         {
2694             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2695             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2696         }
2697     }
2698 }
2699
2700 static void
2701 init_overlap_comm(pme_overlap_t *  ol,
2702                   int              norder,
2703 #ifdef GMX_MPI
2704                   MPI_Comm         comm,
2705 #endif
2706                   int              nnodes,
2707                   int              nodeid,
2708                   int              ndata,
2709                   int              commplainsize)
2710 {
2711     int lbnd,rbnd,maxlr,b,i;
2712     int exten;
2713     int nn,nk;
2714     pme_grid_comm_t *pgc;
2715     gmx_bool bCont;
2716     int fft_start,fft_end,send_index1,recv_index1;
2717
2718 #ifdef GMX_MPI
2719     ol->mpi_comm = comm;
2720 #endif
2721
2722     ol->nnodes = nnodes;
2723     ol->nodeid = nodeid;
2724
2725     /* Linear translation of the PME grid wo'nt affect reciprocal space
2726      * calculations, so to optimize we only interpolate "upwards",
2727      * which also means we only have to consider overlap in one direction.
2728      * I.e., particles on this node might also be spread to grid indices
2729      * that belong to higher nodes (modulo nnodes)
2730      */
2731
2732     snew(ol->s2g0,ol->nnodes+1);
2733     snew(ol->s2g1,ol->nnodes);
2734     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2735     for(i=0; i<nnodes; i++)
2736     {
2737         /* s2g0 the local interpolation grid start.
2738          * s2g1 the local interpolation grid end.
2739          * Because grid overlap communication only goes forward,
2740          * the grid the slabs for fft's should be rounded down.
2741          */
2742         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2743         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2744
2745         if (debug)
2746         {
2747             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2748         }
2749     }
2750     ol->s2g0[nnodes] = ndata;
2751     if (debug) { fprintf(debug,"\n"); }
2752
2753     /* Determine with how many nodes we need to communicate the grid overlap */
2754     b = 0;
2755     do
2756     {
2757         b++;
2758         bCont = FALSE;
2759         for(i=0; i<nnodes; i++)
2760         {
2761             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2762                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2763             {
2764                 bCont = TRUE;
2765             }
2766         }
2767     }
2768     while (bCont && b < nnodes);
2769     ol->noverlap_nodes = b - 1;
2770
2771     snew(ol->send_id,ol->noverlap_nodes);
2772     snew(ol->recv_id,ol->noverlap_nodes);
2773     for(b=0; b<ol->noverlap_nodes; b++)
2774     {
2775         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2776         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2777     }
2778     snew(ol->comm_data, ol->noverlap_nodes);
2779
2780     for(b=0; b<ol->noverlap_nodes; b++)
2781     {
2782         pgc = &ol->comm_data[b];
2783         /* Send */
2784         fft_start        = ol->s2g0[ol->send_id[b]];
2785         fft_end          = ol->s2g0[ol->send_id[b]+1];
2786         if (ol->send_id[b] < nodeid)
2787         {
2788             fft_start += ndata;
2789             fft_end   += ndata;
2790         }
2791         send_index1      = ol->s2g1[nodeid];
2792         send_index1      = min(send_index1,fft_end);
2793         pgc->send_index0 = fft_start;
2794         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2795
2796         /* We always start receiving to the first index of our slab */
2797         fft_start        = ol->s2g0[ol->nodeid];
2798         fft_end          = ol->s2g0[ol->nodeid+1];
2799         recv_index1      = ol->s2g1[ol->recv_id[b]];
2800         if (ol->recv_id[b] > nodeid)
2801         {
2802             recv_index1 -= ndata;
2803         }
2804         recv_index1      = min(recv_index1,fft_end);
2805         pgc->recv_index0 = fft_start;
2806         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2807     }
2808
2809     /* For non-divisible grid we need pme_order iso pme_order-1 */
2810     snew(ol->sendbuf,norder*commplainsize);
2811     snew(ol->recvbuf,norder*commplainsize);
2812 }
2813
2814 static void
2815 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2816                               int **global_to_local,
2817                               real **fraction_shift)
2818 {
2819     int i;
2820     int * gtl;
2821     real * fsh;
2822
2823     snew(gtl,5*n);
2824     snew(fsh,5*n);
2825     for(i=0; (i<5*n); i++)
2826     {
2827         /* Determine the global to local grid index */
2828         gtl[i] = (i - local_start + n) % n;
2829         /* For coordinates that fall within the local grid the fraction
2830          * is correct, we don't need to shift it.
2831          */
2832         fsh[i] = 0;
2833         if (local_range < n)
2834         {
2835             /* Due to rounding issues i could be 1 beyond the lower or
2836              * upper boundary of the local grid. Correct the index for this.
2837              * If we shift the index, we need to shift the fraction by
2838              * the same amount in the other direction to not affect
2839              * the weights.
2840              * Note that due to this shifting the weights at the end of
2841              * the spline might change, but that will only involve values
2842              * between zero and values close to the precision of a real,
2843              * which is anyhow the accuracy of the whole mesh calculation.
2844              */
2845             /* With local_range=0 we should not change i=local_start */
2846             if (i % n != local_start)
2847             {
2848                 if (gtl[i] == n-1)
2849                 {
2850                     gtl[i] = 0;
2851                     fsh[i] = -1;
2852                 }
2853                 else if (gtl[i] == local_range)
2854                 {
2855                     gtl[i] = local_range - 1;
2856                     fsh[i] = 1;
2857                 }
2858             }
2859         }
2860     }
2861
2862     *global_to_local = gtl;
2863     *fraction_shift  = fsh;
2864 }
2865
2866 static void sse_mask_init(pme_spline_work_t *work,int order)
2867 {
2868 #ifdef PME_SSE
2869     float  tmp[8];
2870     __m128 zero_SSE;
2871     int    of,i;
2872
2873     zero_SSE = _mm_setzero_ps();
2874
2875     for(of=0; of<8-(order-1); of++)
2876     {
2877         for(i=0; i<8; i++)
2878         {
2879             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2880         }
2881         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2882         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2883         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2884         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2885     }
2886 #endif
2887 }
2888
2889 static void
2890 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2891 {
2892     int nk_new;
2893
2894     if (*nk % nnodes != 0)
2895     {
2896         nk_new = nnodes*(*nk/nnodes + 1);
2897
2898         if (2*nk_new >= 3*(*nk))
2899         {
2900             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2901                       dim,*nk,dim,nnodes,dim);
2902         }
2903
2904         if (fplog != NULL)
2905         {
2906             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2907                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2908         }
2909
2910         *nk = nk_new;
2911     }
2912 }
2913
2914 int gmx_pme_init(gmx_pme_t *         pmedata,
2915                  t_commrec *         cr,
2916                  int                 nnodes_major,
2917                  int                 nnodes_minor,
2918                  t_inputrec *        ir,
2919                  int                 homenr,
2920                  gmx_bool            bFreeEnergy,
2921                  gmx_bool            bReproducible,
2922                  int                 nthread)
2923 {
2924     gmx_pme_t pme=NULL;
2925
2926     pme_atomcomm_t *atc;
2927     ivec ndata;
2928
2929     if (debug)
2930         fprintf(debug,"Creating PME data structures.\n");
2931     snew(pme,1);
2932
2933     pme->redist_init         = FALSE;
2934     pme->sum_qgrid_tmp       = NULL;
2935     pme->sum_qgrid_dd_tmp    = NULL;
2936     pme->buf_nalloc          = 0;
2937     pme->redist_buf_nalloc   = 0;
2938
2939     pme->nnodes              = 1;
2940     pme->bPPnode             = TRUE;
2941
2942     pme->nnodes_major        = nnodes_major;
2943     pme->nnodes_minor        = nnodes_minor;
2944
2945 #ifdef GMX_MPI
2946     if (nnodes_major*nnodes_minor > 1)
2947     {
2948         pme->mpi_comm = cr->mpi_comm_mygroup;
2949
2950         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2951         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2952         if (pme->nnodes != nnodes_major*nnodes_minor)
2953         {
2954             gmx_incons("PME node count mismatch");
2955         }
2956     }
2957     else
2958     {
2959         pme->mpi_comm = MPI_COMM_NULL;
2960     }
2961 #endif
2962
2963     if (pme->nnodes == 1)
2964     {
2965         pme->ndecompdim = 0;
2966         pme->nodeid_major = 0;
2967         pme->nodeid_minor = 0;
2968 #ifdef GMX_MPI
2969         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2970 #endif
2971     }
2972     else
2973     {
2974         if (nnodes_minor == 1)
2975         {
2976 #ifdef GMX_MPI
2977             pme->mpi_comm_d[0] = pme->mpi_comm;
2978             pme->mpi_comm_d[1] = MPI_COMM_NULL;
2979 #endif
2980             pme->ndecompdim = 1;
2981             pme->nodeid_major = pme->nodeid;
2982             pme->nodeid_minor = 0;
2983
2984         }
2985         else if (nnodes_major == 1)
2986         {
2987 #ifdef GMX_MPI
2988             pme->mpi_comm_d[0] = MPI_COMM_NULL;
2989             pme->mpi_comm_d[1] = pme->mpi_comm;
2990 #endif
2991             pme->ndecompdim = 1;
2992             pme->nodeid_major = 0;
2993             pme->nodeid_minor = pme->nodeid;
2994         }
2995         else
2996         {
2997             if (pme->nnodes % nnodes_major != 0)
2998             {
2999                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3000             }
3001             pme->ndecompdim = 2;
3002
3003 #ifdef GMX_MPI
3004             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3005                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3006             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3007                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3008
3009             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3010             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3011             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3012             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3013 #endif
3014         }
3015         pme->bPPnode = (cr->duty & DUTY_PP);
3016     }
3017
3018     pme->nthread = nthread;
3019
3020     if (ir->ePBC == epbcSCREW)
3021     {
3022         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3023     }
3024
3025     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3026     pme->nkx         = ir->nkx;
3027     pme->nky         = ir->nky;
3028     pme->nkz         = ir->nkz;
3029     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3030     pme->pme_order   = ir->pme_order;
3031     pme->epsilon_r   = ir->epsilon_r;
3032
3033     if (pme->pme_order > PME_ORDER_MAX)
3034     {
3035         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3036                   pme->pme_order,PME_ORDER_MAX);
3037     }
3038
3039     /* Currently pme.c supports only the fft5d FFT code.
3040      * Therefore the grid always needs to be divisible by nnodes.
3041      * When the old 1D code is also supported again, change this check.
3042      *
3043      * This check should be done before calling gmx_pme_init
3044      * and fplog should be passed iso stderr.
3045      *
3046     if (pme->ndecompdim >= 2)
3047     */
3048     if (pme->ndecompdim >= 1)
3049     {
3050         /*
3051         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3052                                         'x',nnodes_major,&pme->nkx);
3053         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3054                                         'y',nnodes_minor,&pme->nky);
3055         */
3056     }
3057
3058     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3059         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3060         pme->nkz <= pme->pme_order)
3061     {
3062         gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3063     }
3064
3065     if (pme->nnodes > 1) {
3066         double imbal;
3067
3068 #ifdef GMX_MPI
3069         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3070         MPI_Type_commit(&(pme->rvec_mpi));
3071 #endif
3072
3073         /* Note that the charge spreading and force gathering, which usually
3074          * takes about the same amount of time as FFT+solve_pme,
3075          * is always fully load balanced
3076          * (unless the charge distribution is inhomogeneous).
3077          */
3078
3079         imbal = pme_load_imbalance(pme);
3080         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3081         {
3082             fprintf(stderr,
3083                     "\n"
3084                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3085                     "      For optimal PME load balancing\n"
3086                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3087                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3088                     "\n",
3089                     (int)((imbal-1)*100 + 0.5),
3090                     pme->nkx,pme->nky,pme->nnodes_major,
3091                     pme->nky,pme->nkz,pme->nnodes_minor);
3092         }
3093     }
3094
3095     /* For non-divisible grid we need pme_order iso pme_order-1 */
3096     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3097      * y is always copied through a buffer: we don't need padding in z,
3098      * but we do need the overlap in x because of the communication order.
3099      */
3100     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3101 #ifdef GMX_MPI
3102                       pme->mpi_comm_d[0],
3103 #endif
3104                       pme->nnodes_major,pme->nodeid_major,
3105                       pme->nkx,
3106                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3107
3108     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3109 #ifdef GMX_MPI
3110                       pme->mpi_comm_d[1],
3111 #endif
3112                       pme->nnodes_minor,pme->nodeid_minor,
3113                       pme->nky,
3114                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3115
3116     /* Check for a limitation of the (current) sum_fftgrid_dd code */
3117     if (pme->nthread > 1 &&
3118         (pme->overlap[0].noverlap_nodes > 1 ||
3119          pme->overlap[1].noverlap_nodes > 1))
3120     {
3121         gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3122     }
3123
3124     snew(pme->bsp_mod[XX],pme->nkx);
3125     snew(pme->bsp_mod[YY],pme->nky);
3126     snew(pme->bsp_mod[ZZ],pme->nkz);
3127
3128     /* The required size of the interpolation grid, including overlap.
3129      * The allocated size (pmegrid_n?) might be slightly larger.
3130      */
3131     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3132                       pme->overlap[0].s2g0[pme->nodeid_major];
3133     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3134                       pme->overlap[1].s2g0[pme->nodeid_minor];
3135     pme->pmegrid_nz_base = pme->nkz;
3136     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3137     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3138
3139     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3140     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3141     pme->pmegrid_start_iz = 0;
3142
3143     make_gridindex5_to_localindex(pme->nkx,
3144                                   pme->pmegrid_start_ix,
3145                                   pme->pmegrid_nx - (pme->pme_order-1),
3146                                   &pme->nnx,&pme->fshx);
3147     make_gridindex5_to_localindex(pme->nky,
3148                                   pme->pmegrid_start_iy,
3149                                   pme->pmegrid_ny - (pme->pme_order-1),
3150                                   &pme->nny,&pme->fshy);
3151     make_gridindex5_to_localindex(pme->nkz,
3152                                   pme->pmegrid_start_iz,
3153                                   pme->pmegrid_nz_base,
3154                                   &pme->nnz,&pme->fshz);
3155
3156     pmegrids_init(&pme->pmegridA,
3157                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3158                   pme->pmegrid_nz_base,
3159                   pme->pme_order,
3160                   pme->nthread,
3161                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3162                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3163
3164     sse_mask_init(&pme->spline_work,pme->pme_order);
3165
3166     ndata[0] = pme->nkx;
3167     ndata[1] = pme->nky;
3168     ndata[2] = pme->nkz;
3169
3170     /* This routine will allocate the grid data to fit the FFTs */
3171     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3172                             &pme->fftgridA,&pme->cfftgridA,
3173                             pme->mpi_comm_d,
3174                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3175                             bReproducible,pme->nthread);
3176
3177     if (bFreeEnergy)
3178     {
3179         pmegrids_init(&pme->pmegridB,
3180                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3181                       pme->pmegrid_nz_base,
3182                       pme->pme_order,
3183                       pme->nthread,
3184                       pme->nkx % pme->nnodes_major != 0,
3185                       pme->nky % pme->nnodes_minor != 0);
3186
3187         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3188                                 &pme->fftgridB,&pme->cfftgridB,
3189                                 pme->mpi_comm_d,
3190                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3191                                 bReproducible,pme->nthread);
3192     }
3193     else
3194     {
3195         pme->pmegridB.grid.grid = NULL;
3196         pme->fftgridB           = NULL;
3197         pme->cfftgridB          = NULL;
3198     }
3199
3200     if (!pme->bP3M)
3201     {
3202         /* Use plain SPME B-spline interpolation */
3203         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3204     }
3205     else
3206     {
3207         /* Use the P3M grid-optimized influence function */
3208         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3209     }
3210
3211     /* Use atc[0] for spreading */
3212     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3213     if (pme->ndecompdim >= 2)
3214     {
3215         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3216     }
3217
3218     if (pme->nnodes == 1) {
3219         pme->atc[0].n = homenr;
3220         pme_realloc_atomcomm_things(&pme->atc[0]);
3221     }
3222
3223     {
3224         int thread;
3225
3226         /* Use fft5d, order after FFT is y major, z, x minor */
3227
3228         snew(pme->work,pme->nthread);
3229         for(thread=0; thread<pme->nthread; thread++)
3230         {
3231             realloc_work(&pme->work[thread],pme->nkx);
3232         }
3233     }
3234
3235     *pmedata = pme;
3236
3237     return 0;
3238 }
3239
3240
3241 static void copy_local_grid(gmx_pme_t pme,
3242                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3243 {
3244     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3245     int  fft_my,fft_mz;
3246     int  nsx,nsy,nsz;
3247     ivec nf;
3248     int  offx,offy,offz,x,y,z,i0,i0t;
3249     int  d;
3250     pmegrid_t *pmegrid;
3251     real *grid_th;
3252
3253     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3254                                    local_fft_ndata,
3255                                    local_fft_offset,
3256                                    local_fft_size);
3257     fft_my = local_fft_size[YY];
3258     fft_mz = local_fft_size[ZZ];
3259
3260     pmegrid = &pmegrids->grid_th[thread];
3261
3262     nsx = pmegrid->n[XX];
3263     nsy = pmegrid->n[YY];
3264     nsz = pmegrid->n[ZZ];
3265
3266     for(d=0; d<DIM; d++)
3267     {
3268         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3269                     local_fft_ndata[d] - pmegrid->offset[d]);
3270     }
3271
3272     offx = pmegrid->offset[XX];
3273     offy = pmegrid->offset[YY];
3274     offz = pmegrid->offset[ZZ];
3275
3276     /* Directly copy the non-overlapping parts of the local grids.
3277      * This also initializes the full grid.
3278      */
3279     grid_th = pmegrid->grid;
3280     for(x=0; x<nf[XX]; x++)
3281     {
3282         for(y=0; y<nf[YY]; y++)
3283         {
3284             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3285             i0t = (x*nsy + y)*nsz;
3286             for(z=0; z<nf[ZZ]; z++)
3287             {
3288                 fftgrid[i0+z] = grid_th[i0t+z];
3289             }
3290         }
3291     }
3292 }
3293
3294 static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3295 {
3296     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3297     pme_overlap_t *overlap;
3298     int datasize,nind;
3299     int i,x,y,z,n;
3300
3301     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3302                                    local_fft_ndata,
3303                                    local_fft_offset,
3304                                    local_fft_size);
3305     /* Major dimension */
3306     overlap = &pme->overlap[0];
3307
3308     nind   = overlap->comm_data[0].send_nindex;
3309
3310     for(y=0; y<local_fft_ndata[YY]; y++) {
3311          printf(" %2d",y);
3312     }
3313     printf("\n");
3314
3315     i = 0;
3316     for(x=0; x<nind; x++) {
3317         for(y=0; y<local_fft_ndata[YY]; y++) {
3318             n = 0;
3319             for(z=0; z<local_fft_ndata[ZZ]; z++) {
3320                 if (sendbuf[i] != 0) n++;
3321                 i++;
3322             }
3323             printf(" %2d",n);
3324         }
3325         printf("\n");
3326     }
3327 }
3328
3329 static void
3330 reduce_threadgrid_overlap(gmx_pme_t pme,
3331                           const pmegrids_t *pmegrids,int thread,
3332                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3333 {
3334     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3335     int  fft_nx,fft_ny,fft_nz;
3336     int  fft_my,fft_mz;
3337     int  buf_my=-1;
3338     int  nsx,nsy,nsz;
3339     ivec ne;
3340     int  offx,offy,offz,x,y,z,i0,i0t;
3341     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3342     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3343     gmx_bool bCommX,bCommY;
3344     int  d;
3345     int  thread_f;
3346     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3347     const real *grid_th;
3348     real *commbuf=NULL;
3349
3350     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3351                                    local_fft_ndata,
3352                                    local_fft_offset,
3353                                    local_fft_size);
3354     fft_nx = local_fft_ndata[XX];
3355     fft_ny = local_fft_ndata[YY];
3356     fft_nz = local_fft_ndata[ZZ];
3357
3358     fft_my = local_fft_size[YY];
3359     fft_mz = local_fft_size[ZZ];
3360
3361     /* This routine is called when all thread have finished spreading.
3362      * Here each thread sums grid contributions calculated by other threads
3363      * to the thread local grid volume.
3364      * To minimize the number of grid copying operations,
3365      * this routines sums immediately from the pmegrid to the fftgrid.
3366      */
3367
3368     /* Determine which part of the full node grid we should operate on,
3369      * this is our thread local part of the full grid.
3370      */
3371     pmegrid = &pmegrids->grid_th[thread];
3372
3373     for(d=0; d<DIM; d++)
3374     {
3375         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3376                     local_fft_ndata[d]);
3377     }
3378
3379     offx = pmegrid->offset[XX];
3380     offy = pmegrid->offset[YY];
3381     offz = pmegrid->offset[ZZ];
3382
3383
3384     bClearBufX  = TRUE;
3385     bClearBufY  = TRUE;
3386     bClearBufXY = TRUE;
3387
3388     /* Now loop over all the thread data blocks that contribute
3389      * to the grid region we (our thread) are operating on.
3390      */
3391     /* Note that ffy_nx/y is equal to the number of grid points
3392      * between the first point of our node grid and the one of the next node.
3393      */
3394     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3395     {
3396         fx = pmegrid->ci[XX] + sx;
3397         ox = 0;
3398         bCommX = FALSE;
3399         if (fx < 0) {
3400             fx += pmegrids->nc[XX];
3401             ox -= fft_nx;
3402             bCommX = (pme->nnodes_major > 1);
3403         }
3404         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3405         ox += pmegrid_g->offset[XX];
3406         if (!bCommX)
3407         {
3408             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3409         }
3410         else
3411         {
3412             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3413         }
3414
3415         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3416         {
3417             fy = pmegrid->ci[YY] + sy;
3418             oy = 0;
3419             bCommY = FALSE;
3420             if (fy < 0) {
3421                 fy += pmegrids->nc[YY];
3422                 oy -= fft_ny;
3423                 bCommY = (pme->nnodes_minor > 1);
3424             }
3425             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3426             oy += pmegrid_g->offset[YY];
3427             if (!bCommY)
3428             {
3429                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3430             }
3431             else
3432             {
3433                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3434             }
3435
3436             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3437             {
3438                 fz = pmegrid->ci[ZZ] + sz;
3439                 oz = 0;
3440                 if (fz < 0)
3441                 {
3442                     fz += pmegrids->nc[ZZ];
3443                     oz -= fft_nz;
3444                 }
3445                 pmegrid_g = &pmegrids->grid_th[fz];
3446                 oz += pmegrid_g->offset[ZZ];
3447                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3448
3449                 if (sx == 0 && sy == 0 && sz == 0)
3450                 {
3451                     /* We have already added our local contribution
3452                      * before calling this routine, so skip it here.
3453                      */
3454                     continue;
3455                 }
3456
3457                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3458
3459                 pmegrid_f = &pmegrids->grid_th[thread_f];
3460
3461                 grid_th = pmegrid_f->grid;
3462
3463                 nsx = pmegrid_f->n[XX];
3464                 nsy = pmegrid_f->n[YY];
3465                 nsz = pmegrid_f->n[ZZ];
3466
3467 #ifdef DEBUG_PME_REDUCE
3468                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3469                        pme->nodeid,thread,thread_f,
3470                        pme->pmegrid_start_ix,
3471                        pme->pmegrid_start_iy,
3472                        pme->pmegrid_start_iz,
3473                        sx,sy,sz,
3474                        offx-ox,tx1-ox,offx,tx1,
3475                        offy-oy,ty1-oy,offy,ty1,
3476                        offz-oz,tz1-oz,offz,tz1);
3477 #endif
3478
3479                 if (!(bCommX || bCommY))
3480                 {
3481                     /* Copy from the thread local grid to the node grid */
3482                     for(x=offx; x<tx1; x++)
3483                     {
3484                         for(y=offy; y<ty1; y++)
3485                         {
3486                             i0  = (x*fft_my + y)*fft_mz;
3487                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3488                             for(z=offz; z<tz1; z++)
3489                             {
3490                                 fftgrid[i0+z] += grid_th[i0t+z];
3491                             }
3492                         }
3493                     }
3494                 }
3495                 else
3496                 {
3497                     /* The order of this conditional decides
3498                      * where the corner volume gets stored with x+y decomp.
3499                      */
3500                     if (bCommY)
3501                     {
3502                         commbuf = commbuf_y;
3503                         buf_my  = ty1 - offy;
3504                         if (bCommX)
3505                         {
3506                             /* We index commbuf modulo the local grid size */
3507                             commbuf += buf_my*fft_nx*fft_nz;
3508
3509                             bClearBuf  = bClearBufXY;
3510                             bClearBufXY = FALSE;
3511                         }
3512                         else
3513                         {
3514                             bClearBuf  = bClearBufY;
3515                             bClearBufY = FALSE;
3516                         }
3517                     }
3518                     else
3519                     {
3520                         commbuf = commbuf_x;
3521                         buf_my  = fft_ny;
3522                         bClearBuf  = bClearBufX;
3523                         bClearBufX = FALSE;
3524                     }
3525
3526                     /* Copy to the communication buffer */
3527                     for(x=offx; x<tx1; x++)
3528                     {
3529                         for(y=offy; y<ty1; y++)
3530                         {
3531                             i0  = (x*buf_my + y)*fft_nz;
3532                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3533
3534                             if (bClearBuf)
3535                             {
3536                                 /* First access of commbuf, initialize it */
3537                                 for(z=offz; z<tz1; z++)
3538                                 {
3539                                     commbuf[i0+z]  = grid_th[i0t+z];
3540                                 }
3541                             }
3542                             else
3543                             {
3544                                 for(z=offz; z<tz1; z++)
3545                                 {
3546                                     commbuf[i0+z] += grid_th[i0t+z];
3547                                 }
3548                             }
3549                         }
3550                     }
3551                 }
3552             }
3553         }
3554     }
3555 }
3556
3557
3558 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3559 {
3560     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3561     pme_overlap_t *overlap;
3562     int  send_nindex;
3563     int  recv_index0,recv_nindex;
3564 #ifdef GMX_MPI
3565     MPI_Status stat;
3566 #endif
3567     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3568     real *sendptr,*recvptr;
3569     int  x,y,z,indg,indb;
3570
3571     /* Note that this routine is only used for forward communication.
3572      * Since the force gathering, unlike the charge spreading,
3573      * can be trivially parallelized over the particles,
3574      * the backwards process is much simpler and can use the "old"
3575      * communication setup.
3576      */
3577
3578     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3579                                    local_fft_ndata,
3580                                    local_fft_offset,
3581                                    local_fft_size);
3582
3583     /* Currently supports only a single communication pulse */
3584
3585 /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3586     if (pme->nnodes_minor > 1)
3587     {
3588         /* Major dimension */
3589         overlap = &pme->overlap[1];
3590
3591         if (pme->nnodes_major > 1)
3592         {
3593              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3594         }
3595         else
3596         {
3597             size_yx = 0;
3598         }
3599         datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3600
3601         ipulse = 0;
3602
3603         send_id = overlap->send_id[ipulse];
3604         recv_id = overlap->recv_id[ipulse];
3605         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3606         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3607         recv_index0 = 0;
3608         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3609
3610         sendptr = overlap->sendbuf;
3611         recvptr = overlap->recvbuf;
3612
3613         /*
3614         printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3615                local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3616         printf("node %d send %f, %f\n",pme->nodeid,
3617                sendptr[0],sendptr[send_nindex*datasize-1]);
3618         */
3619
3620 #ifdef GMX_MPI
3621         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3622                      send_id,ipulse,
3623                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3624                      recv_id,ipulse,
3625                      overlap->mpi_comm,&stat);
3626 #endif
3627
3628         for(x=0; x<local_fft_ndata[XX]; x++)
3629         {
3630             for(y=0; y<recv_nindex; y++)
3631             {
3632                 indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3633                 indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3634                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3635                 {
3636                     fftgrid[indg+z] += recvptr[indb+z];
3637                 }
3638             }
3639         }
3640         if (pme->nnodes_major > 1)
3641         {
3642             sendptr = pme->overlap[0].sendbuf;
3643             for(x=0; x<size_yx; x++)
3644             {
3645                 for(y=0; y<recv_nindex; y++)
3646                 {
3647                     indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3648                     indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3649                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3650                     {
3651                         sendptr[indg+z] += recvptr[indb+z];
3652                     }
3653                 }
3654             }
3655         }
3656     }
3657
3658     /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3659     if (pme->nnodes_major > 1)
3660     {
3661         /* Major dimension */
3662         overlap = &pme->overlap[0];
3663
3664         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3665         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3666
3667         ipulse = 0;
3668
3669         send_id = overlap->send_id[ipulse];
3670         recv_id = overlap->recv_id[ipulse];
3671         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3672         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3673         recv_index0 = 0;
3674         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3675
3676         sendptr = overlap->sendbuf;
3677         recvptr = overlap->recvbuf;
3678
3679         if (debug != NULL)
3680         {
3681             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3682                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3683         }
3684
3685 #ifdef GMX_MPI
3686         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3687                      send_id,ipulse,
3688                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3689                      recv_id,ipulse,
3690                      overlap->mpi_comm,&stat);
3691 #endif
3692
3693         for(x=0; x<recv_nindex; x++)
3694         {
3695             for(y=0; y<local_fft_ndata[YY]; y++)
3696             {
3697                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3698                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3699                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3700                 {
3701                     fftgrid[indg+z] += recvptr[indb+z];
3702                 }
3703             }
3704         }
3705     }
3706 }
3707
3708
3709 static void spread_on_grid(gmx_pme_t pme,
3710                            pme_atomcomm_t *atc,pmegrids_t *grids,
3711                            gmx_bool bCalcSplines,gmx_bool bSpread,
3712                            real *fftgrid)
3713 {
3714     int nthread,thread;
3715 #ifdef PME_TIME_THREADS
3716     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3717     static double cs1=0,cs2=0,cs3=0;
3718     static double cs1a[6]={0,0,0,0,0,0};
3719     static int cnt=0;
3720 #endif
3721
3722     nthread = pme->nthread;
3723     assert(nthread>0);
3724
3725 #ifdef PME_TIME_THREADS
3726     c1 = omp_cyc_start();
3727 #endif
3728     if (bCalcSplines)
3729     {
3730 #pragma omp parallel for num_threads(nthread) schedule(static)
3731         for(thread=0; thread<nthread; thread++)
3732         {
3733             int start,end;
3734
3735             start = atc->n* thread   /nthread;
3736             end   = atc->n*(thread+1)/nthread;
3737
3738             /* Compute fftgrid index for all atoms,
3739              * with help of some extra variables.
3740              */
3741             calc_interpolation_idx(pme,atc,start,end,thread);
3742         }
3743     }
3744 #ifdef PME_TIME_THREADS
3745     c1 = omp_cyc_end(c1);
3746     cs1 += (double)c1;
3747 #endif
3748
3749 #ifdef PME_TIME_THREADS
3750     c2 = omp_cyc_start();
3751 #endif
3752 #pragma omp parallel for num_threads(nthread) schedule(static)
3753     for(thread=0; thread<nthread; thread++)
3754     {
3755         splinedata_t *spline;
3756         pmegrid_t *grid;
3757
3758         /* make local bsplines  */
3759         if (grids == NULL || grids->nthread == 1)
3760         {
3761             spline = &atc->spline[0];
3762
3763             spline->n = atc->n;
3764
3765             grid = &grids->grid;
3766         }
3767         else
3768         {
3769             spline = &atc->spline[thread];
3770
3771             make_thread_local_ind(atc,thread,spline);
3772
3773             grid = &grids->grid_th[thread];
3774         }
3775
3776         if (bCalcSplines)
3777         {
3778             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3779                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3780         }
3781
3782         if (bSpread)
3783         {
3784             /* put local atoms on grid. */
3785 #ifdef PME_TIME_SPREAD
3786             ct1a = omp_cyc_start();
3787 #endif
3788             spread_q_bsplines_thread(grid,atc,spline,&pme->spline_work);
3789
3790             if (grids->nthread > 1)
3791             {
3792                 copy_local_grid(pme,grids,thread,fftgrid);
3793             }
3794 #ifdef PME_TIME_SPREAD
3795             ct1a = omp_cyc_end(ct1a);
3796             cs1a[thread] += (double)ct1a;
3797 #endif
3798         }
3799     }
3800 #ifdef PME_TIME_THREADS
3801     c2 = omp_cyc_end(c2);
3802     cs2 += (double)c2;
3803 #endif
3804
3805     if (bSpread && grids->nthread > 1)
3806     {
3807 #ifdef PME_TIME_THREADS
3808         c3 = omp_cyc_start();
3809 #endif
3810 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3811         for(thread=0; thread<grids->nthread; thread++)
3812         {
3813             reduce_threadgrid_overlap(pme,grids,thread,
3814                                       fftgrid,
3815                                       pme->overlap[0].sendbuf,
3816                                       pme->overlap[1].sendbuf);
3817 #ifdef PRINT_PME_SENDBUF
3818             print_sendbuf(pme,pme->overlap[0].sendbuf);
3819 #endif
3820         }
3821 #ifdef PME_TIME_THREADS
3822         c3 = omp_cyc_end(c3);
3823         cs3 += (double)c3;
3824 #endif
3825
3826         if (pme->nnodes > 1)
3827         {
3828             /* Communicate the overlapping part of the fftgrid */
3829             sum_fftgrid_dd(pme,fftgrid);
3830         }
3831     }
3832
3833 #ifdef PME_TIME_THREADS
3834     cnt++;
3835     if (cnt % 20 == 0)
3836     {
3837         printf("idx %.2f spread %.2f red %.2f",
3838                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3839 #ifdef PME_TIME_SPREAD
3840         for(thread=0; thread<nthread; thread++)
3841             printf(" %.2f",cs1a[thread]*1e-9);
3842 #endif
3843         printf("\n");
3844     }
3845 #endif
3846 }
3847
3848
3849 static void dump_grid(FILE *fp,
3850                       int sx,int sy,int sz,int nx,int ny,int nz,
3851                       int my,int mz,const real *g)
3852 {
3853     int x,y,z;
3854
3855     for(x=0; x<nx; x++)
3856     {
3857         for(y=0; y<ny; y++)
3858         {
3859             for(z=0; z<nz; z++)
3860             {
3861                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3862                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3863             }
3864         }
3865     }
3866 }
3867
3868 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3869 {
3870     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3871
3872     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3873                                    local_fft_ndata,
3874                                    local_fft_offset,
3875                                    local_fft_size);
3876
3877     dump_grid(stderr,
3878               pme->pmegrid_start_ix,
3879               pme->pmegrid_start_iy,
3880               pme->pmegrid_start_iz,
3881               pme->pmegrid_nx-pme->pme_order+1,
3882               pme->pmegrid_ny-pme->pme_order+1,
3883               pme->pmegrid_nz-pme->pme_order+1,
3884               local_fft_size[YY],
3885               local_fft_size[ZZ],
3886               fftgrid);
3887 }
3888
3889
3890 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3891 {
3892     pme_atomcomm_t *atc;
3893     pmegrids_t *grid;
3894
3895     if (pme->nnodes > 1)
3896     {
3897         gmx_incons("gmx_pme_calc_energy called in parallel");
3898     }
3899     if (pme->bFEP > 1)
3900     {
3901         gmx_incons("gmx_pme_calc_energy with free energy");
3902     }
3903
3904     atc = &pme->atc_energy;
3905     atc->nthread   = 1;
3906     if (atc->spline == NULL)
3907     {
3908         snew(atc->spline,atc->nthread);
3909     }
3910     atc->nslab     = 1;
3911     atc->bSpread   = TRUE;
3912     atc->pme_order = pme->pme_order;
3913     atc->n         = n;
3914     pme_realloc_atomcomm_things(atc);
3915     atc->x         = x;
3916     atc->q         = q;
3917
3918     /* We only use the A-charges grid */
3919     grid = &pme->pmegridA;
3920
3921     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3922
3923     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3924 }
3925
3926
3927 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3928         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3929 {
3930     /* Reset all the counters related to performance over the run */
3931     wallcycle_stop(wcycle,ewcRUN);
3932     wallcycle_reset_all(wcycle);
3933     init_nrnb(nrnb);
3934     ir->init_step += step_rel;
3935     ir->nsteps    -= step_rel;
3936     wallcycle_start(wcycle,ewcRUN);
3937 }
3938
3939
3940 int gmx_pmeonly(gmx_pme_t pme,
3941                 t_commrec *cr,    t_nrnb *nrnb,
3942                 gmx_wallcycle_t wcycle,
3943                 real ewaldcoeff,  gmx_bool bGatherOnly,
3944                 t_inputrec *ir)
3945 {
3946     gmx_pme_pp_t pme_pp;
3947     int  natoms;
3948     matrix box;
3949     rvec *x_pp=NULL,*f_pp=NULL;
3950     real *chargeA=NULL,*chargeB=NULL;
3951     real lambda=0;
3952     int  maxshift_x=0,maxshift_y=0;
3953     real energy,dvdlambda;
3954     matrix vir;
3955     float cycles;
3956     int  count;
3957     gmx_bool bEnerVir;
3958     gmx_large_int_t step,step_rel;
3959
3960
3961     pme_pp = gmx_pme_pp_init(cr);
3962
3963     init_nrnb(nrnb);
3964
3965     count = 0;
3966     do /****** this is a quasi-loop over time steps! */
3967     {
3968         /* Domain decomposition */
3969         natoms = gmx_pme_recv_q_x(pme_pp,
3970                                   &chargeA,&chargeB,box,&x_pp,&f_pp,
3971                                   &maxshift_x,&maxshift_y,
3972                                   &pme->bFEP,&lambda,
3973                                   &bEnerVir,
3974                                   &step);
3975
3976         if (natoms == -1) {
3977             /* We should stop: break out of the loop */
3978             break;
3979         }
3980
3981         step_rel = step - ir->init_step;
3982
3983         if (count == 0)
3984             wallcycle_start(wcycle,ewcRUN);
3985
3986         wallcycle_start(wcycle,ewcPMEMESH);
3987
3988         dvdlambda = 0;
3989         clear_mat(vir);
3990         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
3991                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
3992                    &energy,lambda,&dvdlambda,
3993                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
3994
3995         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
3996
3997         gmx_pme_send_force_vir_ener(pme_pp,
3998                                     f_pp,vir,energy,dvdlambda,
3999                                     cycles);
4000
4001         count++;
4002
4003         if (step_rel == wcycle_get_reset_counters(wcycle))
4004         {
4005             /* Reset all the counters related to performance over the run */
4006             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4007             wcycle_set_reset_counters(wcycle, 0);
4008         }
4009
4010     } /***** end of quasi-loop, we stop with the break above */
4011     while (TRUE);
4012
4013     return 0;
4014 }
4015
4016 int gmx_pme_do(gmx_pme_t pme,
4017                int start,       int homenr,
4018                rvec x[],        rvec f[],
4019                real *chargeA,   real *chargeB,
4020                matrix box, t_commrec *cr,
4021                int  maxshift_x, int maxshift_y,
4022                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4023                matrix vir,      real ewaldcoeff,
4024                real *energy,    real lambda,
4025                real *dvdlambda, int flags)
4026 {
4027     int     q,d,i,j,ntot,npme;
4028     int     nx,ny,nz;
4029     int     n_d,local_ny;
4030     pme_atomcomm_t *atc=NULL;
4031     pmegrids_t *pmegrid=NULL;
4032     real    *grid=NULL;
4033     real    *ptr;
4034     rvec    *x_d,*f_d;
4035     real    *charge=NULL,*q_d;
4036     real    energy_AB[2];
4037     matrix  vir_AB[2];
4038     gmx_bool bClearF;
4039     gmx_parallel_3dfft_t pfft_setup;
4040     real *  fftgrid;
4041     t_complex * cfftgrid;
4042     int     thread;
4043     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4044     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4045
4046     assert(pme->nnodes > 0);
4047     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4048
4049     if (pme->nnodes > 1) {
4050         atc = &pme->atc[0];
4051         atc->npd = homenr;
4052         if (atc->npd > atc->pd_nalloc) {
4053             atc->pd_nalloc = over_alloc_dd(atc->npd);
4054             srenew(atc->pd,atc->pd_nalloc);
4055         }
4056         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4057     }
4058     else
4059     {
4060         /* This could be necessary for TPI */
4061         pme->atc[0].n = homenr;
4062     }
4063
4064     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4065         if (q == 0) {
4066             pmegrid = &pme->pmegridA;
4067             fftgrid = pme->fftgridA;
4068             cfftgrid = pme->cfftgridA;
4069             pfft_setup = pme->pfft_setupA;
4070             charge = chargeA+start;
4071         } else {
4072             pmegrid = &pme->pmegridB;
4073             fftgrid = pme->fftgridB;
4074             cfftgrid = pme->cfftgridB;
4075             pfft_setup = pme->pfft_setupB;
4076             charge = chargeB+start;
4077         }
4078         grid = pmegrid->grid.grid;
4079         /* Unpack structure */
4080         if (debug) {
4081             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4082                     cr->nnodes,cr->nodeid);
4083             fprintf(debug,"Grid = %p\n",(void*)grid);
4084             if (grid == NULL)
4085                 gmx_fatal(FARGS,"No grid!");
4086         }
4087         where();
4088
4089         m_inv_ur0(box,pme->recipbox);
4090
4091         if (pme->nnodes == 1) {
4092             atc = &pme->atc[0];
4093             if (DOMAINDECOMP(cr)) {
4094                 atc->n = homenr;
4095                 pme_realloc_atomcomm_things(atc);
4096             }
4097             atc->x = x;
4098             atc->q = charge;
4099             atc->f = f;
4100         } else {
4101             wallcycle_start(wcycle,ewcPME_REDISTXF);
4102             for(d=pme->ndecompdim-1; d>=0; d--)
4103             {
4104                 if (d == pme->ndecompdim-1)
4105                 {
4106                     n_d = homenr;
4107                     x_d = x + start;
4108                     q_d = charge;
4109                 }
4110                 else
4111                 {
4112                     n_d = pme->atc[d+1].n;
4113                     x_d = atc->x;
4114                     q_d = atc->q;
4115                 }
4116                 atc = &pme->atc[d];
4117                 atc->npd = n_d;
4118                 if (atc->npd > atc->pd_nalloc) {
4119                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4120                     srenew(atc->pd,atc->pd_nalloc);
4121                 }
4122                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4123                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4124                 where();
4125
4126                 GMX_BARRIER(cr->mpi_comm_mygroup);
4127                 /* Redistribute x (only once) and qA or qB */
4128                 if (DOMAINDECOMP(cr)) {
4129                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4130                 } else {
4131                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4132                 }
4133             }
4134             where();
4135
4136             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4137         }
4138
4139         if (debug)
4140             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4141                     cr->nodeid,atc->n);
4142
4143         if (flags & GMX_PME_SPREAD_Q)
4144         {
4145             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4146
4147             /* Spread the charges on a grid */
4148             GMX_MPE_LOG(ev_spread_on_grid_start);
4149
4150             /* Spread the charges on a grid */
4151             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4152             GMX_MPE_LOG(ev_spread_on_grid_finish);
4153
4154             if (q == 0)
4155             {
4156                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4157             }
4158             inc_nrnb(nrnb,eNR_SPREADQBSP,
4159                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4160
4161             if (pme->nthread == 1)
4162             {
4163                 wrap_periodic_pmegrid(pme,grid);
4164
4165                 /* sum contributions to local grid from other nodes */
4166 #ifdef GMX_MPI
4167                 if (pme->nnodes > 1)
4168                 {
4169                     GMX_BARRIER(cr->mpi_comm_mygroup);
4170                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4171                     where();
4172                 }
4173 #endif
4174
4175                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4176             }
4177
4178             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4179
4180             /*
4181             dump_local_fftgrid(pme,fftgrid);
4182             exit(0);
4183             */
4184         }
4185
4186         /* Here we start a large thread parallel region */
4187 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4188         for(thread=0; thread<pme->nthread; thread++)
4189         {
4190             if (flags & GMX_PME_SOLVE)
4191             {
4192                 int loop_count;
4193
4194                 /* do 3d-fft */
4195                 if (thread == 0)
4196                 {
4197                     GMX_BARRIER(cr->mpi_comm_mygroup);
4198                     GMX_MPE_LOG(ev_gmxfft3d_start);
4199                     wallcycle_start(wcycle,ewcPME_FFT);
4200                 }
4201                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4202                                            fftgrid,cfftgrid,thread,wcycle);
4203                 if (thread == 0)
4204                 {
4205                     wallcycle_stop(wcycle,ewcPME_FFT);
4206                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4207                 }
4208                 where();
4209
4210                 /* solve in k-space for our local cells */
4211                 if (thread == 0)
4212                 {
4213                     GMX_BARRIER(cr->mpi_comm_mygroup);
4214                     GMX_MPE_LOG(ev_solve_pme_start);
4215                     wallcycle_start(wcycle,ewcPME_SOLVE);
4216                 }
4217                 loop_count =
4218                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4219                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4220                                   bCalcEnerVir,
4221                                   pme->nthread,thread);
4222                 if (thread == 0)
4223                 {
4224                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4225                     where();
4226                     GMX_MPE_LOG(ev_solve_pme_finish);
4227                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4228                 }
4229             }
4230
4231             if (bCalcF)
4232             {
4233                 /* do 3d-invfft */
4234                 if (thread == 0)
4235                 {
4236                     GMX_BARRIER(cr->mpi_comm_mygroup);
4237                     GMX_MPE_LOG(ev_gmxfft3d_start);
4238                     where();
4239                     wallcycle_start(wcycle,ewcPME_FFT);
4240                 }
4241                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4242                                            cfftgrid,fftgrid,thread,wcycle);
4243                 if (thread == 0)
4244                 {
4245                     wallcycle_stop(wcycle,ewcPME_FFT);
4246
4247                     where();
4248                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4249
4250                     if (pme->nodeid == 0)
4251                     {
4252                         ntot = pme->nkx*pme->nky*pme->nkz;
4253                         npme  = ntot*log((real)ntot)/log(2.0);
4254                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4255                     }
4256
4257                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4258                 }
4259
4260                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4261             }
4262         }
4263         /* End of thread parallel section.
4264          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4265          */
4266
4267         if (bCalcF)
4268         {
4269             /* distribute local grid to all nodes */
4270 #ifdef GMX_MPI
4271             if (pme->nnodes > 1) {
4272                 GMX_BARRIER(cr->mpi_comm_mygroup);
4273                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4274             }
4275 #endif
4276             where();
4277
4278             unwrap_periodic_pmegrid(pme,grid);
4279
4280             /* interpolate forces for our local atoms */
4281             GMX_BARRIER(cr->mpi_comm_mygroup);
4282             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4283
4284             where();
4285
4286             /* If we are running without parallelization,
4287              * atc->f is the actual force array, not a buffer,
4288              * therefore we should not clear it.
4289              */
4290             bClearF = (q == 0 && PAR(cr));
4291 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4292             for(thread=0; thread<pme->nthread; thread++)
4293             {
4294                 gather_f_bsplines(pme,grid,bClearF,atc,
4295                                   &atc->spline[thread],
4296                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4297             }
4298
4299             where();
4300
4301             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4302
4303             inc_nrnb(nrnb,eNR_GATHERFBSP,
4304                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4305             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4306         }
4307
4308         if (bCalcEnerVir)
4309         {
4310             /* This should only be called on the master thread
4311              * and after the threads have synchronized.
4312              */
4313             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4314         }
4315     } /* of q-loop */
4316
4317     if (bCalcF && pme->nnodes > 1) {
4318         wallcycle_start(wcycle,ewcPME_REDISTXF);
4319         for(d=0; d<pme->ndecompdim; d++)
4320         {
4321             atc = &pme->atc[d];
4322             if (d == pme->ndecompdim - 1)
4323             {
4324                 n_d = homenr;
4325                 f_d = f + start;
4326             }
4327             else
4328             {
4329                 n_d = pme->atc[d+1].n;
4330                 f_d = pme->atc[d+1].f;
4331             }
4332             GMX_BARRIER(cr->mpi_comm_mygroup);
4333             if (DOMAINDECOMP(cr)) {
4334                 dd_pmeredist_f(pme,atc,n_d,f_d,
4335                                d==pme->ndecompdim-1 && pme->bPPnode);
4336             } else {
4337                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4338             }
4339         }
4340
4341         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4342     }
4343     where();
4344
4345     if (bCalcEnerVir)
4346     {
4347         if (!pme->bFEP) {
4348             *energy = energy_AB[0];
4349             m_add(vir,vir_AB[0],vir);
4350         } else {
4351             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4352             *dvdlambda += energy_AB[1] - energy_AB[0];
4353             for(i=0; i<DIM; i++)
4354             {
4355                 for(j=0; j<DIM; j++)
4356                 {
4357                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4358                         lambda*vir_AB[1][i][j];
4359                 }
4360             }
4361         }
4362     }
4363     else
4364     {
4365         *energy = 0;
4366     }
4367
4368     if (debug)
4369     {
4370         fprintf(debug,"PME mesh energy: %g\n",*energy);
4371     }
4372
4373     return 0;
4374 }