Merge remote-tracking branch 'origin/release-4-6' into HEAD
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include <stdio.h>
71 #include <string.h>
72 #include <math.h>
73 #include <assert.h>
74 #include "typedefs.h"
75 #include "txtdump.h"
76 #include "vec.h"
77 #include "gmxcomplex.h"
78 #include "smalloc.h"
79 #include "futil.h"
80 #include "coulomb.h"
81 #include "gmx_fatal.h"
82 #include "pme.h"
83 #include "network.h"
84 #include "physics.h"
85 #include "nrnb.h"
86 #include "copyrite.h"
87 #include "gmx_wallcycle.h"
88 #include "gmx_parallel_3dfft.h"
89 #include "pdbio.h"
90 #include "gmx_cyclecounter.h"
91 #include "macros.h"
92
93 /* Single precision, with SSE2 or higher available */
94 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
95
96 #include "gmx_x86_sse2.h"
97 #include "gmx_math_x86_sse2_single.h"
98
99 #define PME_SSE
100 /* Some old AMD processors could have problems with unaligned loads+stores */
101 #ifndef GMX_FAHCORE
102 #define PME_SSE_UNALIGNED
103 #endif
104 #endif
105
106 #define DFT_TOL 1e-7
107 /* #define PRT_FORCE */
108 /* conditions for on the fly time-measurement */
109 /* #define TAKETIME (step > 1 && timesteps < 10) */
110 #define TAKETIME FALSE
111
112 /* #define PME_TIME_THREADS */
113
114 #ifdef GMX_DOUBLE
115 #define mpi_type MPI_DOUBLE
116 #else
117 #define mpi_type MPI_FLOAT
118 #endif
119
120 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
121 #define GMX_CACHE_SEP 64
122
123 /* We only define a maximum to be able to use local arrays without allocation.
124  * An order larger than 12 should never be needed, even for test cases.
125  * If needed it can be changed here.
126  */
127 #define PME_ORDER_MAX 12
128
129 /* Internal datastructures */
130 typedef struct {
131     int send_index0;
132     int send_nindex;
133     int recv_index0;
134     int recv_nindex;
135 } pme_grid_comm_t;
136
137 typedef struct {
138 #ifdef GMX_MPI
139     MPI_Comm mpi_comm;
140 #endif
141     int  nnodes,nodeid;
142     int  *s2g0;
143     int  *s2g1;
144     int  noverlap_nodes;
145     int  *send_id,*recv_id;
146     pme_grid_comm_t *comm_data;
147     real *sendbuf;
148     real *recvbuf;
149 } pme_overlap_t;
150
151 typedef struct {
152     int *n;     /* Cumulative counts of the number of particles per thread */
153     int nalloc; /* Allocation size of i */
154     int *i;     /* Particle indices ordered on thread index (n) */
155 } thread_plist_t;
156
157 typedef struct {
158     int  n;
159     int  *ind;
160     splinevec theta;
161     splinevec dtheta;
162 } splinedata_t;
163
164 typedef struct {
165     int  dimind;            /* The index of the dimension, 0=x, 1=y */
166     int  nslab;
167     int  nodeid;
168 #ifdef GMX_MPI
169     MPI_Comm mpi_comm;
170 #endif
171
172     int  *node_dest;        /* The nodes to send x and q to with DD */
173     int  *node_src;         /* The nodes to receive x and q from with DD */
174     int  *buf_index;        /* Index for commnode into the buffers */
175
176     int  maxshift;
177
178     int  npd;
179     int  pd_nalloc;
180     int  *pd;
181     int  *count;            /* The number of atoms to send to each node */
182     int  **count_thread;
183     int  *rcount;           /* The number of atoms to receive */
184
185     int  n;
186     int  nalloc;
187     rvec *x;
188     real *q;
189     rvec *f;
190     gmx_bool bSpread;       /* These coordinates are used for spreading */
191     int  pme_order;
192     ivec *idx;
193     rvec *fractx;            /* Fractional coordinate relative to the
194                               * lower cell boundary
195                               */
196     int  nthread;
197     int  *thread_idx;        /* Which thread should spread which charge */
198     thread_plist_t *thread_plist;
199     splinedata_t *spline;
200 } pme_atomcomm_t;
201
202 #define FLBS  3
203 #define FLBSZ 4
204
205 typedef struct {
206     ivec ci;     /* The spatial location of this grid       */
207     ivec n;      /* The size of *grid, including order-1    */
208     ivec offset; /* The grid offset from the full node grid */
209     int  order;  /* PME spreading order                     */
210     real *grid;  /* The grid local thread, size n           */
211 } pmegrid_t;
212
213 typedef struct {
214     pmegrid_t grid;     /* The full node grid (non thread-local)            */
215     int  nthread;       /* The number of threads operating on this grid     */
216     ivec nc;            /* The local spatial decomposition over the threads */
217     pmegrid_t *grid_th; /* Array of grids for each thread                   */
218     int  **g2t;         /* The grid to thread index                         */
219     ivec nthread_comm;  /* The number of threads to communicate with        */
220 } pmegrids_t;
221
222
223 typedef struct {
224 #ifdef PME_SSE
225     /* Masks for SSE aligned spreading and gathering */
226     __m128 mask_SSE0[6],mask_SSE1[6];
227 #else
228     int dummy; /* C89 requires that struct has at least one member */
229 #endif
230 } pme_spline_work_t;
231
232 typedef struct {
233     /* work data for solve_pme */
234     int      nalloc;
235     real *   mhx;
236     real *   mhy;
237     real *   mhz;
238     real *   m2;
239     real *   denom;
240     real *   tmp1_alloc;
241     real *   tmp1;
242     real *   eterm;
243     real *   m2inv;
244
245     real     energy;
246     matrix   vir;
247 } pme_work_t;
248
249 typedef struct gmx_pme {
250     int  ndecompdim;         /* The number of decomposition dimensions */
251     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
252     int  nodeid_major;
253     int  nodeid_minor;
254     int  nnodes;             /* The number of nodes doing PME */
255     int  nnodes_major;
256     int  nnodes_minor;
257
258     MPI_Comm mpi_comm;
259     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
260 #ifdef GMX_MPI
261     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
262 #endif
263
264     int  nthread;            /* The number of threads doing PME */
265
266     gmx_bool bPPnode;        /* Node also does particle-particle forces */
267     gmx_bool bFEP;           /* Compute Free energy contribution */
268     int nkx,nky,nkz;         /* Grid dimensions */
269     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
270     int pme_order;
271     real epsilon_r;
272
273     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
274     pmegrids_t pmegridB;
275     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
276     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
277     /* pmegrid_nz might be larger than strictly necessary to ensure
278      * memory alignment, pmegrid_nz_base gives the real base size.
279      */
280     int     pmegrid_nz_base;
281     /* The local PME grid starting indices */
282     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
283
284     /* Work data for spreading and gathering */
285     pme_spline_work_t spline_work;
286
287     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
288     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
289     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
290
291     t_complex *cfftgridA;             /* Grids for complex FFT data */
292     t_complex *cfftgridB;
293     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
294
295     gmx_parallel_3dfft_t  pfft_setupA;
296     gmx_parallel_3dfft_t  pfft_setupB;
297
298     int  *nnx,*nny,*nnz;
299     real *fshx,*fshy,*fshz;
300
301     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
302     matrix    recipbox;
303     splinevec bsp_mod;
304
305     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
306
307     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
308
309     rvec *bufv;             /* Communication buffer */
310     real *bufr;             /* Communication buffer */
311     int  buf_nalloc;        /* The communication buffer size */
312
313     /* thread local work data for solve_pme */
314     pme_work_t *work;
315
316     /* Work data for PME_redist */
317     gmx_bool redist_init;
318     int *    scounts;
319     int *    rcounts;
320     int *    sdispls;
321     int *    rdispls;
322     int *    sidx;
323     int *    idxa;
324     real *   redist_buf;
325     int      redist_buf_nalloc;
326
327     /* Work data for sum_qgrid */
328     real *   sum_qgrid_tmp;
329     real *   sum_qgrid_dd_tmp;
330 } t_gmx_pme;
331
332
333 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
334                                    int start,int end,int thread)
335 {
336     int  i;
337     int  *idxptr,tix,tiy,tiz;
338     real *xptr,*fptr,tx,ty,tz;
339     real rxx,ryx,ryy,rzx,rzy,rzz;
340     int  nx,ny,nz;
341     int  start_ix,start_iy,start_iz;
342     int  *g2tx,*g2ty,*g2tz;
343     gmx_bool bThreads;
344     int  *thread_idx=NULL;
345     thread_plist_t *tpl=NULL;
346     int  *tpl_n=NULL;
347     int  thread_i;
348
349     nx  = pme->nkx;
350     ny  = pme->nky;
351     nz  = pme->nkz;
352
353     start_ix = pme->pmegrid_start_ix;
354     start_iy = pme->pmegrid_start_iy;
355     start_iz = pme->pmegrid_start_iz;
356
357     rxx = pme->recipbox[XX][XX];
358     ryx = pme->recipbox[YY][XX];
359     ryy = pme->recipbox[YY][YY];
360     rzx = pme->recipbox[ZZ][XX];
361     rzy = pme->recipbox[ZZ][YY];
362     rzz = pme->recipbox[ZZ][ZZ];
363
364     g2tx = pme->pmegridA.g2t[XX];
365     g2ty = pme->pmegridA.g2t[YY];
366     g2tz = pme->pmegridA.g2t[ZZ];
367
368     bThreads = (atc->nthread > 1);
369     if (bThreads)
370     {
371         thread_idx = atc->thread_idx;
372
373         tpl   = &atc->thread_plist[thread];
374         tpl_n = tpl->n;
375         for(i=0; i<atc->nthread; i++)
376         {
377             tpl_n[i] = 0;
378         }
379     }
380
381     for(i=start; i<end; i++) {
382         xptr   = atc->x[i];
383         idxptr = atc->idx[i];
384         fptr   = atc->fractx[i];
385
386         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
387         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
388         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
389         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
390
391         tix = (int)(tx);
392         tiy = (int)(ty);
393         tiz = (int)(tz);
394
395         /* Because decomposition only occurs in x and y,
396          * we never have a fraction correction in z.
397          */
398         fptr[XX] = tx - tix + pme->fshx[tix];
399         fptr[YY] = ty - tiy + pme->fshy[tiy];
400         fptr[ZZ] = tz - tiz;
401
402         idxptr[XX] = pme->nnx[tix];
403         idxptr[YY] = pme->nny[tiy];
404         idxptr[ZZ] = pme->nnz[tiz];
405
406 #ifdef DEBUG
407         range_check(idxptr[XX],0,pme->pmegrid_nx);
408         range_check(idxptr[YY],0,pme->pmegrid_ny);
409         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
410 #endif
411
412         if (bThreads)
413         {
414             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
415             thread_idx[i] = thread_i;
416             tpl_n[thread_i]++;
417         }
418     }
419
420     if (bThreads)
421     {
422         /* Make a list of particle indices sorted on thread */
423
424         /* Get the cumulative count */
425         for(i=1; i<atc->nthread; i++)
426         {
427             tpl_n[i] += tpl_n[i-1];
428         }
429         /* The current implementation distributes particles equally
430          * over the threads, so we could actually allocate for that
431          * in pme_realloc_atomcomm_things.
432          */
433         if (tpl_n[atc->nthread-1] > tpl->nalloc)
434         {
435             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
436             srenew(tpl->i,tpl->nalloc);
437         }
438         /* Set tpl_n to the cumulative start */
439         for(i=atc->nthread-1; i>=1; i--)
440         {
441             tpl_n[i] = tpl_n[i-1];
442         }
443         tpl_n[0] = 0;
444
445         /* Fill our thread local array with indices sorted on thread */
446         for(i=start; i<end; i++)
447         {
448             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
449         }
450         /* Now tpl_n contains the cummulative count again */
451     }
452 }
453
454 static void make_thread_local_ind(pme_atomcomm_t *atc,
455                                   int thread,splinedata_t *spline)
456 {
457     int  n,t,i,start,end;
458     thread_plist_t *tpl;
459
460     /* Combine the indices made by each thread into one index */
461
462     n = 0;
463     start = 0;
464     for(t=0; t<atc->nthread; t++)
465     {
466         tpl = &atc->thread_plist[t];
467         /* Copy our part (start - end) from the list of thread t */
468         if (thread > 0)
469         {
470             start = tpl->n[thread-1];
471         }
472         end = tpl->n[thread];
473         for(i=start; i<end; i++)
474         {
475             spline->ind[n++] = tpl->i[i];
476         }
477     }
478
479     spline->n = n;
480 }
481
482
483 static void pme_calc_pidx(int start, int end,
484                           matrix recipbox, rvec x[],
485                           pme_atomcomm_t *atc, int *count)
486 {
487     int  nslab,i;
488     int  si;
489     real *xptr,s;
490     real rxx,ryx,rzx,ryy,rzy;
491     int *pd;
492
493     /* Calculate PME task index (pidx) for each grid index.
494      * Here we always assign equally sized slabs to each node
495      * for load balancing reasons (the PME grid spacing is not used).
496      */
497
498     nslab = atc->nslab;
499     pd    = atc->pd;
500
501     /* Reset the count */
502     for(i=0; i<nslab; i++)
503     {
504         count[i] = 0;
505     }
506
507     if (atc->dimind == 0)
508     {
509         rxx = recipbox[XX][XX];
510         ryx = recipbox[YY][XX];
511         rzx = recipbox[ZZ][XX];
512         /* Calculate the node index in x-dimension */
513         for(i=start; i<end; i++)
514         {
515             xptr   = x[i];
516             /* Fractional coordinates along box vectors */
517             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
518             si = (int)(s + 2*nslab) % nslab;
519             pd[i] = si;
520             count[si]++;
521         }
522     }
523     else
524     {
525         ryy = recipbox[YY][YY];
526         rzy = recipbox[ZZ][YY];
527         /* Calculate the node index in y-dimension */
528         for(i=start; i<end; i++)
529         {
530             xptr   = x[i];
531             /* Fractional coordinates along box vectors */
532             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
533             si = (int)(s + 2*nslab) % nslab;
534             pd[i] = si;
535             count[si]++;
536         }
537     }
538 }
539
540 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
541                                   pme_atomcomm_t *atc)
542 {
543     int nthread,thread,slab;
544
545     nthread = atc->nthread;
546
547 #pragma omp parallel for num_threads(nthread) schedule(static)
548     for(thread=0; thread<nthread; thread++)
549     {
550         pme_calc_pidx(natoms* thread   /nthread,
551                       natoms*(thread+1)/nthread,
552                       recipbox,x,atc,atc->count_thread[thread]);
553     }
554     /* Non-parallel reduction, since nslab is small */
555
556     for(thread=1; thread<nthread; thread++)
557     {
558         for(slab=0; slab<atc->nslab; slab++)
559         {
560             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
561         }
562     }
563 }
564
565 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
566 {
567     int i,d;
568
569     srenew(spline->ind,atc->nalloc);
570     /* Initialize the index to identity so it works without threads */
571     for(i=0; i<atc->nalloc; i++)
572     {
573         spline->ind[i] = i;
574     }
575
576     for(d=0;d<DIM;d++)
577     {
578         srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
579         srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
580     }
581 }
582
583 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
584 {
585     int nalloc_old,i,j,nalloc_tpl;
586
587     /* We have to avoid a NULL pointer for atc->x to avoid
588      * possible fatal errors in MPI routines.
589      */
590     if (atc->n > atc->nalloc || atc->nalloc == 0)
591     {
592         nalloc_old = atc->nalloc;
593         atc->nalloc = over_alloc_dd(max(atc->n,1));
594
595         if (atc->nslab > 1) {
596             srenew(atc->x,atc->nalloc);
597             srenew(atc->q,atc->nalloc);
598             srenew(atc->f,atc->nalloc);
599             for(i=nalloc_old; i<atc->nalloc; i++)
600             {
601                 clear_rvec(atc->f[i]);
602             }
603         }
604         if (atc->bSpread) {
605             srenew(atc->fractx,atc->nalloc);
606             srenew(atc->idx   ,atc->nalloc);
607
608             if (atc->nthread > 1)
609             {
610                 srenew(atc->thread_idx,atc->nalloc);
611             }
612
613             for(i=0; i<atc->nthread; i++)
614             {
615                 pme_realloc_splinedata(&atc->spline[i],atc);
616             }
617         }
618     }
619 }
620
621 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
622                          int n, gmx_bool bXF, rvec *x_f, real *charge,
623                          pme_atomcomm_t *atc)
624 /* Redistribute particle data for PME calculation */
625 /* domain decomposition by x coordinate           */
626 {
627     int *idxa;
628     int i, ii;
629
630     if(FALSE == pme->redist_init) {
631         snew(pme->scounts,atc->nslab);
632         snew(pme->rcounts,atc->nslab);
633         snew(pme->sdispls,atc->nslab);
634         snew(pme->rdispls,atc->nslab);
635         snew(pme->sidx,atc->nslab);
636         pme->redist_init = TRUE;
637     }
638     if (n > pme->redist_buf_nalloc) {
639         pme->redist_buf_nalloc = over_alloc_dd(n);
640         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
641     }
642
643     pme->idxa = atc->pd;
644
645 #ifdef GMX_MPI
646     if (forw && bXF) {
647         /* forward, redistribution from pp to pme */
648
649         /* Calculate send counts and exchange them with other nodes */
650         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
651         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
652         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
653
654         /* Calculate send and receive displacements and index into send
655            buffer */
656         pme->sdispls[0]=0;
657         pme->rdispls[0]=0;
658         pme->sidx[0]=0;
659         for(i=1; i<atc->nslab; i++) {
660             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
661             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
662             pme->sidx[i]=pme->sdispls[i];
663         }
664         /* Total # of particles to be received */
665         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
666
667         pme_realloc_atomcomm_things(atc);
668
669         /* Copy particle coordinates into send buffer and exchange*/
670         for(i=0; (i<n); i++) {
671             ii=DIM*pme->sidx[pme->idxa[i]];
672             pme->sidx[pme->idxa[i]]++;
673             pme->redist_buf[ii+XX]=x_f[i][XX];
674             pme->redist_buf[ii+YY]=x_f[i][YY];
675             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
676         }
677         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
678                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
679                       pme->rvec_mpi, atc->mpi_comm);
680     }
681     if (forw) {
682         /* Copy charge into send buffer and exchange*/
683         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
684         for(i=0; (i<n); i++) {
685             ii=pme->sidx[pme->idxa[i]];
686             pme->sidx[pme->idxa[i]]++;
687             pme->redist_buf[ii]=charge[i];
688         }
689         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
690                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
691                       atc->mpi_comm);
692     }
693     else { /* backward, redistribution from pme to pp */
694         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
695                       pme->redist_buf, pme->scounts, pme->sdispls,
696                       pme->rvec_mpi, atc->mpi_comm);
697
698         /* Copy data from receive buffer */
699         for(i=0; i<atc->nslab; i++)
700             pme->sidx[i] = pme->sdispls[i];
701         for(i=0; (i<n); i++) {
702             ii = DIM*pme->sidx[pme->idxa[i]];
703             x_f[i][XX] += pme->redist_buf[ii+XX];
704             x_f[i][YY] += pme->redist_buf[ii+YY];
705             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
706             pme->sidx[pme->idxa[i]]++;
707         }
708     }
709 #endif
710 }
711
712 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
713                             gmx_bool bBackward,int shift,
714                             void *buf_s,int nbyte_s,
715                             void *buf_r,int nbyte_r)
716 {
717 #ifdef GMX_MPI
718     int dest,src;
719     MPI_Status stat;
720
721     if (bBackward == FALSE) {
722         dest = atc->node_dest[shift];
723         src  = atc->node_src[shift];
724     } else {
725         dest = atc->node_src[shift];
726         src  = atc->node_dest[shift];
727     }
728
729     if (nbyte_s > 0 && nbyte_r > 0) {
730         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
731                      dest,shift,
732                      buf_r,nbyte_r,MPI_BYTE,
733                      src,shift,
734                      atc->mpi_comm,&stat);
735     } else if (nbyte_s > 0) {
736         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
737                  dest,shift,
738                  atc->mpi_comm);
739     } else if (nbyte_r > 0) {
740         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
741                  src,shift,
742                  atc->mpi_comm,&stat);
743     }
744 #endif
745 }
746
747 static void dd_pmeredist_x_q(gmx_pme_t pme,
748                              int n, gmx_bool bX, rvec *x, real *charge,
749                              pme_atomcomm_t *atc)
750 {
751     int *commnode,*buf_index;
752     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
753
754     commnode  = atc->node_dest;
755     buf_index = atc->buf_index;
756
757     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
758
759     nsend = 0;
760     for(i=0; i<nnodes_comm; i++) {
761         buf_index[commnode[i]] = nsend;
762         nsend += atc->count[commnode[i]];
763     }
764     if (bX) {
765         if (atc->count[atc->nodeid] + nsend != n)
766             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
767                       "This usually means that your system is not well equilibrated.",
768                       n - (atc->count[atc->nodeid] + nsend),
769                       pme->nodeid,'x'+atc->dimind);
770
771         if (nsend > pme->buf_nalloc) {
772             pme->buf_nalloc = over_alloc_dd(nsend);
773             srenew(pme->bufv,pme->buf_nalloc);
774             srenew(pme->bufr,pme->buf_nalloc);
775         }
776
777         atc->n = atc->count[atc->nodeid];
778         for(i=0; i<nnodes_comm; i++) {
779             scount = atc->count[commnode[i]];
780             /* Communicate the count */
781             if (debug)
782                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
783                         atc->dimind,atc->nodeid,commnode[i],scount);
784             pme_dd_sendrecv(atc,FALSE,i,
785                             &scount,sizeof(int),
786                             &atc->rcount[i],sizeof(int));
787             atc->n += atc->rcount[i];
788         }
789
790         pme_realloc_atomcomm_things(atc);
791     }
792
793     local_pos = 0;
794     for(i=0; i<n; i++) {
795         node = atc->pd[i];
796         if (node == atc->nodeid) {
797             /* Copy direct to the receive buffer */
798             if (bX) {
799                 copy_rvec(x[i],atc->x[local_pos]);
800             }
801             atc->q[local_pos] = charge[i];
802             local_pos++;
803         } else {
804             /* Copy to the send buffer */
805             if (bX) {
806                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
807             }
808             pme->bufr[buf_index[node]] = charge[i];
809             buf_index[node]++;
810         }
811     }
812
813     buf_pos = 0;
814     for(i=0; i<nnodes_comm; i++) {
815         scount = atc->count[commnode[i]];
816         rcount = atc->rcount[i];
817         if (scount > 0 || rcount > 0) {
818             if (bX) {
819                 /* Communicate the coordinates */
820                 pme_dd_sendrecv(atc,FALSE,i,
821                                 pme->bufv[buf_pos],scount*sizeof(rvec),
822                                 atc->x[local_pos],rcount*sizeof(rvec));
823             }
824             /* Communicate the charges */
825             pme_dd_sendrecv(atc,FALSE,i,
826                             pme->bufr+buf_pos,scount*sizeof(real),
827                             atc->q+local_pos,rcount*sizeof(real));
828             buf_pos   += scount;
829             local_pos += atc->rcount[i];
830         }
831     }
832 }
833
834 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
835                            int n, rvec *f,
836                            gmx_bool bAddF)
837 {
838   int *commnode,*buf_index;
839   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
840
841   commnode  = atc->node_dest;
842   buf_index = atc->buf_index;
843
844   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
845
846   local_pos = atc->count[atc->nodeid];
847   buf_pos = 0;
848   for(i=0; i<nnodes_comm; i++) {
849     scount = atc->rcount[i];
850     rcount = atc->count[commnode[i]];
851     if (scount > 0 || rcount > 0) {
852       /* Communicate the forces */
853       pme_dd_sendrecv(atc,TRUE,i,
854                       atc->f[local_pos],scount*sizeof(rvec),
855                       pme->bufv[buf_pos],rcount*sizeof(rvec));
856       local_pos += scount;
857     }
858     buf_index[commnode[i]] = buf_pos;
859     buf_pos   += rcount;
860   }
861
862     local_pos = 0;
863     if (bAddF)
864     {
865         for(i=0; i<n; i++)
866         {
867             node = atc->pd[i];
868             if (node == atc->nodeid)
869             {
870                 /* Add from the local force array */
871                 rvec_inc(f[i],atc->f[local_pos]);
872                 local_pos++;
873             }
874             else
875             {
876                 /* Add from the receive buffer */
877                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
878                 buf_index[node]++;
879             }
880         }
881     }
882     else
883     {
884         for(i=0; i<n; i++)
885         {
886             node = atc->pd[i];
887             if (node == atc->nodeid)
888             {
889                 /* Copy from the local force array */
890                 copy_rvec(atc->f[local_pos],f[i]);
891                 local_pos++;
892             }
893             else
894             {
895                 /* Copy from the receive buffer */
896                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
897                 buf_index[node]++;
898             }
899         }
900     }
901 }
902
903 #ifdef GMX_MPI
904 static void
905 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
906 {
907     pme_overlap_t *overlap;
908     int send_index0,send_nindex;
909     int recv_index0,recv_nindex;
910     MPI_Status stat;
911     int i,j,k,ix,iy,iz,icnt;
912     int ipulse,send_id,recv_id,datasize;
913     real *p;
914     real *sendptr,*recvptr;
915
916     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
917     overlap = &pme->overlap[1];
918
919     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
920     {
921         /* Since we have already (un)wrapped the overlap in the z-dimension,
922          * we only have to communicate 0 to nkz (not pmegrid_nz).
923          */
924         if (direction==GMX_SUM_QGRID_FORWARD)
925         {
926             send_id = overlap->send_id[ipulse];
927             recv_id = overlap->recv_id[ipulse];
928             send_index0   = overlap->comm_data[ipulse].send_index0;
929             send_nindex   = overlap->comm_data[ipulse].send_nindex;
930             recv_index0   = overlap->comm_data[ipulse].recv_index0;
931             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
932         }
933         else
934         {
935             send_id = overlap->recv_id[ipulse];
936             recv_id = overlap->send_id[ipulse];
937             send_index0   = overlap->comm_data[ipulse].recv_index0;
938             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
939             recv_index0   = overlap->comm_data[ipulse].send_index0;
940             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
941         }
942
943         /* Copy data to contiguous send buffer */
944         if (debug)
945         {
946             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
947                     pme->nodeid,overlap->nodeid,send_id,
948                     pme->pmegrid_start_iy,
949                     send_index0-pme->pmegrid_start_iy,
950                     send_index0-pme->pmegrid_start_iy+send_nindex);
951         }
952         icnt = 0;
953         for(i=0;i<pme->pmegrid_nx;i++)
954         {
955             ix = i;
956             for(j=0;j<send_nindex;j++)
957             {
958                 iy = j + send_index0 - pme->pmegrid_start_iy;
959                 for(k=0;k<pme->nkz;k++)
960                 {
961                     iz = k;
962                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
963                 }
964             }
965         }
966
967         datasize      = pme->pmegrid_nx * pme->nkz;
968
969         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
970                      send_id,ipulse,
971                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
972                      recv_id,ipulse,
973                      overlap->mpi_comm,&stat);
974
975         /* Get data from contiguous recv buffer */
976         if (debug)
977         {
978             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
979                     pme->nodeid,overlap->nodeid,recv_id,
980                     pme->pmegrid_start_iy,
981                     recv_index0-pme->pmegrid_start_iy,
982                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
983         }
984         icnt = 0;
985         for(i=0;i<pme->pmegrid_nx;i++)
986         {
987             ix = i;
988             for(j=0;j<recv_nindex;j++)
989             {
990                 iy = j + recv_index0 - pme->pmegrid_start_iy;
991                 for(k=0;k<pme->nkz;k++)
992                 {
993                     iz = k;
994                     if(direction==GMX_SUM_QGRID_FORWARD)
995                     {
996                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
997                     }
998                     else
999                     {
1000                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1001                     }
1002                 }
1003             }
1004         }
1005     }
1006
1007     /* Major dimension is easier, no copying required,
1008      * but we might have to sum to separate array.
1009      * Since we don't copy, we have to communicate up to pmegrid_nz,
1010      * not nkz as for the minor direction.
1011      */
1012     overlap = &pme->overlap[0];
1013
1014     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1015     {
1016         if(direction==GMX_SUM_QGRID_FORWARD)
1017         {
1018             send_id = overlap->send_id[ipulse];
1019             recv_id = overlap->recv_id[ipulse];
1020             send_index0   = overlap->comm_data[ipulse].send_index0;
1021             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1022             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1023             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1024             recvptr   = overlap->recvbuf;
1025         }
1026         else
1027         {
1028             send_id = overlap->recv_id[ipulse];
1029             recv_id = overlap->send_id[ipulse];
1030             send_index0   = overlap->comm_data[ipulse].recv_index0;
1031             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1032             recv_index0   = overlap->comm_data[ipulse].send_index0;
1033             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1034             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1035         }
1036
1037         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1038         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1039
1040         if (debug)
1041         {
1042             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1043                     pme->nodeid,overlap->nodeid,send_id,
1044                     pme->pmegrid_start_ix,
1045                     send_index0-pme->pmegrid_start_ix,
1046                     send_index0-pme->pmegrid_start_ix+send_nindex);
1047             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1048                     pme->nodeid,overlap->nodeid,recv_id,
1049                     pme->pmegrid_start_ix,
1050                     recv_index0-pme->pmegrid_start_ix,
1051                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1052         }
1053
1054         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1055                      send_id,ipulse,
1056                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1057                      recv_id,ipulse,
1058                      overlap->mpi_comm,&stat);
1059
1060         /* ADD data from contiguous recv buffer */
1061         if(direction==GMX_SUM_QGRID_FORWARD)
1062         {
1063             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1064             for(i=0;i<recv_nindex*datasize;i++)
1065             {
1066                 p[i] += overlap->recvbuf[i];
1067             }
1068         }
1069     }
1070 }
1071 #endif
1072
1073
1074 static int
1075 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1076 {
1077     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1078     ivec    local_pme_size;
1079     int     i,ix,iy,iz;
1080     int     pmeidx,fftidx;
1081
1082     /* Dimensions should be identical for A/B grid, so we just use A here */
1083     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1084                                    local_fft_ndata,
1085                                    local_fft_offset,
1086                                    local_fft_size);
1087
1088     local_pme_size[0] = pme->pmegrid_nx;
1089     local_pme_size[1] = pme->pmegrid_ny;
1090     local_pme_size[2] = pme->pmegrid_nz;
1091
1092     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1093      the offset is identical, and the PME grid always has more data (due to overlap)
1094      */
1095     {
1096 #ifdef DEBUG_PME
1097         FILE *fp,*fp2;
1098         char fn[STRLEN],format[STRLEN];
1099         real val;
1100         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1101         fp = ffopen(fn,"w");
1102         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1103         fp2 = ffopen(fn,"w");
1104      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1105 #endif
1106
1107     for(ix=0;ix<local_fft_ndata[XX];ix++)
1108     {
1109         for(iy=0;iy<local_fft_ndata[YY];iy++)
1110         {
1111             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1112             {
1113                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1114                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1115                 fftgrid[fftidx] = pmegrid[pmeidx];
1116 #ifdef DEBUG_PME
1117                 val = 100*pmegrid[pmeidx];
1118                 if (pmegrid[pmeidx] != 0)
1119                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1120                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1121                 if (pmegrid[pmeidx] != 0)
1122                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1123                             "qgrid",
1124                             pme->pmegrid_start_ix + ix,
1125                             pme->pmegrid_start_iy + iy,
1126                             pme->pmegrid_start_iz + iz,
1127                             pmegrid[pmeidx]);
1128 #endif
1129             }
1130         }
1131     }
1132 #ifdef DEBUG_PME
1133     ffclose(fp);
1134     ffclose(fp2);
1135 #endif
1136     }
1137     return 0;
1138 }
1139
1140
1141 static gmx_cycles_t omp_cyc_start()
1142 {
1143     return gmx_cycles_read();
1144 }
1145
1146 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1147 {
1148     return gmx_cycles_read() - c;
1149 }
1150
1151
1152 static int
1153 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1154                         int nthread,int thread)
1155 {
1156     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1157     ivec    local_pme_size;
1158     int     ixy0,ixy1,ixy,ix,iy,iz;
1159     int     pmeidx,fftidx;
1160 #ifdef PME_TIME_THREADS
1161     gmx_cycles_t c1;
1162     static double cs1=0;
1163     static int cnt=0;
1164 #endif
1165
1166 #ifdef PME_TIME_THREADS
1167     c1 = omp_cyc_start();
1168 #endif
1169     /* Dimensions should be identical for A/B grid, so we just use A here */
1170     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1171                                    local_fft_ndata,
1172                                    local_fft_offset,
1173                                    local_fft_size);
1174
1175     local_pme_size[0] = pme->pmegrid_nx;
1176     local_pme_size[1] = pme->pmegrid_ny;
1177     local_pme_size[2] = pme->pmegrid_nz;
1178
1179     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1180      the offset is identical, and the PME grid always has more data (due to overlap)
1181      */
1182     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1183     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1184
1185     for(ixy=ixy0;ixy<ixy1;ixy++)
1186     {
1187         ix = ixy/local_fft_ndata[YY];
1188         iy = ixy - ix*local_fft_ndata[YY];
1189
1190         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1191         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1192         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1193         {
1194             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1195         }
1196     }
1197
1198 #ifdef PME_TIME_THREADS
1199     c1 = omp_cyc_end(c1);
1200     cs1 += (double)c1;
1201     cnt++;
1202     if (cnt % 20 == 0)
1203     {
1204         printf("copy %.2f\n",cs1*1e-9);
1205     }
1206 #endif
1207
1208     return 0;
1209 }
1210
1211
1212 static void
1213 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1214 {
1215     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1216
1217     nx = pme->nkx;
1218     ny = pme->nky;
1219     nz = pme->nkz;
1220
1221     pnx = pme->pmegrid_nx;
1222     pny = pme->pmegrid_ny;
1223     pnz = pme->pmegrid_nz;
1224
1225     overlap = pme->pme_order - 1;
1226
1227     /* Add periodic overlap in z */
1228     for(ix=0; ix<pme->pmegrid_nx; ix++)
1229     {
1230         for(iy=0; iy<pme->pmegrid_ny; iy++)
1231         {
1232             for(iz=0; iz<overlap; iz++)
1233             {
1234                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1235                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1236             }
1237         }
1238     }
1239
1240     if (pme->nnodes_minor == 1)
1241     {
1242        for(ix=0; ix<pme->pmegrid_nx; ix++)
1243        {
1244            for(iy=0; iy<overlap; iy++)
1245            {
1246                for(iz=0; iz<nz; iz++)
1247                {
1248                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1249                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1250                }
1251            }
1252        }
1253     }
1254
1255     if (pme->nnodes_major == 1)
1256     {
1257         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1258
1259         for(ix=0; ix<overlap; ix++)
1260         {
1261             for(iy=0; iy<ny_x; iy++)
1262             {
1263                 for(iz=0; iz<nz; iz++)
1264                 {
1265                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1266                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1267                 }
1268             }
1269         }
1270     }
1271 }
1272
1273
1274 static void
1275 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1276 {
1277     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1278
1279     nx = pme->nkx;
1280     ny = pme->nky;
1281     nz = pme->nkz;
1282
1283     pnx = pme->pmegrid_nx;
1284     pny = pme->pmegrid_ny;
1285     pnz = pme->pmegrid_nz;
1286
1287     overlap = pme->pme_order - 1;
1288
1289     if (pme->nnodes_major == 1)
1290     {
1291         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1292
1293         for(ix=0; ix<overlap; ix++)
1294         {
1295             int iy,iz;
1296
1297             for(iy=0; iy<ny_x; iy++)
1298             {
1299                 for(iz=0; iz<nz; iz++)
1300                 {
1301                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1302                         pmegrid[(ix*pny+iy)*pnz+iz];
1303                 }
1304             }
1305         }
1306     }
1307
1308     if (pme->nnodes_minor == 1)
1309     {
1310 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1311        for(ix=0; ix<pme->pmegrid_nx; ix++)
1312        {
1313            int iy,iz;
1314
1315            for(iy=0; iy<overlap; iy++)
1316            {
1317                for(iz=0; iz<nz; iz++)
1318                {
1319                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1320                        pmegrid[(ix*pny+iy)*pnz+iz];
1321                }
1322            }
1323        }
1324     }
1325
1326     /* Copy periodic overlap in z */
1327 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1328     for(ix=0; ix<pme->pmegrid_nx; ix++)
1329     {
1330         int iy,iz;
1331
1332         for(iy=0; iy<pme->pmegrid_ny; iy++)
1333         {
1334             for(iz=0; iz<overlap; iz++)
1335             {
1336                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1337                     pmegrid[(ix*pny+iy)*pnz+iz];
1338             }
1339         }
1340     }
1341 }
1342
1343 static void clear_grid(int nx,int ny,int nz,real *grid,
1344                        ivec fs,int *flag,
1345                        int fx,int fy,int fz,
1346                        int order)
1347 {
1348     int nc,ncz;
1349     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1350     int flind;
1351
1352     nc  = 2 + (order - 2)/FLBS;
1353     ncz = 2 + (order - 2)/FLBSZ;
1354
1355     for(fsx=fx; fsx<fx+nc; fsx++)
1356     {
1357         for(fsy=fy; fsy<fy+nc; fsy++)
1358         {
1359             for(fsz=fz; fsz<fz+ncz; fsz++)
1360             {
1361                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1362                 if (flag[flind] == 0)
1363                 {
1364                     gx = fsx*FLBS;
1365                     gy = fsy*FLBS;
1366                     gz = fsz*FLBSZ;
1367                     g0x = (gx*ny + gy)*nz + gz;
1368                     for(x=0; x<FLBS; x++)
1369                     {
1370                         g0y = g0x;
1371                         for(y=0; y<FLBS; y++)
1372                         {
1373                             for(z=0; z<FLBSZ; z++)
1374                             {
1375                                 grid[g0y+z] = 0;
1376                             }
1377                             g0y += nz;
1378                         }
1379                         g0x += ny*nz;
1380                     }
1381
1382                     flag[flind] = 1;
1383                 }
1384             }
1385         }
1386     }
1387 }
1388
1389 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1390 #define DO_BSPLINE(order)                            \
1391 for(ithx=0; (ithx<order); ithx++)                    \
1392 {                                                    \
1393     index_x = (i0+ithx)*pny*pnz;                     \
1394     valx    = qn*thx[ithx];                          \
1395                                                      \
1396     for(ithy=0; (ithy<order); ithy++)                \
1397     {                                                \
1398         valxy    = valx*thy[ithy];                   \
1399         index_xy = index_x+(j0+ithy)*pnz;            \
1400                                                      \
1401         for(ithz=0; (ithz<order); ithz++)            \
1402         {                                            \
1403             index_xyz        = index_xy+(k0+ithz);   \
1404             grid[index_xyz] += valxy*thz[ithz];      \
1405         }                                            \
1406     }                                                \
1407 }
1408
1409
1410 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1411                                      pme_atomcomm_t *atc, splinedata_t *spline,
1412                                      pme_spline_work_t *work)
1413 {
1414
1415     /* spread charges from home atoms to local grid */
1416     real     *grid;
1417     pme_overlap_t *ol;
1418     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1419     int *    idxptr;
1420     int      order,norder,index_x,index_xy,index_xyz;
1421     real     valx,valxy,qn;
1422     real     *thx,*thy,*thz;
1423     int      localsize, bndsize;
1424     int      pnx,pny,pnz,ndatatot;
1425     int      offx,offy,offz;
1426
1427     pnx = pmegrid->n[XX];
1428     pny = pmegrid->n[YY];
1429     pnz = pmegrid->n[ZZ];
1430
1431     offx = pmegrid->offset[XX];
1432     offy = pmegrid->offset[YY];
1433     offz = pmegrid->offset[ZZ];
1434
1435     ndatatot = pnx*pny*pnz;
1436     grid = pmegrid->grid;
1437     for(i=0;i<ndatatot;i++)
1438     {
1439         grid[i] = 0;
1440     }
1441
1442     order = pmegrid->order;
1443
1444     for(nn=0; nn<spline->n; nn++)
1445     {
1446         n  = spline->ind[nn];
1447         qn = atc->q[n];
1448
1449         if (qn != 0)
1450         {
1451             idxptr = atc->idx[n];
1452             norder = nn*order;
1453
1454             i0   = idxptr[XX] - offx;
1455             j0   = idxptr[YY] - offy;
1456             k0   = idxptr[ZZ] - offz;
1457
1458             thx = spline->theta[XX] + norder;
1459             thy = spline->theta[YY] + norder;
1460             thz = spline->theta[ZZ] + norder;
1461
1462             switch (order) {
1463             case 4:
1464 #ifdef PME_SSE
1465 #ifdef PME_SSE_UNALIGNED
1466 #define PME_SPREAD_SSE_ORDER4
1467 #else
1468 #define PME_SPREAD_SSE_ALIGNED
1469 #define PME_ORDER 4
1470 #endif
1471 #include "pme_sse_single.h"
1472 #else
1473                 DO_BSPLINE(4);
1474 #endif
1475                 break;
1476             case 5:
1477 #ifdef PME_SSE
1478 #define PME_SPREAD_SSE_ALIGNED
1479 #define PME_ORDER 5
1480 #include "pme_sse_single.h"
1481 #else
1482                 DO_BSPLINE(5);
1483 #endif
1484                 break;
1485             default:
1486                 DO_BSPLINE(order);
1487                 break;
1488             }
1489         }
1490     }
1491 }
1492
1493 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1494 {
1495 #ifdef PME_SSE
1496     if (pme_order == 5
1497 #ifndef PME_SSE_UNALIGNED
1498         || pme_order == 4
1499 #endif
1500         )
1501     {
1502         /* Round nz up to a multiple of 4 to ensure alignment */
1503         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1504     }
1505 #endif
1506 }
1507
1508 static void set_gridsize_alignment(int *gridsize,int pme_order)
1509 {
1510 #ifdef PME_SSE
1511 #ifndef PME_SSE_UNALIGNED
1512     if (pme_order == 4)
1513     {
1514         /* Add extra elements to ensured aligned operations do not go
1515          * beyond the allocated grid size.
1516          * Note that for pme_order=5, the pme grid z-size alignment
1517          * ensures that we will not go beyond the grid size.
1518          */
1519          *gridsize += 4;
1520     }
1521 #endif
1522 #endif
1523 }
1524
1525 static void pmegrid_init(pmegrid_t *grid,
1526                          int cx, int cy, int cz,
1527                          int x0, int y0, int z0,
1528                          int x1, int y1, int z1,
1529                          gmx_bool set_alignment,
1530                          int pme_order,
1531                          real *ptr)
1532 {
1533     int nz,gridsize;
1534
1535     grid->ci[XX] = cx;
1536     grid->ci[YY] = cy;
1537     grid->ci[ZZ] = cz;
1538     grid->offset[XX] = x0;
1539     grid->offset[YY] = y0;
1540     grid->offset[ZZ] = z0;
1541     grid->n[XX]      = x1 - x0 + pme_order - 1;
1542     grid->n[YY]      = y1 - y0 + pme_order - 1;
1543     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1544
1545     nz = grid->n[ZZ];
1546     set_grid_alignment(&nz,pme_order);
1547     if (set_alignment)
1548     {
1549         grid->n[ZZ] = nz;
1550     }
1551     else if (nz != grid->n[ZZ])
1552     {
1553         gmx_incons("pmegrid_init call with an unaligned z size");
1554     }
1555
1556     grid->order = pme_order;
1557     if (ptr == NULL)
1558     {
1559         gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1560         set_gridsize_alignment(&gridsize,pme_order);
1561         snew_aligned(grid->grid,gridsize,16);
1562     }
1563     else
1564     {
1565         grid->grid = ptr;
1566     }
1567 }
1568
1569 static int div_round_up(int enumerator,int denominator)
1570 {
1571     return (enumerator + denominator - 1)/denominator;
1572 }
1573
1574 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1575                                   ivec nsub)
1576 {
1577     int gsize_opt,gsize;
1578     int nsx,nsy,nsz;
1579     char *env;
1580
1581     gsize_opt = -1;
1582     for(nsx=1; nsx<=nthread; nsx++)
1583     {
1584         if (nthread % nsx == 0)
1585         {
1586             for(nsy=1; nsy<=nthread; nsy++)
1587             {
1588                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1589                 {
1590                     nsz = nthread/(nsx*nsy);
1591
1592                     /* Determine the number of grid points per thread */
1593                     gsize =
1594                         (div_round_up(n[XX],nsx) + ovl)*
1595                         (div_round_up(n[YY],nsy) + ovl)*
1596                         (div_round_up(n[ZZ],nsz) + ovl);
1597
1598                     /* Minimize the number of grids points per thread
1599                      * and, secondarily, the number of cuts in minor dimensions.
1600                      */
1601                     if (gsize_opt == -1 ||
1602                         gsize < gsize_opt ||
1603                         (gsize == gsize_opt &&
1604                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1605                     {
1606                         nsub[XX] = nsx;
1607                         nsub[YY] = nsy;
1608                         nsub[ZZ] = nsz;
1609                         gsize_opt = gsize;
1610                     }
1611                 }
1612             }
1613         }
1614     }
1615
1616     env = getenv("GMX_PME_THREAD_DIVISION");
1617     if (env != NULL)
1618     {
1619         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1620     }
1621
1622     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1623     {
1624         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1625     }
1626 }
1627
1628 static void pmegrids_init(pmegrids_t *grids,
1629                           int nx,int ny,int nz,int nz_base,
1630                           int pme_order,
1631                           int nthread,
1632                           int overlap_x,
1633                           int overlap_y)
1634 {
1635     ivec n,n_base,g0,g1;
1636     int t,x,y,z,d,i,tfac;
1637     int max_comm_lines;
1638
1639     n[XX] = nx - (pme_order - 1);
1640     n[YY] = ny - (pme_order - 1);
1641     n[ZZ] = nz - (pme_order - 1);
1642
1643     copy_ivec(n,n_base);
1644     n_base[ZZ] = nz_base;
1645
1646     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1647                  NULL);
1648
1649     grids->nthread = nthread;
1650
1651     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1652
1653     if (grids->nthread > 1)
1654     {
1655         ivec nst;
1656         int gridsize;
1657         real *grid_all;
1658
1659         for(d=0; d<DIM; d++)
1660         {
1661             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1662         }
1663         set_grid_alignment(&nst[ZZ],pme_order);
1664
1665         if (debug)
1666         {
1667             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1668                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1669             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1670                     nx,ny,nz,
1671                     nst[XX],nst[YY],nst[ZZ]);
1672         }
1673
1674         snew(grids->grid_th,grids->nthread);
1675         t = 0;
1676         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1677         set_gridsize_alignment(&gridsize,pme_order);
1678         snew_aligned(grid_all,
1679                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1680                      16);
1681
1682         for(x=0; x<grids->nc[XX]; x++)
1683         {
1684             for(y=0; y<grids->nc[YY]; y++)
1685             {
1686                 for(z=0; z<grids->nc[ZZ]; z++)
1687                 {
1688                     pmegrid_init(&grids->grid_th[t],
1689                                  x,y,z,
1690                                  (n[XX]*(x  ))/grids->nc[XX],
1691                                  (n[YY]*(y  ))/grids->nc[YY],
1692                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1693                                  (n[XX]*(x+1))/grids->nc[XX],
1694                                  (n[YY]*(y+1))/grids->nc[YY],
1695                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1696                                  TRUE,
1697                                  pme_order,
1698                                  grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1699                     t++;
1700                 }
1701             }
1702         }
1703     }
1704
1705     snew(grids->g2t,DIM);
1706     tfac = 1;
1707     for(d=DIM-1; d>=0; d--)
1708     {
1709         snew(grids->g2t[d],n[d]);
1710         t = 0;
1711         for(i=0; i<n[d]; i++)
1712         {
1713             /* The second check should match the parameters
1714              * of the pmegrid_init call above.
1715              */
1716             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1717             {
1718                 t++;
1719             }
1720             grids->g2t[d][i] = t*tfac;
1721         }
1722
1723         tfac *= grids->nc[d];
1724
1725         switch (d)
1726         {
1727         case XX: max_comm_lines = overlap_x;     break;
1728         case YY: max_comm_lines = overlap_y;     break;
1729         case ZZ: max_comm_lines = pme_order - 1; break;
1730         }
1731         grids->nthread_comm[d] = 0;
1732         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1733         {
1734             grids->nthread_comm[d]++;
1735         }
1736         if (debug != NULL)
1737         {
1738             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1739                     'x'+d,grids->nthread_comm[d]);
1740         }
1741         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1742          * work, but this is not a problematic restriction.
1743          */
1744         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1745         {
1746             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1747         }
1748     }
1749 }
1750
1751
1752 static void pmegrids_destroy(pmegrids_t *grids)
1753 {
1754     int t;
1755
1756     if (grids->grid.grid != NULL)
1757     {
1758         sfree(grids->grid.grid);
1759
1760         if (grids->nthread > 0)
1761         {
1762             for(t=0; t<grids->nthread; t++)
1763             {
1764                 sfree(grids->grid_th[t].grid);
1765             }
1766             sfree(grids->grid_th);
1767         }
1768     }
1769 }
1770
1771
1772 static void realloc_work(pme_work_t *work,int nkx)
1773 {
1774     if (nkx > work->nalloc)
1775     {
1776         work->nalloc = nkx;
1777         srenew(work->mhx  ,work->nalloc);
1778         srenew(work->mhy  ,work->nalloc);
1779         srenew(work->mhz  ,work->nalloc);
1780         srenew(work->m2   ,work->nalloc);
1781         /* Allocate an aligned pointer for SSE operations, including 3 extra
1782          * elements at the end since SSE operates on 4 elements at a time.
1783          */
1784         sfree_aligned(work->denom);
1785         sfree_aligned(work->tmp1);
1786         sfree_aligned(work->eterm);
1787         snew_aligned(work->denom,work->nalloc+3,16);
1788         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1789         snew_aligned(work->eterm,work->nalloc+3,16);
1790         srenew(work->m2inv,work->nalloc);
1791     }
1792 }
1793
1794
1795 static void free_work(pme_work_t *work)
1796 {
1797     sfree(work->mhx);
1798     sfree(work->mhy);
1799     sfree(work->mhz);
1800     sfree(work->m2);
1801     sfree_aligned(work->denom);
1802     sfree_aligned(work->tmp1);
1803     sfree_aligned(work->eterm);
1804     sfree(work->m2inv);
1805 }
1806
1807
1808 #ifdef PME_SSE
1809     /* Calculate exponentials through SSE in float precision */
1810 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1811 {
1812     {
1813         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1814         __m128 f_sse;
1815         __m128 lu;
1816         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1817         int kx;
1818         f_sse = _mm_load1_ps(&f);
1819         for(kx=0; kx<end; kx+=4)
1820         {
1821             tmp_d1   = _mm_load_ps(d_aligned+kx);
1822             lu       = _mm_rcp_ps(tmp_d1);
1823             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1824             tmp_r    = _mm_load_ps(r_aligned+kx);
1825             tmp_r    = gmx_mm_exp_ps(tmp_r);
1826             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1827             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1828             _mm_store_ps(e_aligned+kx,tmp_e);
1829         }
1830     }
1831 }
1832 #else
1833 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1834 {
1835     int kx;
1836     for(kx=start; kx<end; kx++)
1837     {
1838         d[kx] = 1.0/d[kx];
1839     }
1840     for(kx=start; kx<end; kx++)
1841     {
1842         r[kx] = exp(r[kx]);
1843     }
1844     for(kx=start; kx<end; kx++)
1845     {
1846         e[kx] = f*r[kx]*d[kx];
1847     }
1848 }
1849 #endif
1850
1851
1852 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1853                          real ewaldcoeff,real vol,
1854                          gmx_bool bEnerVir,
1855                          int nthread,int thread)
1856 {
1857     /* do recip sum over local cells in grid */
1858     /* y major, z middle, x minor or continuous */
1859     t_complex *p0;
1860     int     kx,ky,kz,maxkx,maxky,maxkz;
1861     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1862     real    mx,my,mz;
1863     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1864     real    ets2,struct2,vfactor,ets2vf;
1865     real    d1,d2,energy=0;
1866     real    by,bz;
1867     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1868     real    rxx,ryx,ryy,rzx,rzy,rzz;
1869     pme_work_t *work;
1870     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1871     real    mhxk,mhyk,mhzk,m2k;
1872     real    corner_fac;
1873     ivec    complex_order;
1874     ivec    local_ndata,local_offset,local_size;
1875     real    elfac;
1876
1877     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1878
1879     nx = pme->nkx;
1880     ny = pme->nky;
1881     nz = pme->nkz;
1882
1883     /* Dimensions should be identical for A/B grid, so we just use A here */
1884     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1885                                       complex_order,
1886                                       local_ndata,
1887                                       local_offset,
1888                                       local_size);
1889
1890     rxx = pme->recipbox[XX][XX];
1891     ryx = pme->recipbox[YY][XX];
1892     ryy = pme->recipbox[YY][YY];
1893     rzx = pme->recipbox[ZZ][XX];
1894     rzy = pme->recipbox[ZZ][YY];
1895     rzz = pme->recipbox[ZZ][ZZ];
1896
1897     maxkx = (nx+1)/2;
1898     maxky = (ny+1)/2;
1899     maxkz = nz/2+1;
1900
1901     work = &pme->work[thread];
1902     mhx   = work->mhx;
1903     mhy   = work->mhy;
1904     mhz   = work->mhz;
1905     m2    = work->m2;
1906     denom = work->denom;
1907     tmp1  = work->tmp1;
1908     eterm = work->eterm;
1909     m2inv = work->m2inv;
1910
1911     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1912     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1913
1914     for(iyz=iyz0; iyz<iyz1; iyz++)
1915     {
1916         iy = iyz/local_ndata[ZZ];
1917         iz = iyz - iy*local_ndata[ZZ];
1918
1919         ky = iy + local_offset[YY];
1920
1921         if (ky < maxky)
1922         {
1923             my = ky;
1924         }
1925         else
1926         {
1927             my = (ky - ny);
1928         }
1929
1930         by = M_PI*vol*pme->bsp_mod[YY][ky];
1931
1932         kz = iz + local_offset[ZZ];
1933
1934         mz = kz;
1935
1936         bz = pme->bsp_mod[ZZ][kz];
1937
1938         /* 0.5 correction for corner points */
1939         corner_fac = 1;
1940         if (kz == 0 || kz == (nz+1)/2)
1941         {
1942             corner_fac = 0.5;
1943         }
1944
1945         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1946
1947         /* We should skip the k-space point (0,0,0) */
1948         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1949         {
1950             kxstart = local_offset[XX];
1951         }
1952         else
1953         {
1954             kxstart = local_offset[XX] + 1;
1955             p0++;
1956         }
1957         kxend = local_offset[XX] + local_ndata[XX];
1958
1959         if (bEnerVir)
1960         {
1961             /* More expensive inner loop, especially because of the storage
1962              * of the mh elements in array's.
1963              * Because x is the minor grid index, all mh elements
1964              * depend on kx for triclinic unit cells.
1965              */
1966
1967                 /* Two explicit loops to avoid a conditional inside the loop */
1968             for(kx=kxstart; kx<maxkx; kx++)
1969             {
1970                 mx = kx;
1971
1972                 mhxk      = mx * rxx;
1973                 mhyk      = mx * ryx + my * ryy;
1974                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1975                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1976                 mhx[kx]   = mhxk;
1977                 mhy[kx]   = mhyk;
1978                 mhz[kx]   = mhzk;
1979                 m2[kx]    = m2k;
1980                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1981                 tmp1[kx]  = -factor*m2k;
1982             }
1983
1984             for(kx=maxkx; kx<kxend; kx++)
1985             {
1986                 mx = (kx - nx);
1987
1988                 mhxk      = mx * rxx;
1989                 mhyk      = mx * ryx + my * ryy;
1990                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1991                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1992                 mhx[kx]   = mhxk;
1993                 mhy[kx]   = mhyk;
1994                 mhz[kx]   = mhzk;
1995                 m2[kx]    = m2k;
1996                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1997                 tmp1[kx]  = -factor*m2k;
1998             }
1999
2000             for(kx=kxstart; kx<kxend; kx++)
2001             {
2002                 m2inv[kx] = 1.0/m2[kx];
2003             }
2004
2005             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2006
2007             for(kx=kxstart; kx<kxend; kx++,p0++)
2008             {
2009                 d1      = p0->re;
2010                 d2      = p0->im;
2011
2012                 p0->re  = d1*eterm[kx];
2013                 p0->im  = d2*eterm[kx];
2014
2015                 struct2 = 2.0*(d1*d1+d2*d2);
2016
2017                 tmp1[kx] = eterm[kx]*struct2;
2018             }
2019
2020             for(kx=kxstart; kx<kxend; kx++)
2021             {
2022                 ets2     = corner_fac*tmp1[kx];
2023                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2024                 energy  += ets2;
2025
2026                 ets2vf   = ets2*vfactor;
2027                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2028                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2029                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2030                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2031                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2032                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2033             }
2034         }
2035         else
2036         {
2037             /* We don't need to calculate the energy and the virial.
2038              * In this case the triclinic overhead is small.
2039              */
2040
2041             /* Two explicit loops to avoid a conditional inside the loop */
2042
2043             for(kx=kxstart; kx<maxkx; kx++)
2044             {
2045                 mx = kx;
2046
2047                 mhxk      = mx * rxx;
2048                 mhyk      = mx * ryx + my * ryy;
2049                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2050                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2051                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2052                 tmp1[kx]  = -factor*m2k;
2053             }
2054
2055             for(kx=maxkx; kx<kxend; kx++)
2056             {
2057                 mx = (kx - nx);
2058
2059                 mhxk      = mx * rxx;
2060                 mhyk      = mx * ryx + my * ryy;
2061                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2062                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2063                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2064                 tmp1[kx]  = -factor*m2k;
2065             }
2066
2067             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2068
2069             for(kx=kxstart; kx<kxend; kx++,p0++)
2070             {
2071                 d1      = p0->re;
2072                 d2      = p0->im;
2073
2074                 p0->re  = d1*eterm[kx];
2075                 p0->im  = d2*eterm[kx];
2076             }
2077         }
2078     }
2079
2080     if (bEnerVir)
2081     {
2082         /* Update virial with local values.
2083          * The virial is symmetric by definition.
2084          * this virial seems ok for isotropic scaling, but I'm
2085          * experiencing problems on semiisotropic membranes.
2086          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2087          */
2088         work->vir[XX][XX] = 0.25*virxx;
2089         work->vir[YY][YY] = 0.25*viryy;
2090         work->vir[ZZ][ZZ] = 0.25*virzz;
2091         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2092         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2093         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2094
2095         /* This energy should be corrected for a charged system */
2096         work->energy = 0.5*energy;
2097     }
2098
2099     /* Return the loop count */
2100     return local_ndata[YY]*local_ndata[XX];
2101 }
2102
2103 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2104                              real *mesh_energy,matrix vir)
2105 {
2106     /* This function sums output over threads
2107      * and should therefore only be called after thread synchronization.
2108      */
2109     int thread;
2110
2111     *mesh_energy = pme->work[0].energy;
2112     copy_mat(pme->work[0].vir,vir);
2113
2114     for(thread=1; thread<nthread; thread++)
2115     {
2116         *mesh_energy += pme->work[thread].energy;
2117         m_add(vir,pme->work[thread].vir,vir);
2118     }
2119 }
2120
2121 #define DO_FSPLINE(order)                      \
2122 for(ithx=0; (ithx<order); ithx++)              \
2123 {                                              \
2124     index_x = (i0+ithx)*pny*pnz;               \
2125     tx      = thx[ithx];                       \
2126     dx      = dthx[ithx];                      \
2127                                                \
2128     for(ithy=0; (ithy<order); ithy++)          \
2129     {                                          \
2130         index_xy = index_x+(j0+ithy)*pnz;      \
2131         ty       = thy[ithy];                  \
2132         dy       = dthy[ithy];                 \
2133         fxy1     = fz1 = 0;                    \
2134                                                \
2135         for(ithz=0; (ithz<order); ithz++)      \
2136         {                                      \
2137             gval  = grid[index_xy+(k0+ithz)];  \
2138             fxy1 += thz[ithz]*gval;            \
2139             fz1  += dthz[ithz]*gval;           \
2140         }                                      \
2141         fx += dx*ty*fxy1;                      \
2142         fy += tx*dy*fxy1;                      \
2143         fz += tx*ty*fz1;                       \
2144     }                                          \
2145 }
2146
2147
2148 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2149                               gmx_bool bClearF,pme_atomcomm_t *atc,
2150                               splinedata_t *spline,
2151                               real scale)
2152 {
2153     /* sum forces for local particles */
2154     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2155     int     index_x,index_xy;
2156     int     nx,ny,nz,pnx,pny,pnz;
2157     int *   idxptr;
2158     real    tx,ty,dx,dy,qn;
2159     real    fx,fy,fz,gval;
2160     real    fxy1,fz1;
2161     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2162     int     norder;
2163     real    rxx,ryx,ryy,rzx,rzy,rzz;
2164     int     order;
2165
2166     pme_spline_work_t *work;
2167
2168     work = &pme->spline_work;
2169
2170     order = pme->pme_order;
2171     thx   = spline->theta[XX];
2172     thy   = spline->theta[YY];
2173     thz   = spline->theta[ZZ];
2174     dthx  = spline->dtheta[XX];
2175     dthy  = spline->dtheta[YY];
2176     dthz  = spline->dtheta[ZZ];
2177     nx    = pme->nkx;
2178     ny    = pme->nky;
2179     nz    = pme->nkz;
2180     pnx   = pme->pmegrid_nx;
2181     pny   = pme->pmegrid_ny;
2182     pnz   = pme->pmegrid_nz;
2183
2184     rxx   = pme->recipbox[XX][XX];
2185     ryx   = pme->recipbox[YY][XX];
2186     ryy   = pme->recipbox[YY][YY];
2187     rzx   = pme->recipbox[ZZ][XX];
2188     rzy   = pme->recipbox[ZZ][YY];
2189     rzz   = pme->recipbox[ZZ][ZZ];
2190
2191     for(nn=0; nn<spline->n; nn++)
2192     {
2193         n  = spline->ind[nn];
2194         qn = scale*atc->q[n];
2195
2196         if (bClearF)
2197         {
2198             atc->f[n][XX] = 0;
2199             atc->f[n][YY] = 0;
2200             atc->f[n][ZZ] = 0;
2201         }
2202         if (qn != 0)
2203         {
2204             fx     = 0;
2205             fy     = 0;
2206             fz     = 0;
2207             idxptr = atc->idx[n];
2208             norder = nn*order;
2209
2210             i0   = idxptr[XX];
2211             j0   = idxptr[YY];
2212             k0   = idxptr[ZZ];
2213
2214             /* Pointer arithmetic alert, next six statements */
2215             thx  = spline->theta[XX] + norder;
2216             thy  = spline->theta[YY] + norder;
2217             thz  = spline->theta[ZZ] + norder;
2218             dthx = spline->dtheta[XX] + norder;
2219             dthy = spline->dtheta[YY] + norder;
2220             dthz = spline->dtheta[ZZ] + norder;
2221
2222             switch (order) {
2223             case 4:
2224 #ifdef PME_SSE
2225 #ifdef PME_SSE_UNALIGNED
2226 #define PME_GATHER_F_SSE_ORDER4
2227 #else
2228 #define PME_GATHER_F_SSE_ALIGNED
2229 #define PME_ORDER 4
2230 #endif
2231 #include "pme_sse_single.h"
2232 #else
2233                 DO_FSPLINE(4);
2234 #endif
2235                 break;
2236             case 5:
2237 #ifdef PME_SSE
2238 #define PME_GATHER_F_SSE_ALIGNED
2239 #define PME_ORDER 5
2240 #include "pme_sse_single.h"
2241 #else
2242                 DO_FSPLINE(5);
2243 #endif
2244                 break;
2245             default:
2246                 DO_FSPLINE(order);
2247                 break;
2248             }
2249
2250             atc->f[n][XX] += -qn*( fx*nx*rxx );
2251             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2252             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2253         }
2254     }
2255     /* Since the energy and not forces are interpolated
2256      * the net force might not be exactly zero.
2257      * This can be solved by also interpolating F, but
2258      * that comes at a cost.
2259      * A better hack is to remove the net force every
2260      * step, but that must be done at a higher level
2261      * since this routine doesn't see all atoms if running
2262      * in parallel. Don't know how important it is?  EL 990726
2263      */
2264 }
2265
2266
2267 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2268                                    pme_atomcomm_t *atc)
2269 {
2270     splinedata_t *spline;
2271     int     n,ithx,ithy,ithz,i0,j0,k0;
2272     int     index_x,index_xy;
2273     int *   idxptr;
2274     real    energy,pot,tx,ty,qn,gval;
2275     real    *thx,*thy,*thz;
2276     int     norder;
2277     int     order;
2278
2279     spline = &atc->spline[0];
2280
2281     order = pme->pme_order;
2282
2283     energy = 0;
2284     for(n=0; (n<atc->n); n++) {
2285         qn      = atc->q[n];
2286
2287         if (qn != 0) {
2288             idxptr = atc->idx[n];
2289             norder = n*order;
2290
2291             i0   = idxptr[XX];
2292             j0   = idxptr[YY];
2293             k0   = idxptr[ZZ];
2294
2295             /* Pointer arithmetic alert, next three statements */
2296             thx  = spline->theta[XX] + norder;
2297             thy  = spline->theta[YY] + norder;
2298             thz  = spline->theta[ZZ] + norder;
2299
2300             pot = 0;
2301             for(ithx=0; (ithx<order); ithx++)
2302             {
2303                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2304                 tx      = thx[ithx];
2305
2306                 for(ithy=0; (ithy<order); ithy++)
2307                 {
2308                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2309                     ty       = thy[ithy];
2310
2311                     for(ithz=0; (ithz<order); ithz++)
2312                     {
2313                         gval  = grid[index_xy+(k0+ithz)];
2314                         pot  += tx*ty*thz[ithz]*gval;
2315                     }
2316
2317                 }
2318             }
2319
2320             energy += pot*qn;
2321         }
2322     }
2323
2324     return energy;
2325 }
2326
2327 /* Macro to force loop unrolling by fixing order.
2328  * This gives a significant performance gain.
2329  */
2330 #define CALC_SPLINE(order)                     \
2331 {                                              \
2332     int j,k,l;                                 \
2333     real dr,div;                               \
2334     real data[PME_ORDER_MAX];                  \
2335     real ddata[PME_ORDER_MAX];                 \
2336                                                \
2337     for(j=0; (j<DIM); j++)                     \
2338     {                                          \
2339         dr  = xptr[j];                         \
2340                                                \
2341         /* dr is relative offset from lower cell limit */ \
2342         data[order-1] = 0;                     \
2343         data[1] = dr;                          \
2344         data[0] = 1 - dr;                      \
2345                                                \
2346         for(k=3; (k<order); k++)               \
2347         {                                      \
2348             div = 1.0/(k - 1.0);               \
2349             data[k-1] = div*dr*data[k-2];      \
2350             for(l=1; (l<(k-1)); l++)           \
2351             {                                  \
2352                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2353                                    data[k-l-1]);                \
2354             }                                  \
2355             data[0] = div*(1-dr)*data[0];      \
2356         }                                      \
2357         /* differentiate */                    \
2358         ddata[0] = -data[0];                   \
2359         for(k=1; (k<order); k++)               \
2360         {                                      \
2361             ddata[k] = data[k-1] - data[k];    \
2362         }                                      \
2363                                                \
2364         div = 1.0/(order - 1);                 \
2365         data[order-1] = div*dr*data[order-2];  \
2366         for(l=1; (l<(order-1)); l++)           \
2367         {                                      \
2368             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2369                                (order-l-dr)*data[order-l-1]); \
2370         }                                      \
2371         data[0] = div*(1 - dr)*data[0];        \
2372                                                \
2373         for(k=0; k<order; k++)                 \
2374         {                                      \
2375             theta[j][i*order+k]  = data[k];    \
2376             dtheta[j][i*order+k] = ddata[k];   \
2377         }                                      \
2378     }                                          \
2379 }
2380
2381 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2382                    rvec fractx[],int nr,int ind[],real charge[],
2383                    gmx_bool bFreeEnergy)
2384 {
2385     /* construct splines for local atoms */
2386     int  i,ii;
2387     real *xptr;
2388
2389     for(i=0; i<nr; i++)
2390     {
2391         /* With free energy we do not use the charge check.
2392          * In most cases this will be more efficient than calling make_bsplines
2393          * twice, since usually more than half the particles have charges.
2394          */
2395         ii = ind[i];
2396         if (bFreeEnergy || charge[ii] != 0.0) {
2397             xptr = fractx[ii];
2398             switch(order) {
2399             case 4:  CALC_SPLINE(4);     break;
2400             case 5:  CALC_SPLINE(5);     break;
2401             default: CALC_SPLINE(order); break;
2402             }
2403         }
2404     }
2405 }
2406
2407
2408 void make_dft_mod(real *mod,real *data,int ndata)
2409 {
2410   int i,j;
2411   real sc,ss,arg;
2412
2413   for(i=0;i<ndata;i++) {
2414     sc=ss=0;
2415     for(j=0;j<ndata;j++) {
2416       arg=(2.0*M_PI*i*j)/ndata;
2417       sc+=data[j]*cos(arg);
2418       ss+=data[j]*sin(arg);
2419     }
2420     mod[i]=sc*sc+ss*ss;
2421   }
2422   for(i=0;i<ndata;i++)
2423     if(mod[i]<1e-7)
2424       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2425 }
2426
2427
2428 static void make_bspline_moduli(splinevec bsp_mod,
2429                                 int nx,int ny,int nz,int order)
2430 {
2431   int nmax=max(nx,max(ny,nz));
2432   real *data,*ddata,*bsp_data;
2433   int i,k,l;
2434   real div;
2435
2436   snew(data,order);
2437   snew(ddata,order);
2438   snew(bsp_data,nmax);
2439
2440   data[order-1]=0;
2441   data[1]=0;
2442   data[0]=1;
2443
2444   for(k=3;k<order;k++) {
2445     div=1.0/(k-1.0);
2446     data[k-1]=0;
2447     for(l=1;l<(k-1);l++)
2448       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2449     data[0]=div*data[0];
2450   }
2451   /* differentiate */
2452   ddata[0]=-data[0];
2453   for(k=1;k<order;k++)
2454     ddata[k]=data[k-1]-data[k];
2455   div=1.0/(order-1);
2456   data[order-1]=0;
2457   for(l=1;l<(order-1);l++)
2458     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2459   data[0]=div*data[0];
2460
2461   for(i=0;i<nmax;i++)
2462     bsp_data[i]=0;
2463   for(i=1;i<=order;i++)
2464     bsp_data[i]=data[i-1];
2465
2466   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2467   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2468   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2469
2470   sfree(data);
2471   sfree(ddata);
2472   sfree(bsp_data);
2473 }
2474
2475
2476 /* Return the P3M optimal influence function */
2477 static double do_p3m_influence(double z, int order)
2478 {
2479     double z2,z4;
2480
2481     z2 = z*z;
2482     z4 = z2*z2;
2483
2484     /* The formula and most constants can be found in:
2485      * Ballenegger et al., JCTC 8, 936 (2012)
2486      */
2487     switch(order)
2488     {
2489     case 2:
2490         return 1.0 - 2.0*z2/3.0;
2491         break;
2492     case 3:
2493         return 1.0 - z2 + 2.0*z4/15.0;
2494         break;
2495     case 4:
2496         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2497         break;
2498     case 5:
2499         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2500         break;
2501     case 6:
2502         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2503         break;
2504     case 7:
2505         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2506     case 8:
2507         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2508         break;
2509     }
2510
2511     return 0.0;
2512 }
2513
2514 /* Calculate the P3M B-spline moduli for one dimension */
2515 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2516 {
2517     double zarg,zai,sinzai,infl;
2518     int    maxk,i;
2519
2520     if (order > 8)
2521     {
2522         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2523     }
2524
2525     zarg = M_PI/n;
2526
2527     maxk = (n + 1)/2;
2528
2529     for(i=-maxk; i<0; i++)
2530     {
2531         zai    = zarg*i;
2532         sinzai = sin(zai);
2533         infl   = do_p3m_influence(sinzai,order);
2534         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2535     }
2536     bsp_mod[0] = 1.0;
2537     for(i=1; i<maxk; i++)
2538     {
2539         zai    = zarg*i;
2540         sinzai = sin(zai);
2541         infl   = do_p3m_influence(sinzai,order);
2542         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2543     }
2544 }
2545
2546 /* Calculate the P3M B-spline moduli */
2547 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2548                                     int nx,int ny,int nz,int order)
2549 {
2550     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2551     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2552     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2553 }
2554
2555
2556 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2557 {
2558   int nslab,n,i;
2559   int fw,bw;
2560
2561   nslab = atc->nslab;
2562
2563   n = 0;
2564   for(i=1; i<=nslab/2; i++) {
2565     fw = (atc->nodeid + i) % nslab;
2566     bw = (atc->nodeid - i + nslab) % nslab;
2567     if (n < nslab - 1) {
2568       atc->node_dest[n] = fw;
2569       atc->node_src[n]  = bw;
2570       n++;
2571     }
2572     if (n < nslab - 1) {
2573       atc->node_dest[n] = bw;
2574       atc->node_src[n]  = fw;
2575       n++;
2576     }
2577   }
2578 }
2579
2580 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2581 {
2582     int thread;
2583
2584     if(NULL != log)
2585     {
2586         fprintf(log,"Destroying PME data structures.\n");
2587     }
2588
2589     sfree((*pmedata)->nnx);
2590     sfree((*pmedata)->nny);
2591     sfree((*pmedata)->nnz);
2592
2593     pmegrids_destroy(&(*pmedata)->pmegridA);
2594
2595     sfree((*pmedata)->fftgridA);
2596     sfree((*pmedata)->cfftgridA);
2597     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2598
2599     if ((*pmedata)->pmegridB.grid.grid != NULL)
2600     {
2601         pmegrids_destroy(&(*pmedata)->pmegridB);
2602         sfree((*pmedata)->fftgridB);
2603         sfree((*pmedata)->cfftgridB);
2604         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2605     }
2606     for(thread=0; thread<(*pmedata)->nthread; thread++)
2607     {
2608         free_work(&(*pmedata)->work[thread]);
2609     }
2610     sfree((*pmedata)->work);
2611
2612     sfree(*pmedata);
2613     *pmedata = NULL;
2614
2615   return 0;
2616 }
2617
2618 static int mult_up(int n,int f)
2619 {
2620     return ((n + f - 1)/f)*f;
2621 }
2622
2623
2624 static double pme_load_imbalance(gmx_pme_t pme)
2625 {
2626     int    nma,nmi;
2627     double n1,n2,n3;
2628
2629     nma = pme->nnodes_major;
2630     nmi = pme->nnodes_minor;
2631
2632     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2633     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2634     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2635
2636     /* pme_solve is roughly double the cost of an fft */
2637
2638     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2639 }
2640
2641 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2642                           int dimind,gmx_bool bSpread)
2643 {
2644     int nk,k,s,thread;
2645
2646     atc->dimind = dimind;
2647     atc->nslab  = 1;
2648     atc->nodeid = 0;
2649     atc->pd_nalloc = 0;
2650 #ifdef GMX_MPI
2651     if (pme->nnodes > 1)
2652     {
2653         atc->mpi_comm = pme->mpi_comm_d[dimind];
2654         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2655         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2656     }
2657     if (debug)
2658     {
2659         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2660     }
2661 #endif
2662
2663     atc->bSpread   = bSpread;
2664     atc->pme_order = pme->pme_order;
2665
2666     if (atc->nslab > 1)
2667     {
2668         /* These three allocations are not required for particle decomp. */
2669         snew(atc->node_dest,atc->nslab);
2670         snew(atc->node_src,atc->nslab);
2671         setup_coordinate_communication(atc);
2672
2673         snew(atc->count_thread,pme->nthread);
2674         for(thread=0; thread<pme->nthread; thread++)
2675         {
2676             snew(atc->count_thread[thread],atc->nslab);
2677         }
2678         atc->count = atc->count_thread[0];
2679         snew(atc->rcount,atc->nslab);
2680         snew(atc->buf_index,atc->nslab);
2681     }
2682
2683     atc->nthread = pme->nthread;
2684     if (atc->nthread > 1)
2685     {
2686         snew(atc->thread_plist,atc->nthread);
2687     }
2688     snew(atc->spline,atc->nthread);
2689     for(thread=0; thread<atc->nthread; thread++)
2690     {
2691         if (atc->nthread > 1)
2692         {
2693             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2694             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2695         }
2696     }
2697 }
2698
2699 static void
2700 init_overlap_comm(pme_overlap_t *  ol,
2701                   int              norder,
2702 #ifdef GMX_MPI
2703                   MPI_Comm         comm,
2704 #endif
2705                   int              nnodes,
2706                   int              nodeid,
2707                   int              ndata,
2708                   int              commplainsize)
2709 {
2710     int lbnd,rbnd,maxlr,b,i;
2711     int exten;
2712     int nn,nk;
2713     pme_grid_comm_t *pgc;
2714     gmx_bool bCont;
2715     int fft_start,fft_end,send_index1,recv_index1;
2716
2717 #ifdef GMX_MPI
2718     ol->mpi_comm = comm;
2719 #endif
2720
2721     ol->nnodes = nnodes;
2722     ol->nodeid = nodeid;
2723
2724     /* Linear translation of the PME grid wo'nt affect reciprocal space
2725      * calculations, so to optimize we only interpolate "upwards",
2726      * which also means we only have to consider overlap in one direction.
2727      * I.e., particles on this node might also be spread to grid indices
2728      * that belong to higher nodes (modulo nnodes)
2729      */
2730
2731     snew(ol->s2g0,ol->nnodes+1);
2732     snew(ol->s2g1,ol->nnodes);
2733     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2734     for(i=0; i<nnodes; i++)
2735     {
2736         /* s2g0 the local interpolation grid start.
2737          * s2g1 the local interpolation grid end.
2738          * Because grid overlap communication only goes forward,
2739          * the grid the slabs for fft's should be rounded down.
2740          */
2741         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2742         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2743
2744         if (debug)
2745         {
2746             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2747         }
2748     }
2749     ol->s2g0[nnodes] = ndata;
2750     if (debug) { fprintf(debug,"\n"); }
2751
2752     /* Determine with how many nodes we need to communicate the grid overlap */
2753     b = 0;
2754     do
2755     {
2756         b++;
2757         bCont = FALSE;
2758         for(i=0; i<nnodes; i++)
2759         {
2760             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2761                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2762             {
2763                 bCont = TRUE;
2764             }
2765         }
2766     }
2767     while (bCont && b < nnodes);
2768     ol->noverlap_nodes = b - 1;
2769
2770     snew(ol->send_id,ol->noverlap_nodes);
2771     snew(ol->recv_id,ol->noverlap_nodes);
2772     for(b=0; b<ol->noverlap_nodes; b++)
2773     {
2774         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2775         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2776     }
2777     snew(ol->comm_data, ol->noverlap_nodes);
2778
2779     for(b=0; b<ol->noverlap_nodes; b++)
2780     {
2781         pgc = &ol->comm_data[b];
2782         /* Send */
2783         fft_start        = ol->s2g0[ol->send_id[b]];
2784         fft_end          = ol->s2g0[ol->send_id[b]+1];
2785         if (ol->send_id[b] < nodeid)
2786         {
2787             fft_start += ndata;
2788             fft_end   += ndata;
2789         }
2790         send_index1      = ol->s2g1[nodeid];
2791         send_index1      = min(send_index1,fft_end);
2792         pgc->send_index0 = fft_start;
2793         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2794
2795         /* We always start receiving to the first index of our slab */
2796         fft_start        = ol->s2g0[ol->nodeid];
2797         fft_end          = ol->s2g0[ol->nodeid+1];
2798         recv_index1      = ol->s2g1[ol->recv_id[b]];
2799         if (ol->recv_id[b] > nodeid)
2800         {
2801             recv_index1 -= ndata;
2802         }
2803         recv_index1      = min(recv_index1,fft_end);
2804         pgc->recv_index0 = fft_start;
2805         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2806     }
2807
2808     /* For non-divisible grid we need pme_order iso pme_order-1 */
2809     snew(ol->sendbuf,norder*commplainsize);
2810     snew(ol->recvbuf,norder*commplainsize);
2811 }
2812
2813 static void
2814 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2815                               int **global_to_local,
2816                               real **fraction_shift)
2817 {
2818     int i;
2819     int * gtl;
2820     real * fsh;
2821
2822     snew(gtl,5*n);
2823     snew(fsh,5*n);
2824     for(i=0; (i<5*n); i++)
2825     {
2826         /* Determine the global to local grid index */
2827         gtl[i] = (i - local_start + n) % n;
2828         /* For coordinates that fall within the local grid the fraction
2829          * is correct, we don't need to shift it.
2830          */
2831         fsh[i] = 0;
2832         if (local_range < n)
2833         {
2834             /* Due to rounding issues i could be 1 beyond the lower or
2835              * upper boundary of the local grid. Correct the index for this.
2836              * If we shift the index, we need to shift the fraction by
2837              * the same amount in the other direction to not affect
2838              * the weights.
2839              * Note that due to this shifting the weights at the end of
2840              * the spline might change, but that will only involve values
2841              * between zero and values close to the precision of a real,
2842              * which is anyhow the accuracy of the whole mesh calculation.
2843              */
2844             /* With local_range=0 we should not change i=local_start */
2845             if (i % n != local_start)
2846             {
2847                 if (gtl[i] == n-1)
2848                 {
2849                     gtl[i] = 0;
2850                     fsh[i] = -1;
2851                 }
2852                 else if (gtl[i] == local_range)
2853                 {
2854                     gtl[i] = local_range - 1;
2855                     fsh[i] = 1;
2856                 }
2857             }
2858         }
2859     }
2860
2861     *global_to_local = gtl;
2862     *fraction_shift  = fsh;
2863 }
2864
2865 static void sse_mask_init(pme_spline_work_t *work,int order)
2866 {
2867 #ifdef PME_SSE
2868     float  tmp[8];
2869     __m128 zero_SSE;
2870     int    of,i;
2871
2872     zero_SSE = _mm_setzero_ps();
2873
2874     for(of=0; of<8-(order-1); of++)
2875     {
2876         for(i=0; i<8; i++)
2877         {
2878             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2879         }
2880         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2881         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2882         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2883         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2884     }
2885 #endif
2886 }
2887
2888 static void
2889 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2890 {
2891     int nk_new;
2892
2893     if (*nk % nnodes != 0)
2894     {
2895         nk_new = nnodes*(*nk/nnodes + 1);
2896
2897         if (2*nk_new >= 3*(*nk))
2898         {
2899             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2900                       dim,*nk,dim,nnodes,dim);
2901         }
2902
2903         if (fplog != NULL)
2904         {
2905             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2906                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2907         }
2908
2909         *nk = nk_new;
2910     }
2911 }
2912
2913 int gmx_pme_init(gmx_pme_t *         pmedata,
2914                  t_commrec *         cr,
2915                  int                 nnodes_major,
2916                  int                 nnodes_minor,
2917                  t_inputrec *        ir,
2918                  int                 homenr,
2919                  gmx_bool            bFreeEnergy,
2920                  gmx_bool            bReproducible,
2921                  int                 nthread)
2922 {
2923     gmx_pme_t pme=NULL;
2924
2925     pme_atomcomm_t *atc;
2926     ivec ndata;
2927
2928     if (debug)
2929         fprintf(debug,"Creating PME data structures.\n");
2930     snew(pme,1);
2931
2932     pme->redist_init         = FALSE;
2933     pme->sum_qgrid_tmp       = NULL;
2934     pme->sum_qgrid_dd_tmp    = NULL;
2935     pme->buf_nalloc          = 0;
2936     pme->redist_buf_nalloc   = 0;
2937
2938     pme->nnodes              = 1;
2939     pme->bPPnode             = TRUE;
2940
2941     pme->nnodes_major        = nnodes_major;
2942     pme->nnodes_minor        = nnodes_minor;
2943
2944 #ifdef GMX_MPI
2945     if (nnodes_major*nnodes_minor > 1)
2946     {
2947         pme->mpi_comm = cr->mpi_comm_mygroup;
2948
2949         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2950         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2951         if (pme->nnodes != nnodes_major*nnodes_minor)
2952         {
2953             gmx_incons("PME node count mismatch");
2954         }
2955     }
2956     else
2957     {
2958         pme->mpi_comm = MPI_COMM_NULL;
2959     }
2960 #endif
2961
2962     if (pme->nnodes == 1)
2963     {
2964         pme->ndecompdim = 0;
2965         pme->nodeid_major = 0;
2966         pme->nodeid_minor = 0;
2967 #ifdef GMX_MPI
2968         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2969 #endif
2970     }
2971     else
2972     {
2973         if (nnodes_minor == 1)
2974         {
2975 #ifdef GMX_MPI
2976             pme->mpi_comm_d[0] = pme->mpi_comm;
2977             pme->mpi_comm_d[1] = MPI_COMM_NULL;
2978 #endif
2979             pme->ndecompdim = 1;
2980             pme->nodeid_major = pme->nodeid;
2981             pme->nodeid_minor = 0;
2982
2983         }
2984         else if (nnodes_major == 1)
2985         {
2986 #ifdef GMX_MPI
2987             pme->mpi_comm_d[0] = MPI_COMM_NULL;
2988             pme->mpi_comm_d[1] = pme->mpi_comm;
2989 #endif
2990             pme->ndecompdim = 1;
2991             pme->nodeid_major = 0;
2992             pme->nodeid_minor = pme->nodeid;
2993         }
2994         else
2995         {
2996             if (pme->nnodes % nnodes_major != 0)
2997             {
2998                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
2999             }
3000             pme->ndecompdim = 2;
3001
3002 #ifdef GMX_MPI
3003             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3004                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3005             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3006                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3007
3008             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3009             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3010             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3011             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3012 #endif
3013         }
3014         pme->bPPnode = (cr->duty & DUTY_PP);
3015     }
3016
3017     pme->nthread = nthread;
3018
3019     if (ir->ePBC == epbcSCREW)
3020     {
3021         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3022     }
3023
3024     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3025     pme->nkx         = ir->nkx;
3026     pme->nky         = ir->nky;
3027     pme->nkz         = ir->nkz;
3028     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3029     pme->pme_order   = ir->pme_order;
3030     pme->epsilon_r   = ir->epsilon_r;
3031
3032     if (pme->pme_order > PME_ORDER_MAX)
3033     {
3034         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3035                   pme->pme_order,PME_ORDER_MAX);
3036     }
3037
3038     /* Currently pme.c supports only the fft5d FFT code.
3039      * Therefore the grid always needs to be divisible by nnodes.
3040      * When the old 1D code is also supported again, change this check.
3041      *
3042      * This check should be done before calling gmx_pme_init
3043      * and fplog should be passed iso stderr.
3044      *
3045     if (pme->ndecompdim >= 2)
3046     */
3047     if (pme->ndecompdim >= 1)
3048     {
3049         /*
3050         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3051                                         'x',nnodes_major,&pme->nkx);
3052         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3053                                         'y',nnodes_minor,&pme->nky);
3054         */
3055     }
3056
3057     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3058         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3059         pme->nkz <= pme->pme_order)
3060     {
3061         gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3062     }
3063
3064     if (pme->nnodes > 1) {
3065         double imbal;
3066
3067 #ifdef GMX_MPI
3068         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3069         MPI_Type_commit(&(pme->rvec_mpi));
3070 #endif
3071
3072         /* Note that the charge spreading and force gathering, which usually
3073          * takes about the same amount of time as FFT+solve_pme,
3074          * is always fully load balanced
3075          * (unless the charge distribution is inhomogeneous).
3076          */
3077
3078         imbal = pme_load_imbalance(pme);
3079         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3080         {
3081             fprintf(stderr,
3082                     "\n"
3083                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3084                     "      For optimal PME load balancing\n"
3085                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3086                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3087                     "\n",
3088                     (int)((imbal-1)*100 + 0.5),
3089                     pme->nkx,pme->nky,pme->nnodes_major,
3090                     pme->nky,pme->nkz,pme->nnodes_minor);
3091         }
3092     }
3093
3094     /* For non-divisible grid we need pme_order iso pme_order-1 */
3095     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3096      * y is always copied through a buffer: we don't need padding in z,
3097      * but we do need the overlap in x because of the communication order.
3098      */
3099     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3100 #ifdef GMX_MPI
3101                       pme->mpi_comm_d[0],
3102 #endif
3103                       pme->nnodes_major,pme->nodeid_major,
3104                       pme->nkx,
3105                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3106
3107     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3108 #ifdef GMX_MPI
3109                       pme->mpi_comm_d[1],
3110 #endif
3111                       pme->nnodes_minor,pme->nodeid_minor,
3112                       pme->nky,
3113                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3114
3115     /* Check for a limitation of the (current) sum_fftgrid_dd code */
3116     if (pme->nthread > 1 &&
3117         (pme->overlap[0].noverlap_nodes > 1 ||
3118          pme->overlap[1].noverlap_nodes > 1))
3119     {
3120         gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3121     }
3122
3123     snew(pme->bsp_mod[XX],pme->nkx);
3124     snew(pme->bsp_mod[YY],pme->nky);
3125     snew(pme->bsp_mod[ZZ],pme->nkz);
3126
3127     /* The required size of the interpolation grid, including overlap.
3128      * The allocated size (pmegrid_n?) might be slightly larger.
3129      */
3130     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3131                       pme->overlap[0].s2g0[pme->nodeid_major];
3132     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3133                       pme->overlap[1].s2g0[pme->nodeid_minor];
3134     pme->pmegrid_nz_base = pme->nkz;
3135     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3136     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3137
3138     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3139     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3140     pme->pmegrid_start_iz = 0;
3141
3142     make_gridindex5_to_localindex(pme->nkx,
3143                                   pme->pmegrid_start_ix,
3144                                   pme->pmegrid_nx - (pme->pme_order-1),
3145                                   &pme->nnx,&pme->fshx);
3146     make_gridindex5_to_localindex(pme->nky,
3147                                   pme->pmegrid_start_iy,
3148                                   pme->pmegrid_ny - (pme->pme_order-1),
3149                                   &pme->nny,&pme->fshy);
3150     make_gridindex5_to_localindex(pme->nkz,
3151                                   pme->pmegrid_start_iz,
3152                                   pme->pmegrid_nz_base,
3153                                   &pme->nnz,&pme->fshz);
3154
3155     pmegrids_init(&pme->pmegridA,
3156                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3157                   pme->pmegrid_nz_base,
3158                   pme->pme_order,
3159                   pme->nthread,
3160                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3161                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3162
3163     sse_mask_init(&pme->spline_work,pme->pme_order);
3164
3165     ndata[0] = pme->nkx;
3166     ndata[1] = pme->nky;
3167     ndata[2] = pme->nkz;
3168
3169     /* This routine will allocate the grid data to fit the FFTs */
3170     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3171                             &pme->fftgridA,&pme->cfftgridA,
3172                             pme->mpi_comm_d,
3173                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3174                             bReproducible,pme->nthread);
3175
3176     if (bFreeEnergy)
3177     {
3178         pmegrids_init(&pme->pmegridB,
3179                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3180                       pme->pmegrid_nz_base,
3181                       pme->pme_order,
3182                       pme->nthread,
3183                       pme->nkx % pme->nnodes_major != 0,
3184                       pme->nky % pme->nnodes_minor != 0);
3185
3186         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3187                                 &pme->fftgridB,&pme->cfftgridB,
3188                                 pme->mpi_comm_d,
3189                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3190                                 bReproducible,pme->nthread);
3191     }
3192     else
3193     {
3194         pme->pmegridB.grid.grid = NULL;
3195         pme->fftgridB           = NULL;
3196         pme->cfftgridB          = NULL;
3197     }
3198
3199     if (!pme->bP3M)
3200     {
3201         /* Use plain SPME B-spline interpolation */
3202         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3203     }
3204     else
3205     {
3206         /* Use the P3M grid-optimized influence function */
3207         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3208     }
3209
3210     /* Use atc[0] for spreading */
3211     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3212     if (pme->ndecompdim >= 2)
3213     {
3214         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3215     }
3216
3217     if (pme->nnodes == 1) {
3218         pme->atc[0].n = homenr;
3219         pme_realloc_atomcomm_things(&pme->atc[0]);
3220     }
3221
3222     {
3223         int thread;
3224
3225         /* Use fft5d, order after FFT is y major, z, x minor */
3226
3227         snew(pme->work,pme->nthread);
3228         for(thread=0; thread<pme->nthread; thread++)
3229         {
3230             realloc_work(&pme->work[thread],pme->nkx);
3231         }
3232     }
3233
3234     *pmedata = pme;
3235
3236     return 0;
3237 }
3238
3239
3240 static void copy_local_grid(gmx_pme_t pme,
3241                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3242 {
3243     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3244     int  fft_my,fft_mz;
3245     int  nsx,nsy,nsz;
3246     ivec nf;
3247     int  offx,offy,offz,x,y,z,i0,i0t;
3248     int  d;
3249     pmegrid_t *pmegrid;
3250     real *grid_th;
3251
3252     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3253                                    local_fft_ndata,
3254                                    local_fft_offset,
3255                                    local_fft_size);
3256     fft_my = local_fft_size[YY];
3257     fft_mz = local_fft_size[ZZ];
3258
3259     pmegrid = &pmegrids->grid_th[thread];
3260
3261     nsx = pmegrid->n[XX];
3262     nsy = pmegrid->n[YY];
3263     nsz = pmegrid->n[ZZ];
3264
3265     for(d=0; d<DIM; d++)
3266     {
3267         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3268                     local_fft_ndata[d] - pmegrid->offset[d]);
3269     }
3270
3271     offx = pmegrid->offset[XX];
3272     offy = pmegrid->offset[YY];
3273     offz = pmegrid->offset[ZZ];
3274
3275     /* Directly copy the non-overlapping parts of the local grids.
3276      * This also initializes the full grid.
3277      */
3278     grid_th = pmegrid->grid;
3279     for(x=0; x<nf[XX]; x++)
3280     {
3281         for(y=0; y<nf[YY]; y++)
3282         {
3283             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3284             i0t = (x*nsy + y)*nsz;
3285             for(z=0; z<nf[ZZ]; z++)
3286             {
3287                 fftgrid[i0+z] = grid_th[i0t+z];
3288             }
3289         }
3290     }
3291 }
3292
3293 static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3294 {
3295     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3296     pme_overlap_t *overlap;
3297     int datasize,nind;
3298     int i,x,y,z,n;
3299
3300     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3301                                    local_fft_ndata,
3302                                    local_fft_offset,
3303                                    local_fft_size);
3304     /* Major dimension */
3305     overlap = &pme->overlap[0];
3306
3307     nind   = overlap->comm_data[0].send_nindex;
3308
3309     for(y=0; y<local_fft_ndata[YY]; y++) {
3310          printf(" %2d",y);
3311     }
3312     printf("\n");
3313
3314     i = 0;
3315     for(x=0; x<nind; x++) {
3316         for(y=0; y<local_fft_ndata[YY]; y++) {
3317             n = 0;
3318             for(z=0; z<local_fft_ndata[ZZ]; z++) {
3319                 if (sendbuf[i] != 0) n++;
3320                 i++;
3321             }
3322             printf(" %2d",n);
3323         }
3324         printf("\n");
3325     }
3326 }
3327
3328 static void
3329 reduce_threadgrid_overlap(gmx_pme_t pme,
3330                           const pmegrids_t *pmegrids,int thread,
3331                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3332 {
3333     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3334     int  fft_nx,fft_ny,fft_nz;
3335     int  fft_my,fft_mz;
3336     int  buf_my=-1;
3337     int  nsx,nsy,nsz;
3338     ivec ne;
3339     int  offx,offy,offz,x,y,z,i0,i0t;
3340     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3341     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3342     gmx_bool bCommX,bCommY;
3343     int  d;
3344     int  thread_f;
3345     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3346     const real *grid_th;
3347     real *commbuf=NULL;
3348
3349     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3350                                    local_fft_ndata,
3351                                    local_fft_offset,
3352                                    local_fft_size);
3353     fft_nx = local_fft_ndata[XX];
3354     fft_ny = local_fft_ndata[YY];
3355     fft_nz = local_fft_ndata[ZZ];
3356
3357     fft_my = local_fft_size[YY];
3358     fft_mz = local_fft_size[ZZ];
3359
3360     /* This routine is called when all thread have finished spreading.
3361      * Here each thread sums grid contributions calculated by other threads
3362      * to the thread local grid volume.
3363      * To minimize the number of grid copying operations,
3364      * this routines sums immediately from the pmegrid to the fftgrid.
3365      */
3366
3367     /* Determine which part of the full node grid we should operate on,
3368      * this is our thread local part of the full grid.
3369      */
3370     pmegrid = &pmegrids->grid_th[thread];
3371
3372     for(d=0; d<DIM; d++)
3373     {
3374         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3375                     local_fft_ndata[d]);
3376     }
3377
3378     offx = pmegrid->offset[XX];
3379     offy = pmegrid->offset[YY];
3380     offz = pmegrid->offset[ZZ];
3381
3382
3383     bClearBufX  = TRUE;
3384     bClearBufY  = TRUE;
3385     bClearBufXY = TRUE;
3386
3387     /* Now loop over all the thread data blocks that contribute
3388      * to the grid region we (our thread) are operating on.
3389      */
3390     /* Note that ffy_nx/y is equal to the number of grid points
3391      * between the first point of our node grid and the one of the next node.
3392      */
3393     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3394     {
3395         fx = pmegrid->ci[XX] + sx;
3396         ox = 0;
3397         bCommX = FALSE;
3398         if (fx < 0) {
3399             fx += pmegrids->nc[XX];
3400             ox -= fft_nx;
3401             bCommX = (pme->nnodes_major > 1);
3402         }
3403         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3404         ox += pmegrid_g->offset[XX];
3405         if (!bCommX)
3406         {
3407             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3408         }
3409         else
3410         {
3411             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3412         }
3413
3414         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3415         {
3416             fy = pmegrid->ci[YY] + sy;
3417             oy = 0;
3418             bCommY = FALSE;
3419             if (fy < 0) {
3420                 fy += pmegrids->nc[YY];
3421                 oy -= fft_ny;
3422                 bCommY = (pme->nnodes_minor > 1);
3423             }
3424             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3425             oy += pmegrid_g->offset[YY];
3426             if (!bCommY)
3427             {
3428                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3429             }
3430             else
3431             {
3432                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3433             }
3434
3435             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3436             {
3437                 fz = pmegrid->ci[ZZ] + sz;
3438                 oz = 0;
3439                 if (fz < 0)
3440                 {
3441                     fz += pmegrids->nc[ZZ];
3442                     oz -= fft_nz;
3443                 }
3444                 pmegrid_g = &pmegrids->grid_th[fz];
3445                 oz += pmegrid_g->offset[ZZ];
3446                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3447
3448                 if (sx == 0 && sy == 0 && sz == 0)
3449                 {
3450                     /* We have already added our local contribution
3451                      * before calling this routine, so skip it here.
3452                      */
3453                     continue;
3454                 }
3455
3456                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3457
3458                 pmegrid_f = &pmegrids->grid_th[thread_f];
3459
3460                 grid_th = pmegrid_f->grid;
3461
3462                 nsx = pmegrid_f->n[XX];
3463                 nsy = pmegrid_f->n[YY];
3464                 nsz = pmegrid_f->n[ZZ];
3465
3466 #ifdef DEBUG_PME_REDUCE
3467                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3468                        pme->nodeid,thread,thread_f,
3469                        pme->pmegrid_start_ix,
3470                        pme->pmegrid_start_iy,
3471                        pme->pmegrid_start_iz,
3472                        sx,sy,sz,
3473                        offx-ox,tx1-ox,offx,tx1,
3474                        offy-oy,ty1-oy,offy,ty1,
3475                        offz-oz,tz1-oz,offz,tz1);
3476 #endif
3477
3478                 if (!(bCommX || bCommY))
3479                 {
3480                     /* Copy from the thread local grid to the node grid */
3481                     for(x=offx; x<tx1; x++)
3482                     {
3483                         for(y=offy; y<ty1; y++)
3484                         {
3485                             i0  = (x*fft_my + y)*fft_mz;
3486                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3487                             for(z=offz; z<tz1; z++)
3488                             {
3489                                 fftgrid[i0+z] += grid_th[i0t+z];
3490                             }
3491                         }
3492                     }
3493                 }
3494                 else
3495                 {
3496                     /* The order of this conditional decides
3497                      * where the corner volume gets stored with x+y decomp.
3498                      */
3499                     if (bCommY)
3500                     {
3501                         commbuf = commbuf_y;
3502                         buf_my  = ty1 - offy;
3503                         if (bCommX)
3504                         {
3505                             /* We index commbuf modulo the local grid size */
3506                             commbuf += buf_my*fft_nx*fft_nz;
3507
3508                             bClearBuf  = bClearBufXY;
3509                             bClearBufXY = FALSE;
3510                         }
3511                         else
3512                         {
3513                             bClearBuf  = bClearBufY;
3514                             bClearBufY = FALSE;
3515                         }
3516                     }
3517                     else
3518                     {
3519                         commbuf = commbuf_x;
3520                         buf_my  = fft_ny;
3521                         bClearBuf  = bClearBufX;
3522                         bClearBufX = FALSE;
3523                     }
3524
3525                     /* Copy to the communication buffer */
3526                     for(x=offx; x<tx1; x++)
3527                     {
3528                         for(y=offy; y<ty1; y++)
3529                         {
3530                             i0  = (x*buf_my + y)*fft_nz;
3531                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3532
3533                             if (bClearBuf)
3534                             {
3535                                 /* First access of commbuf, initialize it */
3536                                 for(z=offz; z<tz1; z++)
3537                                 {
3538                                     commbuf[i0+z]  = grid_th[i0t+z];
3539                                 }
3540                             }
3541                             else
3542                             {
3543                                 for(z=offz; z<tz1; z++)
3544                                 {
3545                                     commbuf[i0+z] += grid_th[i0t+z];
3546                                 }
3547                             }
3548                         }
3549                     }
3550                 }
3551             }
3552         }
3553     }
3554 }
3555
3556
3557 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3558 {
3559     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3560     pme_overlap_t *overlap;
3561     int  send_nindex;
3562     int  recv_index0,recv_nindex;
3563 #ifdef GMX_MPI
3564     MPI_Status stat;
3565 #endif
3566     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3567     real *sendptr,*recvptr;
3568     int  x,y,z,indg,indb;
3569
3570     /* Note that this routine is only used for forward communication.
3571      * Since the force gathering, unlike the charge spreading,
3572      * can be trivially parallelized over the particles,
3573      * the backwards process is much simpler and can use the "old"
3574      * communication setup.
3575      */
3576
3577     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3578                                    local_fft_ndata,
3579                                    local_fft_offset,
3580                                    local_fft_size);
3581
3582     /* Currently supports only a single communication pulse */
3583
3584 /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3585     if (pme->nnodes_minor > 1)
3586     {
3587         /* Major dimension */
3588         overlap = &pme->overlap[1];
3589
3590         if (pme->nnodes_major > 1)
3591         {
3592              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3593         }
3594         else
3595         {
3596             size_yx = 0;
3597         }
3598         datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3599
3600         ipulse = 0;
3601
3602         send_id = overlap->send_id[ipulse];
3603         recv_id = overlap->recv_id[ipulse];
3604         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3605         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3606         recv_index0 = 0;
3607         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3608
3609         sendptr = overlap->sendbuf;
3610         recvptr = overlap->recvbuf;
3611
3612         /*
3613         printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3614                local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3615         printf("node %d send %f, %f\n",pme->nodeid,
3616                sendptr[0],sendptr[send_nindex*datasize-1]);
3617         */
3618
3619 #ifdef GMX_MPI
3620         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3621                      send_id,ipulse,
3622                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3623                      recv_id,ipulse,
3624                      overlap->mpi_comm,&stat);
3625 #endif
3626
3627         for(x=0; x<local_fft_ndata[XX]; x++)
3628         {
3629             for(y=0; y<recv_nindex; y++)
3630             {
3631                 indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3632                 indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3633                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3634                 {
3635                     fftgrid[indg+z] += recvptr[indb+z];
3636                 }
3637             }
3638         }
3639         if (pme->nnodes_major > 1)
3640         {
3641             sendptr = pme->overlap[0].sendbuf;
3642             for(x=0; x<size_yx; x++)
3643             {
3644                 for(y=0; y<recv_nindex; y++)
3645                 {
3646                     indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3647                     indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3648                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3649                     {
3650                         sendptr[indg+z] += recvptr[indb+z];
3651                     }
3652                 }
3653             }
3654         }
3655     }
3656
3657     /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3658     if (pme->nnodes_major > 1)
3659     {
3660         /* Major dimension */
3661         overlap = &pme->overlap[0];
3662
3663         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3664         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3665
3666         ipulse = 0;
3667
3668         send_id = overlap->send_id[ipulse];
3669         recv_id = overlap->recv_id[ipulse];
3670         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3671         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3672         recv_index0 = 0;
3673         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3674
3675         sendptr = overlap->sendbuf;
3676         recvptr = overlap->recvbuf;
3677
3678         if (debug != NULL)
3679         {
3680             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3681                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3682         }
3683
3684 #ifdef GMX_MPI
3685         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3686                      send_id,ipulse,
3687                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3688                      recv_id,ipulse,
3689                      overlap->mpi_comm,&stat);
3690 #endif
3691
3692         for(x=0; x<recv_nindex; x++)
3693         {
3694             for(y=0; y<local_fft_ndata[YY]; y++)
3695             {
3696                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3697                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3698                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3699                 {
3700                     fftgrid[indg+z] += recvptr[indb+z];
3701                 }
3702             }
3703         }
3704     }
3705 }
3706
3707
3708 static void spread_on_grid(gmx_pme_t pme,
3709                            pme_atomcomm_t *atc,pmegrids_t *grids,
3710                            gmx_bool bCalcSplines,gmx_bool bSpread,
3711                            real *fftgrid)
3712 {
3713     int nthread,thread;
3714 #ifdef PME_TIME_THREADS
3715     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3716     static double cs1=0,cs2=0,cs3=0;
3717     static double cs1a[6]={0,0,0,0,0,0};
3718     static int cnt=0;
3719 #endif
3720
3721     nthread = pme->nthread;
3722     assert(nthread>0);
3723
3724 #ifdef PME_TIME_THREADS
3725     c1 = omp_cyc_start();
3726 #endif
3727     if (bCalcSplines)
3728     {
3729 #pragma omp parallel for num_threads(nthread) schedule(static)
3730         for(thread=0; thread<nthread; thread++)
3731         {
3732             int start,end;
3733
3734             start = atc->n* thread   /nthread;
3735             end   = atc->n*(thread+1)/nthread;
3736
3737             /* Compute fftgrid index for all atoms,
3738              * with help of some extra variables.
3739              */
3740             calc_interpolation_idx(pme,atc,start,end,thread);
3741         }
3742     }
3743 #ifdef PME_TIME_THREADS
3744     c1 = omp_cyc_end(c1);
3745     cs1 += (double)c1;
3746 #endif
3747
3748 #ifdef PME_TIME_THREADS
3749     c2 = omp_cyc_start();
3750 #endif
3751 #pragma omp parallel for num_threads(nthread) schedule(static)
3752     for(thread=0; thread<nthread; thread++)
3753     {
3754         splinedata_t *spline;
3755         pmegrid_t *grid;
3756
3757         /* make local bsplines  */
3758         if (grids == NULL || grids->nthread == 1)
3759         {
3760             spline = &atc->spline[0];
3761
3762             spline->n = atc->n;
3763
3764             grid = &grids->grid;
3765         }
3766         else
3767         {
3768             spline = &atc->spline[thread];
3769
3770             make_thread_local_ind(atc,thread,spline);
3771
3772             grid = &grids->grid_th[thread];
3773         }
3774
3775         if (bCalcSplines)
3776         {
3777             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3778                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3779         }
3780
3781         if (bSpread)
3782         {
3783             /* put local atoms on grid. */
3784 #ifdef PME_TIME_SPREAD
3785             ct1a = omp_cyc_start();
3786 #endif
3787             spread_q_bsplines_thread(grid,atc,spline,&pme->spline_work);
3788
3789             if (grids->nthread > 1)
3790             {
3791                 copy_local_grid(pme,grids,thread,fftgrid);
3792             }
3793 #ifdef PME_TIME_SPREAD
3794             ct1a = omp_cyc_end(ct1a);
3795             cs1a[thread] += (double)ct1a;
3796 #endif
3797         }
3798     }
3799 #ifdef PME_TIME_THREADS
3800     c2 = omp_cyc_end(c2);
3801     cs2 += (double)c2;
3802 #endif
3803
3804     if (bSpread && grids->nthread > 1)
3805     {
3806 #ifdef PME_TIME_THREADS
3807         c3 = omp_cyc_start();
3808 #endif
3809 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3810         for(thread=0; thread<grids->nthread; thread++)
3811         {
3812             reduce_threadgrid_overlap(pme,grids,thread,
3813                                       fftgrid,
3814                                       pme->overlap[0].sendbuf,
3815                                       pme->overlap[1].sendbuf);
3816 #ifdef PRINT_PME_SENDBUF
3817             print_sendbuf(pme,pme->overlap[0].sendbuf);
3818 #endif
3819         }
3820 #ifdef PME_TIME_THREADS
3821         c3 = omp_cyc_end(c3);
3822         cs3 += (double)c3;
3823 #endif
3824
3825         if (pme->nnodes > 1)
3826         {
3827             /* Communicate the overlapping part of the fftgrid */
3828             sum_fftgrid_dd(pme,fftgrid);
3829         }
3830     }
3831
3832 #ifdef PME_TIME_THREADS
3833     cnt++;
3834     if (cnt % 20 == 0)
3835     {
3836         printf("idx %.2f spread %.2f red %.2f",
3837                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3838 #ifdef PME_TIME_SPREAD
3839         for(thread=0; thread<nthread; thread++)
3840             printf(" %.2f",cs1a[thread]*1e-9);
3841 #endif
3842         printf("\n");
3843     }
3844 #endif
3845 }
3846
3847
3848 static void dump_grid(FILE *fp,
3849                       int sx,int sy,int sz,int nx,int ny,int nz,
3850                       int my,int mz,const real *g)
3851 {
3852     int x,y,z;
3853
3854     for(x=0; x<nx; x++)
3855     {
3856         for(y=0; y<ny; y++)
3857         {
3858             for(z=0; z<nz; z++)
3859             {
3860                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3861                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3862             }
3863         }
3864     }
3865 }
3866
3867 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3868 {
3869     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3870
3871     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3872                                    local_fft_ndata,
3873                                    local_fft_offset,
3874                                    local_fft_size);
3875
3876     dump_grid(stderr,
3877               pme->pmegrid_start_ix,
3878               pme->pmegrid_start_iy,
3879               pme->pmegrid_start_iz,
3880               pme->pmegrid_nx-pme->pme_order+1,
3881               pme->pmegrid_ny-pme->pme_order+1,
3882               pme->pmegrid_nz-pme->pme_order+1,
3883               local_fft_size[YY],
3884               local_fft_size[ZZ],
3885               fftgrid);
3886 }
3887
3888
3889 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3890 {
3891     pme_atomcomm_t *atc;
3892     pmegrids_t *grid;
3893
3894     if (pme->nnodes > 1)
3895     {
3896         gmx_incons("gmx_pme_calc_energy called in parallel");
3897     }
3898     if (pme->bFEP > 1)
3899     {
3900         gmx_incons("gmx_pme_calc_energy with free energy");
3901     }
3902
3903     atc = &pme->atc_energy;
3904     atc->nthread   = 1;
3905     if (atc->spline == NULL)
3906     {
3907         snew(atc->spline,atc->nthread);
3908     }
3909     atc->nslab     = 1;
3910     atc->bSpread   = TRUE;
3911     atc->pme_order = pme->pme_order;
3912     atc->n         = n;
3913     pme_realloc_atomcomm_things(atc);
3914     atc->x         = x;
3915     atc->q         = q;
3916
3917     /* We only use the A-charges grid */
3918     grid = &pme->pmegridA;
3919
3920     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3921
3922     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3923 }
3924
3925
3926 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3927         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3928 {
3929     /* Reset all the counters related to performance over the run */
3930     wallcycle_stop(wcycle,ewcRUN);
3931     wallcycle_reset_all(wcycle);
3932     init_nrnb(nrnb);
3933     ir->init_step += step_rel;
3934     ir->nsteps    -= step_rel;
3935     wallcycle_start(wcycle,ewcRUN);
3936 }
3937
3938
3939 int gmx_pmeonly(gmx_pme_t pme,
3940                 t_commrec *cr,    t_nrnb *nrnb,
3941                 gmx_wallcycle_t wcycle,
3942                 real ewaldcoeff,  gmx_bool bGatherOnly,
3943                 t_inputrec *ir)
3944 {
3945     gmx_pme_pp_t pme_pp;
3946     int  natoms;
3947     matrix box;
3948     rvec *x_pp=NULL,*f_pp=NULL;
3949     real *chargeA=NULL,*chargeB=NULL;
3950     real lambda=0;
3951     int  maxshift_x=0,maxshift_y=0;
3952     real energy,dvdlambda;
3953     matrix vir;
3954     float cycles;
3955     int  count;
3956     gmx_bool bEnerVir;
3957     gmx_large_int_t step,step_rel;
3958
3959
3960     pme_pp = gmx_pme_pp_init(cr);
3961
3962     init_nrnb(nrnb);
3963
3964     count = 0;
3965     do /****** this is a quasi-loop over time steps! */
3966     {
3967         /* Domain decomposition */
3968         natoms = gmx_pme_recv_q_x(pme_pp,
3969                                   &chargeA,&chargeB,box,&x_pp,&f_pp,
3970                                   &maxshift_x,&maxshift_y,
3971                                   &pme->bFEP,&lambda,
3972                                   &bEnerVir,
3973                                   &step);
3974
3975         if (natoms == -1) {
3976             /* We should stop: break out of the loop */
3977             break;
3978         }
3979
3980         step_rel = step - ir->init_step;
3981
3982         if (count == 0)
3983             wallcycle_start(wcycle,ewcRUN);
3984
3985         wallcycle_start(wcycle,ewcPMEMESH);
3986
3987         dvdlambda = 0;
3988         clear_mat(vir);
3989         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
3990                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
3991                    &energy,lambda,&dvdlambda,
3992                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
3993
3994         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
3995
3996         gmx_pme_send_force_vir_ener(pme_pp,
3997                                     f_pp,vir,energy,dvdlambda,
3998                                     cycles);
3999
4000         count++;
4001
4002         if (step_rel == wcycle_get_reset_counters(wcycle))
4003         {
4004             /* Reset all the counters related to performance over the run */
4005             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4006             wcycle_set_reset_counters(wcycle, 0);
4007         }
4008
4009     } /***** end of quasi-loop, we stop with the break above */
4010     while (TRUE);
4011
4012     return 0;
4013 }
4014
4015 int gmx_pme_do(gmx_pme_t pme,
4016                int start,       int homenr,
4017                rvec x[],        rvec f[],
4018                real *chargeA,   real *chargeB,
4019                matrix box, t_commrec *cr,
4020                int  maxshift_x, int maxshift_y,
4021                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4022                matrix vir,      real ewaldcoeff,
4023                real *energy,    real lambda,
4024                real *dvdlambda, int flags)
4025 {
4026     int     q,d,i,j,ntot,npme;
4027     int     nx,ny,nz;
4028     int     n_d,local_ny;
4029     pme_atomcomm_t *atc=NULL;
4030     pmegrids_t *pmegrid=NULL;
4031     real    *grid=NULL;
4032     real    *ptr;
4033     rvec    *x_d,*f_d;
4034     real    *charge=NULL,*q_d;
4035     real    energy_AB[2];
4036     matrix  vir_AB[2];
4037     gmx_bool bClearF;
4038     gmx_parallel_3dfft_t pfft_setup;
4039     real *  fftgrid;
4040     t_complex * cfftgrid;
4041     int     thread;
4042     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4043     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4044
4045     assert(pme->nnodes > 0);
4046     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4047
4048     if (pme->nnodes > 1) {
4049         atc = &pme->atc[0];
4050         atc->npd = homenr;
4051         if (atc->npd > atc->pd_nalloc) {
4052             atc->pd_nalloc = over_alloc_dd(atc->npd);
4053             srenew(atc->pd,atc->pd_nalloc);
4054         }
4055         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4056     }
4057     else
4058     {
4059         /* This could be necessary for TPI */
4060         pme->atc[0].n = homenr;
4061     }
4062
4063     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4064         if (q == 0) {
4065             pmegrid = &pme->pmegridA;
4066             fftgrid = pme->fftgridA;
4067             cfftgrid = pme->cfftgridA;
4068             pfft_setup = pme->pfft_setupA;
4069             charge = chargeA+start;
4070         } else {
4071             pmegrid = &pme->pmegridB;
4072             fftgrid = pme->fftgridB;
4073             cfftgrid = pme->cfftgridB;
4074             pfft_setup = pme->pfft_setupB;
4075             charge = chargeB+start;
4076         }
4077         grid = pmegrid->grid.grid;
4078         /* Unpack structure */
4079         if (debug) {
4080             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4081                     cr->nnodes,cr->nodeid);
4082             fprintf(debug,"Grid = %p\n",(void*)grid);
4083             if (grid == NULL)
4084                 gmx_fatal(FARGS,"No grid!");
4085         }
4086         where();
4087
4088         m_inv_ur0(box,pme->recipbox);
4089
4090         if (pme->nnodes == 1) {
4091             atc = &pme->atc[0];
4092             if (DOMAINDECOMP(cr)) {
4093                 atc->n = homenr;
4094                 pme_realloc_atomcomm_things(atc);
4095             }
4096             atc->x = x;
4097             atc->q = charge;
4098             atc->f = f;
4099         } else {
4100             wallcycle_start(wcycle,ewcPME_REDISTXF);
4101             for(d=pme->ndecompdim-1; d>=0; d--)
4102             {
4103                 if (d == pme->ndecompdim-1)
4104                 {
4105                     n_d = homenr;
4106                     x_d = x + start;
4107                     q_d = charge;
4108                 }
4109                 else
4110                 {
4111                     n_d = pme->atc[d+1].n;
4112                     x_d = atc->x;
4113                     q_d = atc->q;
4114                 }
4115                 atc = &pme->atc[d];
4116                 atc->npd = n_d;
4117                 if (atc->npd > atc->pd_nalloc) {
4118                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4119                     srenew(atc->pd,atc->pd_nalloc);
4120                 }
4121                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4122                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4123                 where();
4124
4125                 /* Redistribute x (only once) and qA or qB */
4126                 if (DOMAINDECOMP(cr)) {
4127                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4128                 } else {
4129                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4130                 }
4131             }
4132             where();
4133
4134             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4135         }
4136
4137         if (debug)
4138             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4139                     cr->nodeid,atc->n);
4140
4141         if (flags & GMX_PME_SPREAD_Q)
4142         {
4143             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4144
4145             /* Spread the charges on a grid */
4146             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4147
4148             if (q == 0)
4149             {
4150                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4151             }
4152             inc_nrnb(nrnb,eNR_SPREADQBSP,
4153                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4154
4155             if (pme->nthread == 1)
4156             {
4157                 wrap_periodic_pmegrid(pme,grid);
4158
4159                 /* sum contributions to local grid from other nodes */
4160 #ifdef GMX_MPI
4161                 if (pme->nnodes > 1)
4162                 {
4163                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4164                     where();
4165                 }
4166 #endif
4167
4168                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4169             }
4170
4171             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4172
4173             /*
4174             dump_local_fftgrid(pme,fftgrid);
4175             exit(0);
4176             */
4177         }
4178
4179         /* Here we start a large thread parallel region */
4180 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4181         for(thread=0; thread<pme->nthread; thread++)
4182         {
4183             if (flags & GMX_PME_SOLVE)
4184             {
4185                 int loop_count;
4186
4187                 /* do 3d-fft */
4188                 if (thread == 0)
4189                 {
4190                     wallcycle_start(wcycle,ewcPME_FFT);
4191                 }
4192                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4193                                            fftgrid,cfftgrid,thread,wcycle);
4194                 if (thread == 0)
4195                 {
4196                     wallcycle_stop(wcycle,ewcPME_FFT);
4197                 }
4198                 where();
4199
4200                 /* solve in k-space for our local cells */
4201                 if (thread == 0)
4202                 {
4203                     wallcycle_start(wcycle,ewcPME_SOLVE);
4204                 }
4205                 loop_count =
4206                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4207                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4208                                   bCalcEnerVir,
4209                                   pme->nthread,thread);
4210                 if (thread == 0)
4211                 {
4212                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4213                     where();
4214                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4215                 }
4216             }
4217
4218             if (bCalcF)
4219             {
4220                 /* do 3d-invfft */
4221                 if (thread == 0)
4222                 {
4223                     where();
4224                     wallcycle_start(wcycle,ewcPME_FFT);
4225                 }
4226                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4227                                            cfftgrid,fftgrid,thread,wcycle);
4228                 if (thread == 0)
4229                 {
4230                     wallcycle_stop(wcycle,ewcPME_FFT);
4231
4232                     where();
4233
4234                     if (pme->nodeid == 0)
4235                     {
4236                         ntot = pme->nkx*pme->nky*pme->nkz;
4237                         npme  = ntot*log((real)ntot)/log(2.0);
4238                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4239                     }
4240
4241                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4242                 }
4243
4244                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4245             }
4246         }
4247         /* End of thread parallel section.
4248          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4249          */
4250
4251         if (bCalcF)
4252         {
4253             /* distribute local grid to all nodes */
4254 #ifdef GMX_MPI
4255             if (pme->nnodes > 1) {
4256                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4257             }
4258 #endif
4259             where();
4260
4261             unwrap_periodic_pmegrid(pme,grid);
4262
4263             /* interpolate forces for our local atoms */
4264
4265             where();
4266
4267             /* If we are running without parallelization,
4268              * atc->f is the actual force array, not a buffer,
4269              * therefore we should not clear it.
4270              */
4271             bClearF = (q == 0 && PAR(cr));
4272 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4273             for(thread=0; thread<pme->nthread; thread++)
4274             {
4275                 gather_f_bsplines(pme,grid,bClearF,atc,
4276                                   &atc->spline[thread],
4277                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4278             }
4279
4280             where();
4281
4282             inc_nrnb(nrnb,eNR_GATHERFBSP,
4283                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4284             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4285         }
4286
4287         if (bCalcEnerVir)
4288         {
4289             /* This should only be called on the master thread
4290              * and after the threads have synchronized.
4291              */
4292             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4293         }
4294     } /* of q-loop */
4295
4296     if (bCalcF && pme->nnodes > 1) {
4297         wallcycle_start(wcycle,ewcPME_REDISTXF);
4298         for(d=0; d<pme->ndecompdim; d++)
4299         {
4300             atc = &pme->atc[d];
4301             if (d == pme->ndecompdim - 1)
4302             {
4303                 n_d = homenr;
4304                 f_d = f + start;
4305             }
4306             else
4307             {
4308                 n_d = pme->atc[d+1].n;
4309                 f_d = pme->atc[d+1].f;
4310             }
4311             if (DOMAINDECOMP(cr)) {
4312                 dd_pmeredist_f(pme,atc,n_d,f_d,
4313                                d==pme->ndecompdim-1 && pme->bPPnode);
4314             } else {
4315                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4316             }
4317         }
4318
4319         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4320     }
4321     where();
4322
4323     if (bCalcEnerVir)
4324     {
4325         if (!pme->bFEP) {
4326             *energy = energy_AB[0];
4327             m_add(vir,vir_AB[0],vir);
4328         } else {
4329             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4330             *dvdlambda += energy_AB[1] - energy_AB[0];
4331             for(i=0; i<DIM; i++)
4332             {
4333                 for(j=0; j<DIM; j++)
4334                 {
4335                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4336                         lambda*vir_AB[1][i][j];
4337                 }
4338             }
4339         }
4340     }
4341     else
4342     {
4343         *energy = 0;
4344     }
4345
4346     if (debug)
4347     {
4348         fprintf(debug,"PME mesh energy: %g\n",*energy);
4349     }
4350
4351     return 0;
4352 }