Merge release-4-6 into master
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include <stdio.h>
71 #include <string.h>
72 #include <math.h>
73 #include <assert.h>
74 #include "typedefs.h"
75 #include "txtdump.h"
76 #include "vec.h"
77 #include "gmxcomplex.h"
78 #include "smalloc.h"
79 #include "futil.h"
80 #include "coulomb.h"
81 #include "gmx_fatal.h"
82 #include "pme.h"
83 #include "network.h"
84 #include "physics.h"
85 #include "nrnb.h"
86 #include "copyrite.h"
87 #include "gmx_wallcycle.h"
88 #include "gmx_parallel_3dfft.h"
89 #include "pdbio.h"
90 #include "gmx_cyclecounter.h"
91 #include "macros.h"
92
93 /* Single precision, with SSE2 or higher available */
94 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
95
96 #include "gmx_x86_sse2.h"
97 #include "gmx_math_x86_sse2_single.h"
98
99 #define PME_SSE
100 /* Some old AMD processors could have problems with unaligned loads+stores */
101 #ifndef GMX_FAHCORE
102 #define PME_SSE_UNALIGNED
103 #endif
104 #endif
105
106 #define DFT_TOL 1e-7
107 /* #define PRT_FORCE */
108 /* conditions for on the fly time-measurement */
109 /* #define TAKETIME (step > 1 && timesteps < 10) */
110 #define TAKETIME FALSE
111
112 /* #define PME_TIME_THREADS */
113
114 #ifdef GMX_DOUBLE
115 #define mpi_type MPI_DOUBLE
116 #else
117 #define mpi_type MPI_FLOAT
118 #endif
119
120 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
121 #define GMX_CACHE_SEP 64
122
123 /* We only define a maximum to be able to use local arrays without allocation.
124  * An order larger than 12 should never be needed, even for test cases.
125  * If needed it can be changed here.
126  */
127 #define PME_ORDER_MAX 12
128
129 /* Internal datastructures */
130 typedef struct {
131     int send_index0;
132     int send_nindex;
133     int recv_index0;
134     int recv_nindex;
135 } pme_grid_comm_t;
136
137 typedef struct {
138 #ifdef GMX_MPI
139     MPI_Comm mpi_comm;
140 #endif
141     int  nnodes,nodeid;
142     int  *s2g0;
143     int  *s2g1;
144     int  noverlap_nodes;
145     int  *send_id,*recv_id;
146     pme_grid_comm_t *comm_data;
147     real *sendbuf;
148     real *recvbuf;
149 } pme_overlap_t;
150
151 typedef struct {
152     int *n;     /* Cumulative counts of the number of particles per thread */
153     int nalloc; /* Allocation size of i */
154     int *i;     /* Particle indices ordered on thread index (n) */
155 } thread_plist_t;
156
157 typedef struct {
158     int  n;
159     int  *ind;
160     splinevec theta;
161     splinevec dtheta;
162 } splinedata_t;
163
164 typedef struct {
165     int  dimind;            /* The index of the dimension, 0=x, 1=y */
166     int  nslab;
167     int  nodeid;
168 #ifdef GMX_MPI
169     MPI_Comm mpi_comm;
170 #endif
171
172     int  *node_dest;        /* The nodes to send x and q to with DD */
173     int  *node_src;         /* The nodes to receive x and q from with DD */
174     int  *buf_index;        /* Index for commnode into the buffers */
175
176     int  maxshift;
177
178     int  npd;
179     int  pd_nalloc;
180     int  *pd;
181     int  *count;            /* The number of atoms to send to each node */
182     int  **count_thread;
183     int  *rcount;           /* The number of atoms to receive */
184
185     int  n;
186     int  nalloc;
187     rvec *x;
188     real *q;
189     rvec *f;
190     gmx_bool bSpread;       /* These coordinates are used for spreading */
191     int  pme_order;
192     ivec *idx;
193     rvec *fractx;            /* Fractional coordinate relative to the
194                               * lower cell boundary
195                               */
196     int  nthread;
197     int  *thread_idx;        /* Which thread should spread which charge */
198     thread_plist_t *thread_plist;
199     splinedata_t *spline;
200 } pme_atomcomm_t;
201
202 #define FLBS  3
203 #define FLBSZ 4
204
205 typedef struct {
206     ivec ci;     /* The spatial location of this grid       */
207     ivec n;      /* The size of *grid, including order-1    */
208     ivec offset; /* The grid offset from the full node grid */
209     int  order;  /* PME spreading order                     */
210     real *grid;  /* The grid local thread, size n           */
211 } pmegrid_t;
212
213 typedef struct {
214     pmegrid_t grid;     /* The full node grid (non thread-local)            */
215     int  nthread;       /* The number of threads operating on this grid     */
216     ivec nc;            /* The local spatial decomposition over the threads */
217     pmegrid_t *grid_th; /* Array of grids for each thread                   */
218     int  **g2t;         /* The grid to thread index                         */
219     ivec nthread_comm;  /* The number of threads to communicate with        */
220 } pmegrids_t;
221
222
223 typedef struct {
224 #ifdef PME_SSE
225     /* Masks for SSE aligned spreading and gathering */
226     __m128 mask_SSE0[6],mask_SSE1[6];
227 #else
228     int dummy; /* C89 requires that struct has at least one member */
229 #endif
230 } pme_spline_work_t;
231
232 typedef struct {
233     /* work data for solve_pme */
234     int      nalloc;
235     real *   mhx;
236     real *   mhy;
237     real *   mhz;
238     real *   m2;
239     real *   denom;
240     real *   tmp1_alloc;
241     real *   tmp1;
242     real *   eterm;
243     real *   m2inv;
244
245     real     energy;
246     matrix   vir;
247 } pme_work_t;
248
249 typedef struct gmx_pme {
250     int  ndecompdim;         /* The number of decomposition dimensions */
251     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
252     int  nodeid_major;
253     int  nodeid_minor;
254     int  nnodes;             /* The number of nodes doing PME */
255     int  nnodes_major;
256     int  nnodes_minor;
257
258     MPI_Comm mpi_comm;
259     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
260 #ifdef GMX_MPI
261     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
262 #endif
263
264     int  nthread;            /* The number of threads doing PME */
265
266     gmx_bool bPPnode;        /* Node also does particle-particle forces */
267     gmx_bool bFEP;           /* Compute Free energy contribution */
268     int nkx,nky,nkz;         /* Grid dimensions */
269     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
270     int pme_order;
271     real epsilon_r;
272
273     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
274     pmegrids_t pmegridB;
275     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
276     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
277     /* pmegrid_nz might be larger than strictly necessary to ensure
278      * memory alignment, pmegrid_nz_base gives the real base size.
279      */
280     int     pmegrid_nz_base;
281     /* The local PME grid starting indices */
282     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
283
284     /* Work data for spreading and gathering */
285     pme_spline_work_t *spline_work;
286
287     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
288     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
289     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
290
291     t_complex *cfftgridA;             /* Grids for complex FFT data */
292     t_complex *cfftgridB;
293     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
294
295     gmx_parallel_3dfft_t  pfft_setupA;
296     gmx_parallel_3dfft_t  pfft_setupB;
297
298     int  *nnx,*nny,*nnz;
299     real *fshx,*fshy,*fshz;
300
301     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
302     matrix    recipbox;
303     splinevec bsp_mod;
304
305     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
306
307     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
308
309     rvec *bufv;             /* Communication buffer */
310     real *bufr;             /* Communication buffer */
311     int  buf_nalloc;        /* The communication buffer size */
312
313     /* thread local work data for solve_pme */
314     pme_work_t *work;
315
316     /* Work data for PME_redist */
317     gmx_bool redist_init;
318     int *    scounts;
319     int *    rcounts;
320     int *    sdispls;
321     int *    rdispls;
322     int *    sidx;
323     int *    idxa;
324     real *   redist_buf;
325     int      redist_buf_nalloc;
326
327     /* Work data for sum_qgrid */
328     real *   sum_qgrid_tmp;
329     real *   sum_qgrid_dd_tmp;
330 } t_gmx_pme;
331
332
333 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
334                                    int start,int end,int thread)
335 {
336     int  i;
337     int  *idxptr,tix,tiy,tiz;
338     real *xptr,*fptr,tx,ty,tz;
339     real rxx,ryx,ryy,rzx,rzy,rzz;
340     int  nx,ny,nz;
341     int  start_ix,start_iy,start_iz;
342     int  *g2tx,*g2ty,*g2tz;
343     gmx_bool bThreads;
344     int  *thread_idx=NULL;
345     thread_plist_t *tpl=NULL;
346     int  *tpl_n=NULL;
347     int  thread_i;
348
349     nx  = pme->nkx;
350     ny  = pme->nky;
351     nz  = pme->nkz;
352
353     start_ix = pme->pmegrid_start_ix;
354     start_iy = pme->pmegrid_start_iy;
355     start_iz = pme->pmegrid_start_iz;
356
357     rxx = pme->recipbox[XX][XX];
358     ryx = pme->recipbox[YY][XX];
359     ryy = pme->recipbox[YY][YY];
360     rzx = pme->recipbox[ZZ][XX];
361     rzy = pme->recipbox[ZZ][YY];
362     rzz = pme->recipbox[ZZ][ZZ];
363
364     g2tx = pme->pmegridA.g2t[XX];
365     g2ty = pme->pmegridA.g2t[YY];
366     g2tz = pme->pmegridA.g2t[ZZ];
367
368     bThreads = (atc->nthread > 1);
369     if (bThreads)
370     {
371         thread_idx = atc->thread_idx;
372
373         tpl   = &atc->thread_plist[thread];
374         tpl_n = tpl->n;
375         for(i=0; i<atc->nthread; i++)
376         {
377             tpl_n[i] = 0;
378         }
379     }
380
381     for(i=start; i<end; i++) {
382         xptr   = atc->x[i];
383         idxptr = atc->idx[i];
384         fptr   = atc->fractx[i];
385
386         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
387         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
388         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
389         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
390
391         tix = (int)(tx);
392         tiy = (int)(ty);
393         tiz = (int)(tz);
394
395         /* Because decomposition only occurs in x and y,
396          * we never have a fraction correction in z.
397          */
398         fptr[XX] = tx - tix + pme->fshx[tix];
399         fptr[YY] = ty - tiy + pme->fshy[tiy];
400         fptr[ZZ] = tz - tiz;
401
402         idxptr[XX] = pme->nnx[tix];
403         idxptr[YY] = pme->nny[tiy];
404         idxptr[ZZ] = pme->nnz[tiz];
405
406 #ifdef DEBUG
407         range_check(idxptr[XX],0,pme->pmegrid_nx);
408         range_check(idxptr[YY],0,pme->pmegrid_ny);
409         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
410 #endif
411
412         if (bThreads)
413         {
414             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
415             thread_idx[i] = thread_i;
416             tpl_n[thread_i]++;
417         }
418     }
419
420     if (bThreads)
421     {
422         /* Make a list of particle indices sorted on thread */
423
424         /* Get the cumulative count */
425         for(i=1; i<atc->nthread; i++)
426         {
427             tpl_n[i] += tpl_n[i-1];
428         }
429         /* The current implementation distributes particles equally
430          * over the threads, so we could actually allocate for that
431          * in pme_realloc_atomcomm_things.
432          */
433         if (tpl_n[atc->nthread-1] > tpl->nalloc)
434         {
435             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
436             srenew(tpl->i,tpl->nalloc);
437         }
438         /* Set tpl_n to the cumulative start */
439         for(i=atc->nthread-1; i>=1; i--)
440         {
441             tpl_n[i] = tpl_n[i-1];
442         }
443         tpl_n[0] = 0;
444
445         /* Fill our thread local array with indices sorted on thread */
446         for(i=start; i<end; i++)
447         {
448             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
449         }
450         /* Now tpl_n contains the cummulative count again */
451     }
452 }
453
454 static void make_thread_local_ind(pme_atomcomm_t *atc,
455                                   int thread,splinedata_t *spline)
456 {
457     int  n,t,i,start,end;
458     thread_plist_t *tpl;
459
460     /* Combine the indices made by each thread into one index */
461
462     n = 0;
463     start = 0;
464     for(t=0; t<atc->nthread; t++)
465     {
466         tpl = &atc->thread_plist[t];
467         /* Copy our part (start - end) from the list of thread t */
468         if (thread > 0)
469         {
470             start = tpl->n[thread-1];
471         }
472         end = tpl->n[thread];
473         for(i=start; i<end; i++)
474         {
475             spline->ind[n++] = tpl->i[i];
476         }
477     }
478
479     spline->n = n;
480 }
481
482
483 static void pme_calc_pidx(int start, int end,
484                           matrix recipbox, rvec x[],
485                           pme_atomcomm_t *atc, int *count)
486 {
487     int  nslab,i;
488     int  si;
489     real *xptr,s;
490     real rxx,ryx,rzx,ryy,rzy;
491     int *pd;
492
493     /* Calculate PME task index (pidx) for each grid index.
494      * Here we always assign equally sized slabs to each node
495      * for load balancing reasons (the PME grid spacing is not used).
496      */
497
498     nslab = atc->nslab;
499     pd    = atc->pd;
500
501     /* Reset the count */
502     for(i=0; i<nslab; i++)
503     {
504         count[i] = 0;
505     }
506
507     if (atc->dimind == 0)
508     {
509         rxx = recipbox[XX][XX];
510         ryx = recipbox[YY][XX];
511         rzx = recipbox[ZZ][XX];
512         /* Calculate the node index in x-dimension */
513         for(i=start; i<end; i++)
514         {
515             xptr   = x[i];
516             /* Fractional coordinates along box vectors */
517             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
518             si = (int)(s + 2*nslab) % nslab;
519             pd[i] = si;
520             count[si]++;
521         }
522     }
523     else
524     {
525         ryy = recipbox[YY][YY];
526         rzy = recipbox[ZZ][YY];
527         /* Calculate the node index in y-dimension */
528         for(i=start; i<end; i++)
529         {
530             xptr   = x[i];
531             /* Fractional coordinates along box vectors */
532             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
533             si = (int)(s + 2*nslab) % nslab;
534             pd[i] = si;
535             count[si]++;
536         }
537     }
538 }
539
540 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
541                                   pme_atomcomm_t *atc)
542 {
543     int nthread,thread,slab;
544
545     nthread = atc->nthread;
546
547 #pragma omp parallel for num_threads(nthread) schedule(static)
548     for(thread=0; thread<nthread; thread++)
549     {
550         pme_calc_pidx(natoms* thread   /nthread,
551                       natoms*(thread+1)/nthread,
552                       recipbox,x,atc,atc->count_thread[thread]);
553     }
554     /* Non-parallel reduction, since nslab is small */
555
556     for(thread=1; thread<nthread; thread++)
557     {
558         for(slab=0; slab<atc->nslab; slab++)
559         {
560             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
561         }
562     }
563 }
564
565 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
566 {
567     int i,d;
568
569     srenew(spline->ind,atc->nalloc);
570     /* Initialize the index to identity so it works without threads */
571     for(i=0; i<atc->nalloc; i++)
572     {
573         spline->ind[i] = i;
574     }
575
576     for(d=0;d<DIM;d++)
577     {
578         srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
579         srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
580     }
581 }
582
583 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
584 {
585     int nalloc_old,i,j,nalloc_tpl;
586
587     /* We have to avoid a NULL pointer for atc->x to avoid
588      * possible fatal errors in MPI routines.
589      */
590     if (atc->n > atc->nalloc || atc->nalloc == 0)
591     {
592         nalloc_old = atc->nalloc;
593         atc->nalloc = over_alloc_dd(max(atc->n,1));
594
595         if (atc->nslab > 1) {
596             srenew(atc->x,atc->nalloc);
597             srenew(atc->q,atc->nalloc);
598             srenew(atc->f,atc->nalloc);
599             for(i=nalloc_old; i<atc->nalloc; i++)
600             {
601                 clear_rvec(atc->f[i]);
602             }
603         }
604         if (atc->bSpread) {
605             srenew(atc->fractx,atc->nalloc);
606             srenew(atc->idx   ,atc->nalloc);
607
608             if (atc->nthread > 1)
609             {
610                 srenew(atc->thread_idx,atc->nalloc);
611             }
612
613             for(i=0; i<atc->nthread; i++)
614             {
615                 pme_realloc_splinedata(&atc->spline[i],atc);
616             }
617         }
618     }
619 }
620
621 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
622                          int n, gmx_bool bXF, rvec *x_f, real *charge,
623                          pme_atomcomm_t *atc)
624 /* Redistribute particle data for PME calculation */
625 /* domain decomposition by x coordinate           */
626 {
627     int *idxa;
628     int i, ii;
629
630     if(FALSE == pme->redist_init) {
631         snew(pme->scounts,atc->nslab);
632         snew(pme->rcounts,atc->nslab);
633         snew(pme->sdispls,atc->nslab);
634         snew(pme->rdispls,atc->nslab);
635         snew(pme->sidx,atc->nslab);
636         pme->redist_init = TRUE;
637     }
638     if (n > pme->redist_buf_nalloc) {
639         pme->redist_buf_nalloc = over_alloc_dd(n);
640         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
641     }
642
643     pme->idxa = atc->pd;
644
645 #ifdef GMX_MPI
646     if (forw && bXF) {
647         /* forward, redistribution from pp to pme */
648
649         /* Calculate send counts and exchange them with other nodes */
650         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
651         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
652         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
653
654         /* Calculate send and receive displacements and index into send
655            buffer */
656         pme->sdispls[0]=0;
657         pme->rdispls[0]=0;
658         pme->sidx[0]=0;
659         for(i=1; i<atc->nslab; i++) {
660             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
661             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
662             pme->sidx[i]=pme->sdispls[i];
663         }
664         /* Total # of particles to be received */
665         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
666
667         pme_realloc_atomcomm_things(atc);
668
669         /* Copy particle coordinates into send buffer and exchange*/
670         for(i=0; (i<n); i++) {
671             ii=DIM*pme->sidx[pme->idxa[i]];
672             pme->sidx[pme->idxa[i]]++;
673             pme->redist_buf[ii+XX]=x_f[i][XX];
674             pme->redist_buf[ii+YY]=x_f[i][YY];
675             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
676         }
677         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
678                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
679                       pme->rvec_mpi, atc->mpi_comm);
680     }
681     if (forw) {
682         /* Copy charge into send buffer and exchange*/
683         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
684         for(i=0; (i<n); i++) {
685             ii=pme->sidx[pme->idxa[i]];
686             pme->sidx[pme->idxa[i]]++;
687             pme->redist_buf[ii]=charge[i];
688         }
689         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
690                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
691                       atc->mpi_comm);
692     }
693     else { /* backward, redistribution from pme to pp */
694         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
695                       pme->redist_buf, pme->scounts, pme->sdispls,
696                       pme->rvec_mpi, atc->mpi_comm);
697
698         /* Copy data from receive buffer */
699         for(i=0; i<atc->nslab; i++)
700             pme->sidx[i] = pme->sdispls[i];
701         for(i=0; (i<n); i++) {
702             ii = DIM*pme->sidx[pme->idxa[i]];
703             x_f[i][XX] += pme->redist_buf[ii+XX];
704             x_f[i][YY] += pme->redist_buf[ii+YY];
705             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
706             pme->sidx[pme->idxa[i]]++;
707         }
708     }
709 #endif
710 }
711
712 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
713                             gmx_bool bBackward,int shift,
714                             void *buf_s,int nbyte_s,
715                             void *buf_r,int nbyte_r)
716 {
717 #ifdef GMX_MPI
718     int dest,src;
719     MPI_Status stat;
720
721     if (bBackward == FALSE) {
722         dest = atc->node_dest[shift];
723         src  = atc->node_src[shift];
724     } else {
725         dest = atc->node_src[shift];
726         src  = atc->node_dest[shift];
727     }
728
729     if (nbyte_s > 0 && nbyte_r > 0) {
730         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
731                      dest,shift,
732                      buf_r,nbyte_r,MPI_BYTE,
733                      src,shift,
734                      atc->mpi_comm,&stat);
735     } else if (nbyte_s > 0) {
736         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
737                  dest,shift,
738                  atc->mpi_comm);
739     } else if (nbyte_r > 0) {
740         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
741                  src,shift,
742                  atc->mpi_comm,&stat);
743     }
744 #endif
745 }
746
747 static void dd_pmeredist_x_q(gmx_pme_t pme,
748                              int n, gmx_bool bX, rvec *x, real *charge,
749                              pme_atomcomm_t *atc)
750 {
751     int *commnode,*buf_index;
752     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
753
754     commnode  = atc->node_dest;
755     buf_index = atc->buf_index;
756
757     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
758
759     nsend = 0;
760     for(i=0; i<nnodes_comm; i++) {
761         buf_index[commnode[i]] = nsend;
762         nsend += atc->count[commnode[i]];
763     }
764     if (bX) {
765         if (atc->count[atc->nodeid] + nsend != n)
766             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
767                       "This usually means that your system is not well equilibrated.",
768                       n - (atc->count[atc->nodeid] + nsend),
769                       pme->nodeid,'x'+atc->dimind);
770
771         if (nsend > pme->buf_nalloc) {
772             pme->buf_nalloc = over_alloc_dd(nsend);
773             srenew(pme->bufv,pme->buf_nalloc);
774             srenew(pme->bufr,pme->buf_nalloc);
775         }
776
777         atc->n = atc->count[atc->nodeid];
778         for(i=0; i<nnodes_comm; i++) {
779             scount = atc->count[commnode[i]];
780             /* Communicate the count */
781             if (debug)
782                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
783                         atc->dimind,atc->nodeid,commnode[i],scount);
784             pme_dd_sendrecv(atc,FALSE,i,
785                             &scount,sizeof(int),
786                             &atc->rcount[i],sizeof(int));
787             atc->n += atc->rcount[i];
788         }
789
790         pme_realloc_atomcomm_things(atc);
791     }
792
793     local_pos = 0;
794     for(i=0; i<n; i++) {
795         node = atc->pd[i];
796         if (node == atc->nodeid) {
797             /* Copy direct to the receive buffer */
798             if (bX) {
799                 copy_rvec(x[i],atc->x[local_pos]);
800             }
801             atc->q[local_pos] = charge[i];
802             local_pos++;
803         } else {
804             /* Copy to the send buffer */
805             if (bX) {
806                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
807             }
808             pme->bufr[buf_index[node]] = charge[i];
809             buf_index[node]++;
810         }
811     }
812
813     buf_pos = 0;
814     for(i=0; i<nnodes_comm; i++) {
815         scount = atc->count[commnode[i]];
816         rcount = atc->rcount[i];
817         if (scount > 0 || rcount > 0) {
818             if (bX) {
819                 /* Communicate the coordinates */
820                 pme_dd_sendrecv(atc,FALSE,i,
821                                 pme->bufv[buf_pos],scount*sizeof(rvec),
822                                 atc->x[local_pos],rcount*sizeof(rvec));
823             }
824             /* Communicate the charges */
825             pme_dd_sendrecv(atc,FALSE,i,
826                             pme->bufr+buf_pos,scount*sizeof(real),
827                             atc->q+local_pos,rcount*sizeof(real));
828             buf_pos   += scount;
829             local_pos += atc->rcount[i];
830         }
831     }
832 }
833
834 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
835                            int n, rvec *f,
836                            gmx_bool bAddF)
837 {
838   int *commnode,*buf_index;
839   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
840
841   commnode  = atc->node_dest;
842   buf_index = atc->buf_index;
843
844   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
845
846   local_pos = atc->count[atc->nodeid];
847   buf_pos = 0;
848   for(i=0; i<nnodes_comm; i++) {
849     scount = atc->rcount[i];
850     rcount = atc->count[commnode[i]];
851     if (scount > 0 || rcount > 0) {
852       /* Communicate the forces */
853       pme_dd_sendrecv(atc,TRUE,i,
854                       atc->f[local_pos],scount*sizeof(rvec),
855                       pme->bufv[buf_pos],rcount*sizeof(rvec));
856       local_pos += scount;
857     }
858     buf_index[commnode[i]] = buf_pos;
859     buf_pos   += rcount;
860   }
861
862     local_pos = 0;
863     if (bAddF)
864     {
865         for(i=0; i<n; i++)
866         {
867             node = atc->pd[i];
868             if (node == atc->nodeid)
869             {
870                 /* Add from the local force array */
871                 rvec_inc(f[i],atc->f[local_pos]);
872                 local_pos++;
873             }
874             else
875             {
876                 /* Add from the receive buffer */
877                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
878                 buf_index[node]++;
879             }
880         }
881     }
882     else
883     {
884         for(i=0; i<n; i++)
885         {
886             node = atc->pd[i];
887             if (node == atc->nodeid)
888             {
889                 /* Copy from the local force array */
890                 copy_rvec(atc->f[local_pos],f[i]);
891                 local_pos++;
892             }
893             else
894             {
895                 /* Copy from the receive buffer */
896                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
897                 buf_index[node]++;
898             }
899         }
900     }
901 }
902
903 #ifdef GMX_MPI
904 static void
905 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
906 {
907     pme_overlap_t *overlap;
908     int send_index0,send_nindex;
909     int recv_index0,recv_nindex;
910     MPI_Status stat;
911     int i,j,k,ix,iy,iz,icnt;
912     int ipulse,send_id,recv_id,datasize;
913     real *p;
914     real *sendptr,*recvptr;
915
916     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
917     overlap = &pme->overlap[1];
918
919     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
920     {
921         /* Since we have already (un)wrapped the overlap in the z-dimension,
922          * we only have to communicate 0 to nkz (not pmegrid_nz).
923          */
924         if (direction==GMX_SUM_QGRID_FORWARD)
925         {
926             send_id = overlap->send_id[ipulse];
927             recv_id = overlap->recv_id[ipulse];
928             send_index0   = overlap->comm_data[ipulse].send_index0;
929             send_nindex   = overlap->comm_data[ipulse].send_nindex;
930             recv_index0   = overlap->comm_data[ipulse].recv_index0;
931             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
932         }
933         else
934         {
935             send_id = overlap->recv_id[ipulse];
936             recv_id = overlap->send_id[ipulse];
937             send_index0   = overlap->comm_data[ipulse].recv_index0;
938             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
939             recv_index0   = overlap->comm_data[ipulse].send_index0;
940             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
941         }
942
943         /* Copy data to contiguous send buffer */
944         if (debug)
945         {
946             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
947                     pme->nodeid,overlap->nodeid,send_id,
948                     pme->pmegrid_start_iy,
949                     send_index0-pme->pmegrid_start_iy,
950                     send_index0-pme->pmegrid_start_iy+send_nindex);
951         }
952         icnt = 0;
953         for(i=0;i<pme->pmegrid_nx;i++)
954         {
955             ix = i;
956             for(j=0;j<send_nindex;j++)
957             {
958                 iy = j + send_index0 - pme->pmegrid_start_iy;
959                 for(k=0;k<pme->nkz;k++)
960                 {
961                     iz = k;
962                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
963                 }
964             }
965         }
966
967         datasize      = pme->pmegrid_nx * pme->nkz;
968
969         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
970                      send_id,ipulse,
971                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
972                      recv_id,ipulse,
973                      overlap->mpi_comm,&stat);
974
975         /* Get data from contiguous recv buffer */
976         if (debug)
977         {
978             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
979                     pme->nodeid,overlap->nodeid,recv_id,
980                     pme->pmegrid_start_iy,
981                     recv_index0-pme->pmegrid_start_iy,
982                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
983         }
984         icnt = 0;
985         for(i=0;i<pme->pmegrid_nx;i++)
986         {
987             ix = i;
988             for(j=0;j<recv_nindex;j++)
989             {
990                 iy = j + recv_index0 - pme->pmegrid_start_iy;
991                 for(k=0;k<pme->nkz;k++)
992                 {
993                     iz = k;
994                     if(direction==GMX_SUM_QGRID_FORWARD)
995                     {
996                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
997                     }
998                     else
999                     {
1000                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1001                     }
1002                 }
1003             }
1004         }
1005     }
1006
1007     /* Major dimension is easier, no copying required,
1008      * but we might have to sum to separate array.
1009      * Since we don't copy, we have to communicate up to pmegrid_nz,
1010      * not nkz as for the minor direction.
1011      */
1012     overlap = &pme->overlap[0];
1013
1014     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1015     {
1016         if(direction==GMX_SUM_QGRID_FORWARD)
1017         {
1018             send_id = overlap->send_id[ipulse];
1019             recv_id = overlap->recv_id[ipulse];
1020             send_index0   = overlap->comm_data[ipulse].send_index0;
1021             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1022             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1023             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1024             recvptr   = overlap->recvbuf;
1025         }
1026         else
1027         {
1028             send_id = overlap->recv_id[ipulse];
1029             recv_id = overlap->send_id[ipulse];
1030             send_index0   = overlap->comm_data[ipulse].recv_index0;
1031             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1032             recv_index0   = overlap->comm_data[ipulse].send_index0;
1033             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1034             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1035         }
1036
1037         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1038         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1039
1040         if (debug)
1041         {
1042             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1043                     pme->nodeid,overlap->nodeid,send_id,
1044                     pme->pmegrid_start_ix,
1045                     send_index0-pme->pmegrid_start_ix,
1046                     send_index0-pme->pmegrid_start_ix+send_nindex);
1047             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1048                     pme->nodeid,overlap->nodeid,recv_id,
1049                     pme->pmegrid_start_ix,
1050                     recv_index0-pme->pmegrid_start_ix,
1051                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1052         }
1053
1054         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1055                      send_id,ipulse,
1056                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1057                      recv_id,ipulse,
1058                      overlap->mpi_comm,&stat);
1059
1060         /* ADD data from contiguous recv buffer */
1061         if(direction==GMX_SUM_QGRID_FORWARD)
1062         {
1063             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1064             for(i=0;i<recv_nindex*datasize;i++)
1065             {
1066                 p[i] += overlap->recvbuf[i];
1067             }
1068         }
1069     }
1070 }
1071 #endif
1072
1073
1074 static int
1075 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1076 {
1077     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1078     ivec    local_pme_size;
1079     int     i,ix,iy,iz;
1080     int     pmeidx,fftidx;
1081
1082     /* Dimensions should be identical for A/B grid, so we just use A here */
1083     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1084                                    local_fft_ndata,
1085                                    local_fft_offset,
1086                                    local_fft_size);
1087
1088     local_pme_size[0] = pme->pmegrid_nx;
1089     local_pme_size[1] = pme->pmegrid_ny;
1090     local_pme_size[2] = pme->pmegrid_nz;
1091
1092     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1093      the offset is identical, and the PME grid always has more data (due to overlap)
1094      */
1095     {
1096 #ifdef DEBUG_PME
1097         FILE *fp,*fp2;
1098         char fn[STRLEN],format[STRLEN];
1099         real val;
1100         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1101         fp = ffopen(fn,"w");
1102         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1103         fp2 = ffopen(fn,"w");
1104      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1105 #endif
1106
1107     for(ix=0;ix<local_fft_ndata[XX];ix++)
1108     {
1109         for(iy=0;iy<local_fft_ndata[YY];iy++)
1110         {
1111             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1112             {
1113                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1114                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1115                 fftgrid[fftidx] = pmegrid[pmeidx];
1116 #ifdef DEBUG_PME
1117                 val = 100*pmegrid[pmeidx];
1118                 if (pmegrid[pmeidx] != 0)
1119                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1120                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1121                 if (pmegrid[pmeidx] != 0)
1122                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1123                             "qgrid",
1124                             pme->pmegrid_start_ix + ix,
1125                             pme->pmegrid_start_iy + iy,
1126                             pme->pmegrid_start_iz + iz,
1127                             pmegrid[pmeidx]);
1128 #endif
1129             }
1130         }
1131     }
1132 #ifdef DEBUG_PME
1133     ffclose(fp);
1134     ffclose(fp2);
1135 #endif
1136     }
1137     return 0;
1138 }
1139
1140
1141 static gmx_cycles_t omp_cyc_start()
1142 {
1143     return gmx_cycles_read();
1144 }
1145
1146 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1147 {
1148     return gmx_cycles_read() - c;
1149 }
1150
1151
1152 static int
1153 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1154                         int nthread,int thread)
1155 {
1156     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1157     ivec    local_pme_size;
1158     int     ixy0,ixy1,ixy,ix,iy,iz;
1159     int     pmeidx,fftidx;
1160 #ifdef PME_TIME_THREADS
1161     gmx_cycles_t c1;
1162     static double cs1=0;
1163     static int cnt=0;
1164 #endif
1165
1166 #ifdef PME_TIME_THREADS
1167     c1 = omp_cyc_start();
1168 #endif
1169     /* Dimensions should be identical for A/B grid, so we just use A here */
1170     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1171                                    local_fft_ndata,
1172                                    local_fft_offset,
1173                                    local_fft_size);
1174
1175     local_pme_size[0] = pme->pmegrid_nx;
1176     local_pme_size[1] = pme->pmegrid_ny;
1177     local_pme_size[2] = pme->pmegrid_nz;
1178
1179     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1180      the offset is identical, and the PME grid always has more data (due to overlap)
1181      */
1182     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1183     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1184
1185     for(ixy=ixy0;ixy<ixy1;ixy++)
1186     {
1187         ix = ixy/local_fft_ndata[YY];
1188         iy = ixy - ix*local_fft_ndata[YY];
1189
1190         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1191         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1192         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1193         {
1194             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1195         }
1196     }
1197
1198 #ifdef PME_TIME_THREADS
1199     c1 = omp_cyc_end(c1);
1200     cs1 += (double)c1;
1201     cnt++;
1202     if (cnt % 20 == 0)
1203     {
1204         printf("copy %.2f\n",cs1*1e-9);
1205     }
1206 #endif
1207
1208     return 0;
1209 }
1210
1211
1212 static void
1213 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1214 {
1215     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1216
1217     nx = pme->nkx;
1218     ny = pme->nky;
1219     nz = pme->nkz;
1220
1221     pnx = pme->pmegrid_nx;
1222     pny = pme->pmegrid_ny;
1223     pnz = pme->pmegrid_nz;
1224
1225     overlap = pme->pme_order - 1;
1226
1227     /* Add periodic overlap in z */
1228     for(ix=0; ix<pme->pmegrid_nx; ix++)
1229     {
1230         for(iy=0; iy<pme->pmegrid_ny; iy++)
1231         {
1232             for(iz=0; iz<overlap; iz++)
1233             {
1234                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1235                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1236             }
1237         }
1238     }
1239
1240     if (pme->nnodes_minor == 1)
1241     {
1242        for(ix=0; ix<pme->pmegrid_nx; ix++)
1243        {
1244            for(iy=0; iy<overlap; iy++)
1245            {
1246                for(iz=0; iz<nz; iz++)
1247                {
1248                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1249                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1250                }
1251            }
1252        }
1253     }
1254
1255     if (pme->nnodes_major == 1)
1256     {
1257         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1258
1259         for(ix=0; ix<overlap; ix++)
1260         {
1261             for(iy=0; iy<ny_x; iy++)
1262             {
1263                 for(iz=0; iz<nz; iz++)
1264                 {
1265                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1266                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1267                 }
1268             }
1269         }
1270     }
1271 }
1272
1273
1274 static void
1275 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1276 {
1277     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1278
1279     nx = pme->nkx;
1280     ny = pme->nky;
1281     nz = pme->nkz;
1282
1283     pnx = pme->pmegrid_nx;
1284     pny = pme->pmegrid_ny;
1285     pnz = pme->pmegrid_nz;
1286
1287     overlap = pme->pme_order - 1;
1288
1289     if (pme->nnodes_major == 1)
1290     {
1291         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1292
1293         for(ix=0; ix<overlap; ix++)
1294         {
1295             int iy,iz;
1296
1297             for(iy=0; iy<ny_x; iy++)
1298             {
1299                 for(iz=0; iz<nz; iz++)
1300                 {
1301                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1302                         pmegrid[(ix*pny+iy)*pnz+iz];
1303                 }
1304             }
1305         }
1306     }
1307
1308     if (pme->nnodes_minor == 1)
1309     {
1310 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1311        for(ix=0; ix<pme->pmegrid_nx; ix++)
1312        {
1313            int iy,iz;
1314
1315            for(iy=0; iy<overlap; iy++)
1316            {
1317                for(iz=0; iz<nz; iz++)
1318                {
1319                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1320                        pmegrid[(ix*pny+iy)*pnz+iz];
1321                }
1322            }
1323        }
1324     }
1325
1326     /* Copy periodic overlap in z */
1327 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1328     for(ix=0; ix<pme->pmegrid_nx; ix++)
1329     {
1330         int iy,iz;
1331
1332         for(iy=0; iy<pme->pmegrid_ny; iy++)
1333         {
1334             for(iz=0; iz<overlap; iz++)
1335             {
1336                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1337                     pmegrid[(ix*pny+iy)*pnz+iz];
1338             }
1339         }
1340     }
1341 }
1342
1343 static void clear_grid(int nx,int ny,int nz,real *grid,
1344                        ivec fs,int *flag,
1345                        int fx,int fy,int fz,
1346                        int order)
1347 {
1348     int nc,ncz;
1349     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1350     int flind;
1351
1352     nc  = 2 + (order - 2)/FLBS;
1353     ncz = 2 + (order - 2)/FLBSZ;
1354
1355     for(fsx=fx; fsx<fx+nc; fsx++)
1356     {
1357         for(fsy=fy; fsy<fy+nc; fsy++)
1358         {
1359             for(fsz=fz; fsz<fz+ncz; fsz++)
1360             {
1361                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1362                 if (flag[flind] == 0)
1363                 {
1364                     gx = fsx*FLBS;
1365                     gy = fsy*FLBS;
1366                     gz = fsz*FLBSZ;
1367                     g0x = (gx*ny + gy)*nz + gz;
1368                     for(x=0; x<FLBS; x++)
1369                     {
1370                         g0y = g0x;
1371                         for(y=0; y<FLBS; y++)
1372                         {
1373                             for(z=0; z<FLBSZ; z++)
1374                             {
1375                                 grid[g0y+z] = 0;
1376                             }
1377                             g0y += nz;
1378                         }
1379                         g0x += ny*nz;
1380                     }
1381
1382                     flag[flind] = 1;
1383                 }
1384             }
1385         }
1386     }
1387 }
1388
1389 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1390 #define DO_BSPLINE(order)                            \
1391 for(ithx=0; (ithx<order); ithx++)                    \
1392 {                                                    \
1393     index_x = (i0+ithx)*pny*pnz;                     \
1394     valx    = qn*thx[ithx];                          \
1395                                                      \
1396     for(ithy=0; (ithy<order); ithy++)                \
1397     {                                                \
1398         valxy    = valx*thy[ithy];                   \
1399         index_xy = index_x+(j0+ithy)*pnz;            \
1400                                                      \
1401         for(ithz=0; (ithz<order); ithz++)            \
1402         {                                            \
1403             index_xyz        = index_xy+(k0+ithz);   \
1404             grid[index_xyz] += valxy*thz[ithz];      \
1405         }                                            \
1406     }                                                \
1407 }
1408
1409
1410 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1411                                      pme_atomcomm_t *atc, splinedata_t *spline,
1412                                      pme_spline_work_t *work)
1413 {
1414
1415     /* spread charges from home atoms to local grid */
1416     real     *grid;
1417     pme_overlap_t *ol;
1418     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1419     int *    idxptr;
1420     int      order,norder,index_x,index_xy,index_xyz;
1421     real     valx,valxy,qn;
1422     real     *thx,*thy,*thz;
1423     int      localsize, bndsize;
1424     int      pnx,pny,pnz,ndatatot;
1425     int      offx,offy,offz;
1426
1427     pnx = pmegrid->n[XX];
1428     pny = pmegrid->n[YY];
1429     pnz = pmegrid->n[ZZ];
1430
1431     offx = pmegrid->offset[XX];
1432     offy = pmegrid->offset[YY];
1433     offz = pmegrid->offset[ZZ];
1434
1435     ndatatot = pnx*pny*pnz;
1436     grid = pmegrid->grid;
1437     for(i=0;i<ndatatot;i++)
1438     {
1439         grid[i] = 0;
1440     }
1441
1442     order = pmegrid->order;
1443
1444     for(nn=0; nn<spline->n; nn++)
1445     {
1446         n  = spline->ind[nn];
1447         qn = atc->q[n];
1448
1449         if (qn != 0)
1450         {
1451             idxptr = atc->idx[n];
1452             norder = nn*order;
1453
1454             i0   = idxptr[XX] - offx;
1455             j0   = idxptr[YY] - offy;
1456             k0   = idxptr[ZZ] - offz;
1457
1458             thx = spline->theta[XX] + norder;
1459             thy = spline->theta[YY] + norder;
1460             thz = spline->theta[ZZ] + norder;
1461
1462             switch (order) {
1463             case 4:
1464 #ifdef PME_SSE
1465 #ifdef PME_SSE_UNALIGNED
1466 #define PME_SPREAD_SSE_ORDER4
1467 #else
1468 #define PME_SPREAD_SSE_ALIGNED
1469 #define PME_ORDER 4
1470 #endif
1471 #include "pme_sse_single.h"
1472 #else
1473                 DO_BSPLINE(4);
1474 #endif
1475                 break;
1476             case 5:
1477 #ifdef PME_SSE
1478 #define PME_SPREAD_SSE_ALIGNED
1479 #define PME_ORDER 5
1480 #include "pme_sse_single.h"
1481 #else
1482                 DO_BSPLINE(5);
1483 #endif
1484                 break;
1485             default:
1486                 DO_BSPLINE(order);
1487                 break;
1488             }
1489         }
1490     }
1491 }
1492
1493 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1494 {
1495 #ifdef PME_SSE
1496     if (pme_order == 5
1497 #ifndef PME_SSE_UNALIGNED
1498         || pme_order == 4
1499 #endif
1500         )
1501     {
1502         /* Round nz up to a multiple of 4 to ensure alignment */
1503         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1504     }
1505 #endif
1506 }
1507
1508 static void set_gridsize_alignment(int *gridsize,int pme_order)
1509 {
1510 #ifdef PME_SSE
1511 #ifndef PME_SSE_UNALIGNED
1512     if (pme_order == 4)
1513     {
1514         /* Add extra elements to ensured aligned operations do not go
1515          * beyond the allocated grid size.
1516          * Note that for pme_order=5, the pme grid z-size alignment
1517          * ensures that we will not go beyond the grid size.
1518          */
1519          *gridsize += 4;
1520     }
1521 #endif
1522 #endif
1523 }
1524
1525 static void pmegrid_init(pmegrid_t *grid,
1526                          int cx, int cy, int cz,
1527                          int x0, int y0, int z0,
1528                          int x1, int y1, int z1,
1529                          gmx_bool set_alignment,
1530                          int pme_order,
1531                          real *ptr)
1532 {
1533     int nz,gridsize;
1534
1535     grid->ci[XX] = cx;
1536     grid->ci[YY] = cy;
1537     grid->ci[ZZ] = cz;
1538     grid->offset[XX] = x0;
1539     grid->offset[YY] = y0;
1540     grid->offset[ZZ] = z0;
1541     grid->n[XX]      = x1 - x0 + pme_order - 1;
1542     grid->n[YY]      = y1 - y0 + pme_order - 1;
1543     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1544
1545     nz = grid->n[ZZ];
1546     set_grid_alignment(&nz,pme_order);
1547     if (set_alignment)
1548     {
1549         grid->n[ZZ] = nz;
1550     }
1551     else if (nz != grid->n[ZZ])
1552     {
1553         gmx_incons("pmegrid_init call with an unaligned z size");
1554     }
1555
1556     grid->order = pme_order;
1557     if (ptr == NULL)
1558     {
1559         gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1560         set_gridsize_alignment(&gridsize,pme_order);
1561         snew_aligned(grid->grid,gridsize,16);
1562     }
1563     else
1564     {
1565         grid->grid = ptr;
1566     }
1567 }
1568
1569 static int div_round_up(int enumerator,int denominator)
1570 {
1571     return (enumerator + denominator - 1)/denominator;
1572 }
1573
1574 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1575                                   ivec nsub)
1576 {
1577     int gsize_opt,gsize;
1578     int nsx,nsy,nsz;
1579     char *env;
1580
1581     gsize_opt = -1;
1582     for(nsx=1; nsx<=nthread; nsx++)
1583     {
1584         if (nthread % nsx == 0)
1585         {
1586             for(nsy=1; nsy<=nthread; nsy++)
1587             {
1588                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1589                 {
1590                     nsz = nthread/(nsx*nsy);
1591
1592                     /* Determine the number of grid points per thread */
1593                     gsize =
1594                         (div_round_up(n[XX],nsx) + ovl)*
1595                         (div_round_up(n[YY],nsy) + ovl)*
1596                         (div_round_up(n[ZZ],nsz) + ovl);
1597
1598                     /* Minimize the number of grids points per thread
1599                      * and, secondarily, the number of cuts in minor dimensions.
1600                      */
1601                     if (gsize_opt == -1 ||
1602                         gsize < gsize_opt ||
1603                         (gsize == gsize_opt &&
1604                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1605                     {
1606                         nsub[XX] = nsx;
1607                         nsub[YY] = nsy;
1608                         nsub[ZZ] = nsz;
1609                         gsize_opt = gsize;
1610                     }
1611                 }
1612             }
1613         }
1614     }
1615
1616     env = getenv("GMX_PME_THREAD_DIVISION");
1617     if (env != NULL)
1618     {
1619         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1620     }
1621
1622     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1623     {
1624         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1625     }
1626 }
1627
1628 static void pmegrids_init(pmegrids_t *grids,
1629                           int nx,int ny,int nz,int nz_base,
1630                           int pme_order,
1631                           int nthread,
1632                           int overlap_x,
1633                           int overlap_y)
1634 {
1635     ivec n,n_base,g0,g1;
1636     int t,x,y,z,d,i,tfac;
1637     int max_comm_lines;
1638
1639     n[XX] = nx - (pme_order - 1);
1640     n[YY] = ny - (pme_order - 1);
1641     n[ZZ] = nz - (pme_order - 1);
1642
1643     copy_ivec(n,n_base);
1644     n_base[ZZ] = nz_base;
1645
1646     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1647                  NULL);
1648
1649     grids->nthread = nthread;
1650
1651     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1652
1653     if (grids->nthread > 1)
1654     {
1655         ivec nst;
1656         int gridsize;
1657         real *grid_all;
1658
1659         for(d=0; d<DIM; d++)
1660         {
1661             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1662         }
1663         set_grid_alignment(&nst[ZZ],pme_order);
1664
1665         if (debug)
1666         {
1667             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1668                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1669             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1670                     nx,ny,nz,
1671                     nst[XX],nst[YY],nst[ZZ]);
1672         }
1673
1674         snew(grids->grid_th,grids->nthread);
1675         t = 0;
1676         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1677         set_gridsize_alignment(&gridsize,pme_order);
1678         snew_aligned(grid_all,
1679                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1680                      16);
1681
1682         for(x=0; x<grids->nc[XX]; x++)
1683         {
1684             for(y=0; y<grids->nc[YY]; y++)
1685             {
1686                 for(z=0; z<grids->nc[ZZ]; z++)
1687                 {
1688                     pmegrid_init(&grids->grid_th[t],
1689                                  x,y,z,
1690                                  (n[XX]*(x  ))/grids->nc[XX],
1691                                  (n[YY]*(y  ))/grids->nc[YY],
1692                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1693                                  (n[XX]*(x+1))/grids->nc[XX],
1694                                  (n[YY]*(y+1))/grids->nc[YY],
1695                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1696                                  TRUE,
1697                                  pme_order,
1698                                  grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1699                     t++;
1700                 }
1701             }
1702         }
1703     }
1704
1705     snew(grids->g2t,DIM);
1706     tfac = 1;
1707     for(d=DIM-1; d>=0; d--)
1708     {
1709         snew(grids->g2t[d],n[d]);
1710         t = 0;
1711         for(i=0; i<n[d]; i++)
1712         {
1713             /* The second check should match the parameters
1714              * of the pmegrid_init call above.
1715              */
1716             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1717             {
1718                 t++;
1719             }
1720             grids->g2t[d][i] = t*tfac;
1721         }
1722
1723         tfac *= grids->nc[d];
1724
1725         switch (d)
1726         {
1727         case XX: max_comm_lines = overlap_x;     break;
1728         case YY: max_comm_lines = overlap_y;     break;
1729         case ZZ: max_comm_lines = pme_order - 1; break;
1730         }
1731         grids->nthread_comm[d] = 0;
1732         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1733         {
1734             grids->nthread_comm[d]++;
1735         }
1736         if (debug != NULL)
1737         {
1738             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1739                     'x'+d,grids->nthread_comm[d]);
1740         }
1741         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1742          * work, but this is not a problematic restriction.
1743          */
1744         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1745         {
1746             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1747         }
1748     }
1749 }
1750
1751
1752 static void pmegrids_destroy(pmegrids_t *grids)
1753 {
1754     int t;
1755
1756     if (grids->grid.grid != NULL)
1757     {
1758         sfree(grids->grid.grid);
1759
1760         if (grids->nthread > 0)
1761         {
1762             for(t=0; t<grids->nthread; t++)
1763             {
1764                 sfree(grids->grid_th[t].grid);
1765             }
1766             sfree(grids->grid_th);
1767         }
1768     }
1769 }
1770
1771
1772 static void realloc_work(pme_work_t *work,int nkx)
1773 {
1774     if (nkx > work->nalloc)
1775     {
1776         work->nalloc = nkx;
1777         srenew(work->mhx  ,work->nalloc);
1778         srenew(work->mhy  ,work->nalloc);
1779         srenew(work->mhz  ,work->nalloc);
1780         srenew(work->m2   ,work->nalloc);
1781         /* Allocate an aligned pointer for SSE operations, including 3 extra
1782          * elements at the end since SSE operates on 4 elements at a time.
1783          */
1784         sfree_aligned(work->denom);
1785         sfree_aligned(work->tmp1);
1786         sfree_aligned(work->eterm);
1787         snew_aligned(work->denom,work->nalloc+3,16);
1788         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1789         snew_aligned(work->eterm,work->nalloc+3,16);
1790         srenew(work->m2inv,work->nalloc);
1791     }
1792 }
1793
1794
1795 static void free_work(pme_work_t *work)
1796 {
1797     sfree(work->mhx);
1798     sfree(work->mhy);
1799     sfree(work->mhz);
1800     sfree(work->m2);
1801     sfree_aligned(work->denom);
1802     sfree_aligned(work->tmp1);
1803     sfree_aligned(work->eterm);
1804     sfree(work->m2inv);
1805 }
1806
1807
1808 #ifdef PME_SSE
1809     /* Calculate exponentials through SSE in float precision */
1810 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1811 {
1812     {
1813         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1814         __m128 f_sse;
1815         __m128 lu;
1816         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1817         int kx;
1818         f_sse = _mm_load1_ps(&f);
1819         for(kx=0; kx<end; kx+=4)
1820         {
1821             tmp_d1   = _mm_load_ps(d_aligned+kx);
1822             lu       = _mm_rcp_ps(tmp_d1);
1823             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1824             tmp_r    = _mm_load_ps(r_aligned+kx);
1825             tmp_r    = gmx_mm_exp_ps(tmp_r);
1826             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1827             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1828             _mm_store_ps(e_aligned+kx,tmp_e);
1829         }
1830     }
1831 }
1832 #else
1833 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1834 {
1835     int kx;
1836     for(kx=start; kx<end; kx++)
1837     {
1838         d[kx] = 1.0/d[kx];
1839     }
1840     for(kx=start; kx<end; kx++)
1841     {
1842         r[kx] = exp(r[kx]);
1843     }
1844     for(kx=start; kx<end; kx++)
1845     {
1846         e[kx] = f*r[kx]*d[kx];
1847     }
1848 }
1849 #endif
1850
1851
1852 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1853                          real ewaldcoeff,real vol,
1854                          gmx_bool bEnerVir,
1855                          int nthread,int thread)
1856 {
1857     /* do recip sum over local cells in grid */
1858     /* y major, z middle, x minor or continuous */
1859     t_complex *p0;
1860     int     kx,ky,kz,maxkx,maxky,maxkz;
1861     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1862     real    mx,my,mz;
1863     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1864     real    ets2,struct2,vfactor,ets2vf;
1865     real    d1,d2,energy=0;
1866     real    by,bz;
1867     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1868     real    rxx,ryx,ryy,rzx,rzy,rzz;
1869     pme_work_t *work;
1870     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1871     real    mhxk,mhyk,mhzk,m2k;
1872     real    corner_fac;
1873     ivec    complex_order;
1874     ivec    local_ndata,local_offset,local_size;
1875     real    elfac;
1876
1877     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1878
1879     nx = pme->nkx;
1880     ny = pme->nky;
1881     nz = pme->nkz;
1882
1883     /* Dimensions should be identical for A/B grid, so we just use A here */
1884     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1885                                       complex_order,
1886                                       local_ndata,
1887                                       local_offset,
1888                                       local_size);
1889
1890     rxx = pme->recipbox[XX][XX];
1891     ryx = pme->recipbox[YY][XX];
1892     ryy = pme->recipbox[YY][YY];
1893     rzx = pme->recipbox[ZZ][XX];
1894     rzy = pme->recipbox[ZZ][YY];
1895     rzz = pme->recipbox[ZZ][ZZ];
1896
1897     maxkx = (nx+1)/2;
1898     maxky = (ny+1)/2;
1899     maxkz = nz/2+1;
1900
1901     work = &pme->work[thread];
1902     mhx   = work->mhx;
1903     mhy   = work->mhy;
1904     mhz   = work->mhz;
1905     m2    = work->m2;
1906     denom = work->denom;
1907     tmp1  = work->tmp1;
1908     eterm = work->eterm;
1909     m2inv = work->m2inv;
1910
1911     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1912     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1913
1914     for(iyz=iyz0; iyz<iyz1; iyz++)
1915     {
1916         iy = iyz/local_ndata[ZZ];
1917         iz = iyz - iy*local_ndata[ZZ];
1918
1919         ky = iy + local_offset[YY];
1920
1921         if (ky < maxky)
1922         {
1923             my = ky;
1924         }
1925         else
1926         {
1927             my = (ky - ny);
1928         }
1929
1930         by = M_PI*vol*pme->bsp_mod[YY][ky];
1931
1932         kz = iz + local_offset[ZZ];
1933
1934         mz = kz;
1935
1936         bz = pme->bsp_mod[ZZ][kz];
1937
1938         /* 0.5 correction for corner points */
1939         corner_fac = 1;
1940         if (kz == 0 || kz == (nz+1)/2)
1941         {
1942             corner_fac = 0.5;
1943         }
1944
1945         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1946
1947         /* We should skip the k-space point (0,0,0) */
1948         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1949         {
1950             kxstart = local_offset[XX];
1951         }
1952         else
1953         {
1954             kxstart = local_offset[XX] + 1;
1955             p0++;
1956         }
1957         kxend = local_offset[XX] + local_ndata[XX];
1958
1959         if (bEnerVir)
1960         {
1961             /* More expensive inner loop, especially because of the storage
1962              * of the mh elements in array's.
1963              * Because x is the minor grid index, all mh elements
1964              * depend on kx for triclinic unit cells.
1965              */
1966
1967                 /* Two explicit loops to avoid a conditional inside the loop */
1968             for(kx=kxstart; kx<maxkx; kx++)
1969             {
1970                 mx = kx;
1971
1972                 mhxk      = mx * rxx;
1973                 mhyk      = mx * ryx + my * ryy;
1974                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1975                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1976                 mhx[kx]   = mhxk;
1977                 mhy[kx]   = mhyk;
1978                 mhz[kx]   = mhzk;
1979                 m2[kx]    = m2k;
1980                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1981                 tmp1[kx]  = -factor*m2k;
1982             }
1983
1984             for(kx=maxkx; kx<kxend; kx++)
1985             {
1986                 mx = (kx - nx);
1987
1988                 mhxk      = mx * rxx;
1989                 mhyk      = mx * ryx + my * ryy;
1990                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1991                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1992                 mhx[kx]   = mhxk;
1993                 mhy[kx]   = mhyk;
1994                 mhz[kx]   = mhzk;
1995                 m2[kx]    = m2k;
1996                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1997                 tmp1[kx]  = -factor*m2k;
1998             }
1999
2000             for(kx=kxstart; kx<kxend; kx++)
2001             {
2002                 m2inv[kx] = 1.0/m2[kx];
2003             }
2004
2005             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2006
2007             for(kx=kxstart; kx<kxend; kx++,p0++)
2008             {
2009                 d1      = p0->re;
2010                 d2      = p0->im;
2011
2012                 p0->re  = d1*eterm[kx];
2013                 p0->im  = d2*eterm[kx];
2014
2015                 struct2 = 2.0*(d1*d1+d2*d2);
2016
2017                 tmp1[kx] = eterm[kx]*struct2;
2018             }
2019
2020             for(kx=kxstart; kx<kxend; kx++)
2021             {
2022                 ets2     = corner_fac*tmp1[kx];
2023                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2024                 energy  += ets2;
2025
2026                 ets2vf   = ets2*vfactor;
2027                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2028                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2029                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2030                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2031                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2032                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2033             }
2034         }
2035         else
2036         {
2037             /* We don't need to calculate the energy and the virial.
2038              * In this case the triclinic overhead is small.
2039              */
2040
2041             /* Two explicit loops to avoid a conditional inside the loop */
2042
2043             for(kx=kxstart; kx<maxkx; kx++)
2044             {
2045                 mx = kx;
2046
2047                 mhxk      = mx * rxx;
2048                 mhyk      = mx * ryx + my * ryy;
2049                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2050                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2051                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2052                 tmp1[kx]  = -factor*m2k;
2053             }
2054
2055             for(kx=maxkx; kx<kxend; kx++)
2056             {
2057                 mx = (kx - nx);
2058
2059                 mhxk      = mx * rxx;
2060                 mhyk      = mx * ryx + my * ryy;
2061                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2062                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2063                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2064                 tmp1[kx]  = -factor*m2k;
2065             }
2066
2067             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2068
2069             for(kx=kxstart; kx<kxend; kx++,p0++)
2070             {
2071                 d1      = p0->re;
2072                 d2      = p0->im;
2073
2074                 p0->re  = d1*eterm[kx];
2075                 p0->im  = d2*eterm[kx];
2076             }
2077         }
2078     }
2079
2080     if (bEnerVir)
2081     {
2082         /* Update virial with local values.
2083          * The virial is symmetric by definition.
2084          * this virial seems ok for isotropic scaling, but I'm
2085          * experiencing problems on semiisotropic membranes.
2086          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2087          */
2088         work->vir[XX][XX] = 0.25*virxx;
2089         work->vir[YY][YY] = 0.25*viryy;
2090         work->vir[ZZ][ZZ] = 0.25*virzz;
2091         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2092         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2093         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2094
2095         /* This energy should be corrected for a charged system */
2096         work->energy = 0.5*energy;
2097     }
2098
2099     /* Return the loop count */
2100     return local_ndata[YY]*local_ndata[XX];
2101 }
2102
2103 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2104                              real *mesh_energy,matrix vir)
2105 {
2106     /* This function sums output over threads
2107      * and should therefore only be called after thread synchronization.
2108      */
2109     int thread;
2110
2111     *mesh_energy = pme->work[0].energy;
2112     copy_mat(pme->work[0].vir,vir);
2113
2114     for(thread=1; thread<nthread; thread++)
2115     {
2116         *mesh_energy += pme->work[thread].energy;
2117         m_add(vir,pme->work[thread].vir,vir);
2118     }
2119 }
2120
2121 #define DO_FSPLINE(order)                      \
2122 for(ithx=0; (ithx<order); ithx++)              \
2123 {                                              \
2124     index_x = (i0+ithx)*pny*pnz;               \
2125     tx      = thx[ithx];                       \
2126     dx      = dthx[ithx];                      \
2127                                                \
2128     for(ithy=0; (ithy<order); ithy++)          \
2129     {                                          \
2130         index_xy = index_x+(j0+ithy)*pnz;      \
2131         ty       = thy[ithy];                  \
2132         dy       = dthy[ithy];                 \
2133         fxy1     = fz1 = 0;                    \
2134                                                \
2135         for(ithz=0; (ithz<order); ithz++)      \
2136         {                                      \
2137             gval  = grid[index_xy+(k0+ithz)];  \
2138             fxy1 += thz[ithz]*gval;            \
2139             fz1  += dthz[ithz]*gval;           \
2140         }                                      \
2141         fx += dx*ty*fxy1;                      \
2142         fy += tx*dy*fxy1;                      \
2143         fz += tx*ty*fz1;                       \
2144     }                                          \
2145 }
2146
2147
2148 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2149                               gmx_bool bClearF,pme_atomcomm_t *atc,
2150                               splinedata_t *spline,
2151                               real scale)
2152 {
2153     /* sum forces for local particles */
2154     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2155     int     index_x,index_xy;
2156     int     nx,ny,nz,pnx,pny,pnz;
2157     int *   idxptr;
2158     real    tx,ty,dx,dy,qn;
2159     real    fx,fy,fz,gval;
2160     real    fxy1,fz1;
2161     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2162     int     norder;
2163     real    rxx,ryx,ryy,rzx,rzy,rzz;
2164     int     order;
2165
2166     pme_spline_work_t *work;
2167
2168     work = pme->spline_work;
2169
2170     order = pme->pme_order;
2171     thx   = spline->theta[XX];
2172     thy   = spline->theta[YY];
2173     thz   = spline->theta[ZZ];
2174     dthx  = spline->dtheta[XX];
2175     dthy  = spline->dtheta[YY];
2176     dthz  = spline->dtheta[ZZ];
2177     nx    = pme->nkx;
2178     ny    = pme->nky;
2179     nz    = pme->nkz;
2180     pnx   = pme->pmegrid_nx;
2181     pny   = pme->pmegrid_ny;
2182     pnz   = pme->pmegrid_nz;
2183
2184     rxx   = pme->recipbox[XX][XX];
2185     ryx   = pme->recipbox[YY][XX];
2186     ryy   = pme->recipbox[YY][YY];
2187     rzx   = pme->recipbox[ZZ][XX];
2188     rzy   = pme->recipbox[ZZ][YY];
2189     rzz   = pme->recipbox[ZZ][ZZ];
2190
2191     for(nn=0; nn<spline->n; nn++)
2192     {
2193         n  = spline->ind[nn];
2194         qn = scale*atc->q[n];
2195
2196         if (bClearF)
2197         {
2198             atc->f[n][XX] = 0;
2199             atc->f[n][YY] = 0;
2200             atc->f[n][ZZ] = 0;
2201         }
2202         if (qn != 0)
2203         {
2204             fx     = 0;
2205             fy     = 0;
2206             fz     = 0;
2207             idxptr = atc->idx[n];
2208             norder = nn*order;
2209
2210             i0   = idxptr[XX];
2211             j0   = idxptr[YY];
2212             k0   = idxptr[ZZ];
2213
2214             /* Pointer arithmetic alert, next six statements */
2215             thx  = spline->theta[XX] + norder;
2216             thy  = spline->theta[YY] + norder;
2217             thz  = spline->theta[ZZ] + norder;
2218             dthx = spline->dtheta[XX] + norder;
2219             dthy = spline->dtheta[YY] + norder;
2220             dthz = spline->dtheta[ZZ] + norder;
2221
2222             switch (order) {
2223             case 4:
2224 #ifdef PME_SSE
2225 #ifdef PME_SSE_UNALIGNED
2226 #define PME_GATHER_F_SSE_ORDER4
2227 #else
2228 #define PME_GATHER_F_SSE_ALIGNED
2229 #define PME_ORDER 4
2230 #endif
2231 #include "pme_sse_single.h"
2232 #else
2233                 DO_FSPLINE(4);
2234 #endif
2235                 break;
2236             case 5:
2237 #ifdef PME_SSE
2238 #define PME_GATHER_F_SSE_ALIGNED
2239 #define PME_ORDER 5
2240 #include "pme_sse_single.h"
2241 #else
2242                 DO_FSPLINE(5);
2243 #endif
2244                 break;
2245             default:
2246                 DO_FSPLINE(order);
2247                 break;
2248             }
2249
2250             atc->f[n][XX] += -qn*( fx*nx*rxx );
2251             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2252             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2253         }
2254     }
2255     /* Since the energy and not forces are interpolated
2256      * the net force might not be exactly zero.
2257      * This can be solved by also interpolating F, but
2258      * that comes at a cost.
2259      * A better hack is to remove the net force every
2260      * step, but that must be done at a higher level
2261      * since this routine doesn't see all atoms if running
2262      * in parallel. Don't know how important it is?  EL 990726
2263      */
2264 }
2265
2266
2267 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2268                                    pme_atomcomm_t *atc)
2269 {
2270     splinedata_t *spline;
2271     int     n,ithx,ithy,ithz,i0,j0,k0;
2272     int     index_x,index_xy;
2273     int *   idxptr;
2274     real    energy,pot,tx,ty,qn,gval;
2275     real    *thx,*thy,*thz;
2276     int     norder;
2277     int     order;
2278
2279     spline = &atc->spline[0];
2280
2281     order = pme->pme_order;
2282
2283     energy = 0;
2284     for(n=0; (n<atc->n); n++) {
2285         qn      = atc->q[n];
2286
2287         if (qn != 0) {
2288             idxptr = atc->idx[n];
2289             norder = n*order;
2290
2291             i0   = idxptr[XX];
2292             j0   = idxptr[YY];
2293             k0   = idxptr[ZZ];
2294
2295             /* Pointer arithmetic alert, next three statements */
2296             thx  = spline->theta[XX] + norder;
2297             thy  = spline->theta[YY] + norder;
2298             thz  = spline->theta[ZZ] + norder;
2299
2300             pot = 0;
2301             for(ithx=0; (ithx<order); ithx++)
2302             {
2303                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2304                 tx      = thx[ithx];
2305
2306                 for(ithy=0; (ithy<order); ithy++)
2307                 {
2308                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2309                     ty       = thy[ithy];
2310
2311                     for(ithz=0; (ithz<order); ithz++)
2312                     {
2313                         gval  = grid[index_xy+(k0+ithz)];
2314                         pot  += tx*ty*thz[ithz]*gval;
2315                     }
2316
2317                 }
2318             }
2319
2320             energy += pot*qn;
2321         }
2322     }
2323
2324     return energy;
2325 }
2326
2327 /* Macro to force loop unrolling by fixing order.
2328  * This gives a significant performance gain.
2329  */
2330 #define CALC_SPLINE(order)                     \
2331 {                                              \
2332     int j,k,l;                                 \
2333     real dr,div;                               \
2334     real data[PME_ORDER_MAX];                  \
2335     real ddata[PME_ORDER_MAX];                 \
2336                                                \
2337     for(j=0; (j<DIM); j++)                     \
2338     {                                          \
2339         dr  = xptr[j];                         \
2340                                                \
2341         /* dr is relative offset from lower cell limit */ \
2342         data[order-1] = 0;                     \
2343         data[1] = dr;                          \
2344         data[0] = 1 - dr;                      \
2345                                                \
2346         for(k=3; (k<order); k++)               \
2347         {                                      \
2348             div = 1.0/(k - 1.0);               \
2349             data[k-1] = div*dr*data[k-2];      \
2350             for(l=1; (l<(k-1)); l++)           \
2351             {                                  \
2352                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2353                                    data[k-l-1]);                \
2354             }                                  \
2355             data[0] = div*(1-dr)*data[0];      \
2356         }                                      \
2357         /* differentiate */                    \
2358         ddata[0] = -data[0];                   \
2359         for(k=1; (k<order); k++)               \
2360         {                                      \
2361             ddata[k] = data[k-1] - data[k];    \
2362         }                                      \
2363                                                \
2364         div = 1.0/(order - 1);                 \
2365         data[order-1] = div*dr*data[order-2];  \
2366         for(l=1; (l<(order-1)); l++)           \
2367         {                                      \
2368             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2369                                (order-l-dr)*data[order-l-1]); \
2370         }                                      \
2371         data[0] = div*(1 - dr)*data[0];        \
2372                                                \
2373         for(k=0; k<order; k++)                 \
2374         {                                      \
2375             theta[j][i*order+k]  = data[k];    \
2376             dtheta[j][i*order+k] = ddata[k];   \
2377         }                                      \
2378     }                                          \
2379 }
2380
2381 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2382                    rvec fractx[],int nr,int ind[],real charge[],
2383                    gmx_bool bFreeEnergy)
2384 {
2385     /* construct splines for local atoms */
2386     int  i,ii;
2387     real *xptr;
2388
2389     for(i=0; i<nr; i++)
2390     {
2391         /* With free energy we do not use the charge check.
2392          * In most cases this will be more efficient than calling make_bsplines
2393          * twice, since usually more than half the particles have charges.
2394          */
2395         ii = ind[i];
2396         if (bFreeEnergy || charge[ii] != 0.0) {
2397             xptr = fractx[ii];
2398             switch(order) {
2399             case 4:  CALC_SPLINE(4);     break;
2400             case 5:  CALC_SPLINE(5);     break;
2401             default: CALC_SPLINE(order); break;
2402             }
2403         }
2404     }
2405 }
2406
2407
2408 void make_dft_mod(real *mod,real *data,int ndata)
2409 {
2410   int i,j;
2411   real sc,ss,arg;
2412
2413   for(i=0;i<ndata;i++) {
2414     sc=ss=0;
2415     for(j=0;j<ndata;j++) {
2416       arg=(2.0*M_PI*i*j)/ndata;
2417       sc+=data[j]*cos(arg);
2418       ss+=data[j]*sin(arg);
2419     }
2420     mod[i]=sc*sc+ss*ss;
2421   }
2422   for(i=0;i<ndata;i++)
2423     if(mod[i]<1e-7)
2424       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2425 }
2426
2427
2428 static void make_bspline_moduli(splinevec bsp_mod,
2429                                 int nx,int ny,int nz,int order)
2430 {
2431   int nmax=max(nx,max(ny,nz));
2432   real *data,*ddata,*bsp_data;
2433   int i,k,l;
2434   real div;
2435
2436   snew(data,order);
2437   snew(ddata,order);
2438   snew(bsp_data,nmax);
2439
2440   data[order-1]=0;
2441   data[1]=0;
2442   data[0]=1;
2443
2444   for(k=3;k<order;k++) {
2445     div=1.0/(k-1.0);
2446     data[k-1]=0;
2447     for(l=1;l<(k-1);l++)
2448       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2449     data[0]=div*data[0];
2450   }
2451   /* differentiate */
2452   ddata[0]=-data[0];
2453   for(k=1;k<order;k++)
2454     ddata[k]=data[k-1]-data[k];
2455   div=1.0/(order-1);
2456   data[order-1]=0;
2457   for(l=1;l<(order-1);l++)
2458     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2459   data[0]=div*data[0];
2460
2461   for(i=0;i<nmax;i++)
2462     bsp_data[i]=0;
2463   for(i=1;i<=order;i++)
2464     bsp_data[i]=data[i-1];
2465
2466   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2467   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2468   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2469
2470   sfree(data);
2471   sfree(ddata);
2472   sfree(bsp_data);
2473 }
2474
2475
2476 /* Return the P3M optimal influence function */
2477 static double do_p3m_influence(double z, int order)
2478 {
2479     double z2,z4;
2480
2481     z2 = z*z;
2482     z4 = z2*z2;
2483
2484     /* The formula and most constants can be found in:
2485      * Ballenegger et al., JCTC 8, 936 (2012)
2486      */
2487     switch(order)
2488     {
2489     case 2:
2490         return 1.0 - 2.0*z2/3.0;
2491         break;
2492     case 3:
2493         return 1.0 - z2 + 2.0*z4/15.0;
2494         break;
2495     case 4:
2496         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2497         break;
2498     case 5:
2499         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2500         break;
2501     case 6:
2502         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2503         break;
2504     case 7:
2505         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2506     case 8:
2507         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2508         break;
2509     }
2510
2511     return 0.0;
2512 }
2513
2514 /* Calculate the P3M B-spline moduli for one dimension */
2515 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2516 {
2517     double zarg,zai,sinzai,infl;
2518     int    maxk,i;
2519
2520     if (order > 8)
2521     {
2522         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2523     }
2524
2525     zarg = M_PI/n;
2526
2527     maxk = (n + 1)/2;
2528
2529     for(i=-maxk; i<0; i++)
2530     {
2531         zai    = zarg*i;
2532         sinzai = sin(zai);
2533         infl   = do_p3m_influence(sinzai,order);
2534         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2535     }
2536     bsp_mod[0] = 1.0;
2537     for(i=1; i<maxk; i++)
2538     {
2539         zai    = zarg*i;
2540         sinzai = sin(zai);
2541         infl   = do_p3m_influence(sinzai,order);
2542         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2543     }
2544 }
2545
2546 /* Calculate the P3M B-spline moduli */
2547 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2548                                     int nx,int ny,int nz,int order)
2549 {
2550     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2551     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2552     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2553 }
2554
2555
2556 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2557 {
2558   int nslab,n,i;
2559   int fw,bw;
2560
2561   nslab = atc->nslab;
2562
2563   n = 0;
2564   for(i=1; i<=nslab/2; i++) {
2565     fw = (atc->nodeid + i) % nslab;
2566     bw = (atc->nodeid - i + nslab) % nslab;
2567     if (n < nslab - 1) {
2568       atc->node_dest[n] = fw;
2569       atc->node_src[n]  = bw;
2570       n++;
2571     }
2572     if (n < nslab - 1) {
2573       atc->node_dest[n] = bw;
2574       atc->node_src[n]  = fw;
2575       n++;
2576     }
2577   }
2578 }
2579
2580 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2581 {
2582     int thread;
2583
2584     if(NULL != log)
2585     {
2586         fprintf(log,"Destroying PME data structures.\n");
2587     }
2588
2589     sfree((*pmedata)->nnx);
2590     sfree((*pmedata)->nny);
2591     sfree((*pmedata)->nnz);
2592
2593     pmegrids_destroy(&(*pmedata)->pmegridA);
2594
2595     sfree((*pmedata)->fftgridA);
2596     sfree((*pmedata)->cfftgridA);
2597     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2598
2599     if ((*pmedata)->pmegridB.grid.grid != NULL)
2600     {
2601         pmegrids_destroy(&(*pmedata)->pmegridB);
2602         sfree((*pmedata)->fftgridB);
2603         sfree((*pmedata)->cfftgridB);
2604         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2605     }
2606     for(thread=0; thread<(*pmedata)->nthread; thread++)
2607     {
2608         free_work(&(*pmedata)->work[thread]);
2609     }
2610     sfree((*pmedata)->work);
2611
2612     sfree(*pmedata);
2613     *pmedata = NULL;
2614
2615   return 0;
2616 }
2617
2618 static int mult_up(int n,int f)
2619 {
2620     return ((n + f - 1)/f)*f;
2621 }
2622
2623
2624 static double pme_load_imbalance(gmx_pme_t pme)
2625 {
2626     int    nma,nmi;
2627     double n1,n2,n3;
2628
2629     nma = pme->nnodes_major;
2630     nmi = pme->nnodes_minor;
2631
2632     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2633     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2634     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2635
2636     /* pme_solve is roughly double the cost of an fft */
2637
2638     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2639 }
2640
2641 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2642                           int dimind,gmx_bool bSpread)
2643 {
2644     int nk,k,s,thread;
2645
2646     atc->dimind = dimind;
2647     atc->nslab  = 1;
2648     atc->nodeid = 0;
2649     atc->pd_nalloc = 0;
2650 #ifdef GMX_MPI
2651     if (pme->nnodes > 1)
2652     {
2653         atc->mpi_comm = pme->mpi_comm_d[dimind];
2654         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2655         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2656     }
2657     if (debug)
2658     {
2659         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2660     }
2661 #endif
2662
2663     atc->bSpread   = bSpread;
2664     atc->pme_order = pme->pme_order;
2665
2666     if (atc->nslab > 1)
2667     {
2668         /* These three allocations are not required for particle decomp. */
2669         snew(atc->node_dest,atc->nslab);
2670         snew(atc->node_src,atc->nslab);
2671         setup_coordinate_communication(atc);
2672
2673         snew(atc->count_thread,pme->nthread);
2674         for(thread=0; thread<pme->nthread; thread++)
2675         {
2676             snew(atc->count_thread[thread],atc->nslab);
2677         }
2678         atc->count = atc->count_thread[0];
2679         snew(atc->rcount,atc->nslab);
2680         snew(atc->buf_index,atc->nslab);
2681     }
2682
2683     atc->nthread = pme->nthread;
2684     if (atc->nthread > 1)
2685     {
2686         snew(atc->thread_plist,atc->nthread);
2687     }
2688     snew(atc->spline,atc->nthread);
2689     for(thread=0; thread<atc->nthread; thread++)
2690     {
2691         if (atc->nthread > 1)
2692         {
2693             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2694             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2695         }
2696     }
2697 }
2698
2699 static void
2700 init_overlap_comm(pme_overlap_t *  ol,
2701                   int              norder,
2702 #ifdef GMX_MPI
2703                   MPI_Comm         comm,
2704 #endif
2705                   int              nnodes,
2706                   int              nodeid,
2707                   int              ndata,
2708                   int              commplainsize)
2709 {
2710     int lbnd,rbnd,maxlr,b,i;
2711     int exten;
2712     int nn,nk;
2713     pme_grid_comm_t *pgc;
2714     gmx_bool bCont;
2715     int fft_start,fft_end,send_index1,recv_index1;
2716
2717 #ifdef GMX_MPI
2718     ol->mpi_comm = comm;
2719 #endif
2720
2721     ol->nnodes = nnodes;
2722     ol->nodeid = nodeid;
2723
2724     /* Linear translation of the PME grid wo'nt affect reciprocal space
2725      * calculations, so to optimize we only interpolate "upwards",
2726      * which also means we only have to consider overlap in one direction.
2727      * I.e., particles on this node might also be spread to grid indices
2728      * that belong to higher nodes (modulo nnodes)
2729      */
2730
2731     snew(ol->s2g0,ol->nnodes+1);
2732     snew(ol->s2g1,ol->nnodes);
2733     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2734     for(i=0; i<nnodes; i++)
2735     {
2736         /* s2g0 the local interpolation grid start.
2737          * s2g1 the local interpolation grid end.
2738          * Because grid overlap communication only goes forward,
2739          * the grid the slabs for fft's should be rounded down.
2740          */
2741         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2742         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2743
2744         if (debug)
2745         {
2746             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2747         }
2748     }
2749     ol->s2g0[nnodes] = ndata;
2750     if (debug) { fprintf(debug,"\n"); }
2751
2752     /* Determine with how many nodes we need to communicate the grid overlap */
2753     b = 0;
2754     do
2755     {
2756         b++;
2757         bCont = FALSE;
2758         for(i=0; i<nnodes; i++)
2759         {
2760             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2761                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2762             {
2763                 bCont = TRUE;
2764             }
2765         }
2766     }
2767     while (bCont && b < nnodes);
2768     ol->noverlap_nodes = b - 1;
2769
2770     snew(ol->send_id,ol->noverlap_nodes);
2771     snew(ol->recv_id,ol->noverlap_nodes);
2772     for(b=0; b<ol->noverlap_nodes; b++)
2773     {
2774         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2775         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2776     }
2777     snew(ol->comm_data, ol->noverlap_nodes);
2778
2779     for(b=0; b<ol->noverlap_nodes; b++)
2780     {
2781         pgc = &ol->comm_data[b];
2782         /* Send */
2783         fft_start        = ol->s2g0[ol->send_id[b]];
2784         fft_end          = ol->s2g0[ol->send_id[b]+1];
2785         if (ol->send_id[b] < nodeid)
2786         {
2787             fft_start += ndata;
2788             fft_end   += ndata;
2789         }
2790         send_index1      = ol->s2g1[nodeid];
2791         send_index1      = min(send_index1,fft_end);
2792         pgc->send_index0 = fft_start;
2793         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2794
2795         /* We always start receiving to the first index of our slab */
2796         fft_start        = ol->s2g0[ol->nodeid];
2797         fft_end          = ol->s2g0[ol->nodeid+1];
2798         recv_index1      = ol->s2g1[ol->recv_id[b]];
2799         if (ol->recv_id[b] > nodeid)
2800         {
2801             recv_index1 -= ndata;
2802         }
2803         recv_index1      = min(recv_index1,fft_end);
2804         pgc->recv_index0 = fft_start;
2805         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2806     }
2807
2808     /* For non-divisible grid we need pme_order iso pme_order-1 */
2809     snew(ol->sendbuf,norder*commplainsize);
2810     snew(ol->recvbuf,norder*commplainsize);
2811 }
2812
2813 static void
2814 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2815                               int **global_to_local,
2816                               real **fraction_shift)
2817 {
2818     int i;
2819     int * gtl;
2820     real * fsh;
2821
2822     snew(gtl,5*n);
2823     snew(fsh,5*n);
2824     for(i=0; (i<5*n); i++)
2825     {
2826         /* Determine the global to local grid index */
2827         gtl[i] = (i - local_start + n) % n;
2828         /* For coordinates that fall within the local grid the fraction
2829          * is correct, we don't need to shift it.
2830          */
2831         fsh[i] = 0;
2832         if (local_range < n)
2833         {
2834             /* Due to rounding issues i could be 1 beyond the lower or
2835              * upper boundary of the local grid. Correct the index for this.
2836              * If we shift the index, we need to shift the fraction by
2837              * the same amount in the other direction to not affect
2838              * the weights.
2839              * Note that due to this shifting the weights at the end of
2840              * the spline might change, but that will only involve values
2841              * between zero and values close to the precision of a real,
2842              * which is anyhow the accuracy of the whole mesh calculation.
2843              */
2844             /* With local_range=0 we should not change i=local_start */
2845             if (i % n != local_start)
2846             {
2847                 if (gtl[i] == n-1)
2848                 {
2849                     gtl[i] = 0;
2850                     fsh[i] = -1;
2851                 }
2852                 else if (gtl[i] == local_range)
2853                 {
2854                     gtl[i] = local_range - 1;
2855                     fsh[i] = 1;
2856                 }
2857             }
2858         }
2859     }
2860
2861     *global_to_local = gtl;
2862     *fraction_shift  = fsh;
2863 }
2864
2865 static pme_spline_work_t *make_pme_spline_work(int order)
2866 {
2867     pme_spline_work_t *work;
2868
2869 #ifdef PME_SSE
2870     float  tmp[8];
2871     __m128 zero_SSE;
2872     int    of,i;
2873
2874     snew_aligned(work,1,16);
2875
2876     zero_SSE = _mm_setzero_ps();
2877
2878     /* Generate bit masks to mask out the unused grid entries,
2879      * as we only operate on order of the 8 grid entries that are
2880      * load into 2 SSE float registers.
2881      */
2882     for(of=0; of<8-(order-1); of++)
2883     {
2884         for(i=0; i<8; i++)
2885         {
2886             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2887         }
2888         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2889         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2890         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2891         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2892     }
2893 #else
2894     work = NULL;
2895 #endif
2896
2897     return work;
2898 }
2899
2900 static void
2901 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2902 {
2903     int nk_new;
2904
2905     if (*nk % nnodes != 0)
2906     {
2907         nk_new = nnodes*(*nk/nnodes + 1);
2908
2909         if (2*nk_new >= 3*(*nk))
2910         {
2911             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2912                       dim,*nk,dim,nnodes,dim);
2913         }
2914
2915         if (fplog != NULL)
2916         {
2917             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2918                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2919         }
2920
2921         *nk = nk_new;
2922     }
2923 }
2924
2925 int gmx_pme_init(gmx_pme_t *         pmedata,
2926                  t_commrec *         cr,
2927                  int                 nnodes_major,
2928                  int                 nnodes_minor,
2929                  t_inputrec *        ir,
2930                  int                 homenr,
2931                  gmx_bool            bFreeEnergy,
2932                  gmx_bool            bReproducible,
2933                  int                 nthread)
2934 {
2935     gmx_pme_t pme=NULL;
2936
2937     pme_atomcomm_t *atc;
2938     ivec ndata;
2939
2940     if (debug)
2941         fprintf(debug,"Creating PME data structures.\n");
2942     snew(pme,1);
2943
2944     pme->redist_init         = FALSE;
2945     pme->sum_qgrid_tmp       = NULL;
2946     pme->sum_qgrid_dd_tmp    = NULL;
2947     pme->buf_nalloc          = 0;
2948     pme->redist_buf_nalloc   = 0;
2949
2950     pme->nnodes              = 1;
2951     pme->bPPnode             = TRUE;
2952
2953     pme->nnodes_major        = nnodes_major;
2954     pme->nnodes_minor        = nnodes_minor;
2955
2956 #ifdef GMX_MPI
2957     if (nnodes_major*nnodes_minor > 1)
2958     {
2959         pme->mpi_comm = cr->mpi_comm_mygroup;
2960
2961         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2962         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2963         if (pme->nnodes != nnodes_major*nnodes_minor)
2964         {
2965             gmx_incons("PME node count mismatch");
2966         }
2967     }
2968     else
2969     {
2970         pme->mpi_comm = MPI_COMM_NULL;
2971     }
2972 #endif
2973
2974     if (pme->nnodes == 1)
2975     {
2976 #ifdef GMX_MPI
2977         pme->mpi_comm_d[0] = MPI_COMM_NULL;
2978         pme->mpi_comm_d[1] = MPI_COMM_NULL;
2979 #endif
2980         pme->ndecompdim = 0;
2981         pme->nodeid_major = 0;
2982         pme->nodeid_minor = 0;
2983 #ifdef GMX_MPI
2984         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2985 #endif
2986     }
2987     else
2988     {
2989         if (nnodes_minor == 1)
2990         {
2991 #ifdef GMX_MPI
2992             pme->mpi_comm_d[0] = pme->mpi_comm;
2993             pme->mpi_comm_d[1] = MPI_COMM_NULL;
2994 #endif
2995             pme->ndecompdim = 1;
2996             pme->nodeid_major = pme->nodeid;
2997             pme->nodeid_minor = 0;
2998
2999         }
3000         else if (nnodes_major == 1)
3001         {
3002 #ifdef GMX_MPI
3003             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3004             pme->mpi_comm_d[1] = pme->mpi_comm;
3005 #endif
3006             pme->ndecompdim = 1;
3007             pme->nodeid_major = 0;
3008             pme->nodeid_minor = pme->nodeid;
3009         }
3010         else
3011         {
3012             if (pme->nnodes % nnodes_major != 0)
3013             {
3014                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3015             }
3016             pme->ndecompdim = 2;
3017
3018 #ifdef GMX_MPI
3019             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3020                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3021             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3022                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3023
3024             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3025             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3026             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3027             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3028 #endif
3029         }
3030         pme->bPPnode = (cr->duty & DUTY_PP);
3031     }
3032
3033     pme->nthread = nthread;
3034
3035     if (ir->ePBC == epbcSCREW)
3036     {
3037         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3038     }
3039
3040     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3041     pme->nkx         = ir->nkx;
3042     pme->nky         = ir->nky;
3043     pme->nkz         = ir->nkz;
3044     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3045     pme->pme_order   = ir->pme_order;
3046     pme->epsilon_r   = ir->epsilon_r;
3047
3048     if (pme->pme_order > PME_ORDER_MAX)
3049     {
3050         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3051                   pme->pme_order,PME_ORDER_MAX);
3052     }
3053
3054     /* Currently pme.c supports only the fft5d FFT code.
3055      * Therefore the grid always needs to be divisible by nnodes.
3056      * When the old 1D code is also supported again, change this check.
3057      *
3058      * This check should be done before calling gmx_pme_init
3059      * and fplog should be passed iso stderr.
3060      *
3061     if (pme->ndecompdim >= 2)
3062     */
3063     if (pme->ndecompdim >= 1)
3064     {
3065         /*
3066         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3067                                         'x',nnodes_major,&pme->nkx);
3068         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3069                                         'y',nnodes_minor,&pme->nky);
3070         */
3071     }
3072
3073     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3074         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3075         pme->nkz <= pme->pme_order)
3076     {
3077         gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3078     }
3079
3080     if (pme->nnodes > 1) {
3081         double imbal;
3082
3083 #ifdef GMX_MPI
3084         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3085         MPI_Type_commit(&(pme->rvec_mpi));
3086 #endif
3087
3088         /* Note that the charge spreading and force gathering, which usually
3089          * takes about the same amount of time as FFT+solve_pme,
3090          * is always fully load balanced
3091          * (unless the charge distribution is inhomogeneous).
3092          */
3093
3094         imbal = pme_load_imbalance(pme);
3095         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3096         {
3097             fprintf(stderr,
3098                     "\n"
3099                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3100                     "      For optimal PME load balancing\n"
3101                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3102                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3103                     "\n",
3104                     (int)((imbal-1)*100 + 0.5),
3105                     pme->nkx,pme->nky,pme->nnodes_major,
3106                     pme->nky,pme->nkz,pme->nnodes_minor);
3107         }
3108     }
3109
3110     /* For non-divisible grid we need pme_order iso pme_order-1 */
3111     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3112      * y is always copied through a buffer: we don't need padding in z,
3113      * but we do need the overlap in x because of the communication order.
3114      */
3115     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3116 #ifdef GMX_MPI
3117                       pme->mpi_comm_d[0],
3118 #endif
3119                       pme->nnodes_major,pme->nodeid_major,
3120                       pme->nkx,
3121                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3122
3123     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3124 #ifdef GMX_MPI
3125                       pme->mpi_comm_d[1],
3126 #endif
3127                       pme->nnodes_minor,pme->nodeid_minor,
3128                       pme->nky,
3129                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3130
3131     /* Check for a limitation of the (current) sum_fftgrid_dd code */
3132     if (pme->nthread > 1 &&
3133         (pme->overlap[0].noverlap_nodes > 1 ||
3134          pme->overlap[1].noverlap_nodes > 1))
3135     {
3136         gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3137     }
3138
3139     snew(pme->bsp_mod[XX],pme->nkx);
3140     snew(pme->bsp_mod[YY],pme->nky);
3141     snew(pme->bsp_mod[ZZ],pme->nkz);
3142
3143     /* The required size of the interpolation grid, including overlap.
3144      * The allocated size (pmegrid_n?) might be slightly larger.
3145      */
3146     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3147                       pme->overlap[0].s2g0[pme->nodeid_major];
3148     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3149                       pme->overlap[1].s2g0[pme->nodeid_minor];
3150     pme->pmegrid_nz_base = pme->nkz;
3151     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3152     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3153
3154     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3155     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3156     pme->pmegrid_start_iz = 0;
3157
3158     make_gridindex5_to_localindex(pme->nkx,
3159                                   pme->pmegrid_start_ix,
3160                                   pme->pmegrid_nx - (pme->pme_order-1),
3161                                   &pme->nnx,&pme->fshx);
3162     make_gridindex5_to_localindex(pme->nky,
3163                                   pme->pmegrid_start_iy,
3164                                   pme->pmegrid_ny - (pme->pme_order-1),
3165                                   &pme->nny,&pme->fshy);
3166     make_gridindex5_to_localindex(pme->nkz,
3167                                   pme->pmegrid_start_iz,
3168                                   pme->pmegrid_nz_base,
3169                                   &pme->nnz,&pme->fshz);
3170
3171     pmegrids_init(&pme->pmegridA,
3172                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3173                   pme->pmegrid_nz_base,
3174                   pme->pme_order,
3175                   pme->nthread,
3176                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3177                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3178
3179     pme->spline_work = make_pme_spline_work(pme->pme_order);
3180
3181     ndata[0] = pme->nkx;
3182     ndata[1] = pme->nky;
3183     ndata[2] = pme->nkz;
3184
3185     /* This routine will allocate the grid data to fit the FFTs */
3186     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3187                             &pme->fftgridA,&pme->cfftgridA,
3188                             pme->mpi_comm_d,
3189                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3190                             bReproducible,pme->nthread);
3191
3192     if (bFreeEnergy)
3193     {
3194         pmegrids_init(&pme->pmegridB,
3195                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3196                       pme->pmegrid_nz_base,
3197                       pme->pme_order,
3198                       pme->nthread,
3199                       pme->nkx % pme->nnodes_major != 0,
3200                       pme->nky % pme->nnodes_minor != 0);
3201
3202         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3203                                 &pme->fftgridB,&pme->cfftgridB,
3204                                 pme->mpi_comm_d,
3205                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3206                                 bReproducible,pme->nthread);
3207     }
3208     else
3209     {
3210         pme->pmegridB.grid.grid = NULL;
3211         pme->fftgridB           = NULL;
3212         pme->cfftgridB          = NULL;
3213     }
3214
3215     if (!pme->bP3M)
3216     {
3217         /* Use plain SPME B-spline interpolation */
3218         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3219     }
3220     else
3221     {
3222         /* Use the P3M grid-optimized influence function */
3223         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3224     }
3225
3226     /* Use atc[0] for spreading */
3227     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3228     if (pme->ndecompdim >= 2)
3229     {
3230         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3231     }
3232
3233     if (pme->nnodes == 1) {
3234         pme->atc[0].n = homenr;
3235         pme_realloc_atomcomm_things(&pme->atc[0]);
3236     }
3237
3238     {
3239         int thread;
3240
3241         /* Use fft5d, order after FFT is y major, z, x minor */
3242
3243         snew(pme->work,pme->nthread);
3244         for(thread=0; thread<pme->nthread; thread++)
3245         {
3246             realloc_work(&pme->work[thread],pme->nkx);
3247         }
3248     }
3249
3250     *pmedata = pme;
3251
3252     return 0;
3253 }
3254
3255
3256 static void copy_local_grid(gmx_pme_t pme,
3257                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3258 {
3259     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3260     int  fft_my,fft_mz;
3261     int  nsx,nsy,nsz;
3262     ivec nf;
3263     int  offx,offy,offz,x,y,z,i0,i0t;
3264     int  d;
3265     pmegrid_t *pmegrid;
3266     real *grid_th;
3267
3268     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3269                                    local_fft_ndata,
3270                                    local_fft_offset,
3271                                    local_fft_size);
3272     fft_my = local_fft_size[YY];
3273     fft_mz = local_fft_size[ZZ];
3274
3275     pmegrid = &pmegrids->grid_th[thread];
3276
3277     nsx = pmegrid->n[XX];
3278     nsy = pmegrid->n[YY];
3279     nsz = pmegrid->n[ZZ];
3280
3281     for(d=0; d<DIM; d++)
3282     {
3283         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3284                     local_fft_ndata[d] - pmegrid->offset[d]);
3285     }
3286
3287     offx = pmegrid->offset[XX];
3288     offy = pmegrid->offset[YY];
3289     offz = pmegrid->offset[ZZ];
3290
3291     /* Directly copy the non-overlapping parts of the local grids.
3292      * This also initializes the full grid.
3293      */
3294     grid_th = pmegrid->grid;
3295     for(x=0; x<nf[XX]; x++)
3296     {
3297         for(y=0; y<nf[YY]; y++)
3298         {
3299             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3300             i0t = (x*nsy + y)*nsz;
3301             for(z=0; z<nf[ZZ]; z++)
3302             {
3303                 fftgrid[i0+z] = grid_th[i0t+z];
3304             }
3305         }
3306     }
3307 }
3308
3309 static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3310 {
3311     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3312     pme_overlap_t *overlap;
3313     int datasize,nind;
3314     int i,x,y,z,n;
3315
3316     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3317                                    local_fft_ndata,
3318                                    local_fft_offset,
3319                                    local_fft_size);
3320     /* Major dimension */
3321     overlap = &pme->overlap[0];
3322
3323     nind   = overlap->comm_data[0].send_nindex;
3324
3325     for(y=0; y<local_fft_ndata[YY]; y++) {
3326          printf(" %2d",y);
3327     }
3328     printf("\n");
3329
3330     i = 0;
3331     for(x=0; x<nind; x++) {
3332         for(y=0; y<local_fft_ndata[YY]; y++) {
3333             n = 0;
3334             for(z=0; z<local_fft_ndata[ZZ]; z++) {
3335                 if (sendbuf[i] != 0) n++;
3336                 i++;
3337             }
3338             printf(" %2d",n);
3339         }
3340         printf("\n");
3341     }
3342 }
3343
3344 static void
3345 reduce_threadgrid_overlap(gmx_pme_t pme,
3346                           const pmegrids_t *pmegrids,int thread,
3347                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3348 {
3349     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3350     int  fft_nx,fft_ny,fft_nz;
3351     int  fft_my,fft_mz;
3352     int  buf_my=-1;
3353     int  nsx,nsy,nsz;
3354     ivec ne;
3355     int  offx,offy,offz,x,y,z,i0,i0t;
3356     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3357     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3358     gmx_bool bCommX,bCommY;
3359     int  d;
3360     int  thread_f;
3361     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3362     const real *grid_th;
3363     real *commbuf=NULL;
3364
3365     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3366                                    local_fft_ndata,
3367                                    local_fft_offset,
3368                                    local_fft_size);
3369     fft_nx = local_fft_ndata[XX];
3370     fft_ny = local_fft_ndata[YY];
3371     fft_nz = local_fft_ndata[ZZ];
3372
3373     fft_my = local_fft_size[YY];
3374     fft_mz = local_fft_size[ZZ];
3375
3376     /* This routine is called when all thread have finished spreading.
3377      * Here each thread sums grid contributions calculated by other threads
3378      * to the thread local grid volume.
3379      * To minimize the number of grid copying operations,
3380      * this routines sums immediately from the pmegrid to the fftgrid.
3381      */
3382
3383     /* Determine which part of the full node grid we should operate on,
3384      * this is our thread local part of the full grid.
3385      */
3386     pmegrid = &pmegrids->grid_th[thread];
3387
3388     for(d=0; d<DIM; d++)
3389     {
3390         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3391                     local_fft_ndata[d]);
3392     }
3393
3394     offx = pmegrid->offset[XX];
3395     offy = pmegrid->offset[YY];
3396     offz = pmegrid->offset[ZZ];
3397
3398
3399     bClearBufX  = TRUE;
3400     bClearBufY  = TRUE;
3401     bClearBufXY = TRUE;
3402
3403     /* Now loop over all the thread data blocks that contribute
3404      * to the grid region we (our thread) are operating on.
3405      */
3406     /* Note that ffy_nx/y is equal to the number of grid points
3407      * between the first point of our node grid and the one of the next node.
3408      */
3409     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3410     {
3411         fx = pmegrid->ci[XX] + sx;
3412         ox = 0;
3413         bCommX = FALSE;
3414         if (fx < 0) {
3415             fx += pmegrids->nc[XX];
3416             ox -= fft_nx;
3417             bCommX = (pme->nnodes_major > 1);
3418         }
3419         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3420         ox += pmegrid_g->offset[XX];
3421         if (!bCommX)
3422         {
3423             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3424         }
3425         else
3426         {
3427             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3428         }
3429
3430         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3431         {
3432             fy = pmegrid->ci[YY] + sy;
3433             oy = 0;
3434             bCommY = FALSE;
3435             if (fy < 0) {
3436                 fy += pmegrids->nc[YY];
3437                 oy -= fft_ny;
3438                 bCommY = (pme->nnodes_minor > 1);
3439             }
3440             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3441             oy += pmegrid_g->offset[YY];
3442             if (!bCommY)
3443             {
3444                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3445             }
3446             else
3447             {
3448                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3449             }
3450
3451             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3452             {
3453                 fz = pmegrid->ci[ZZ] + sz;
3454                 oz = 0;
3455                 if (fz < 0)
3456                 {
3457                     fz += pmegrids->nc[ZZ];
3458                     oz -= fft_nz;
3459                 }
3460                 pmegrid_g = &pmegrids->grid_th[fz];
3461                 oz += pmegrid_g->offset[ZZ];
3462                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3463
3464                 if (sx == 0 && sy == 0 && sz == 0)
3465                 {
3466                     /* We have already added our local contribution
3467                      * before calling this routine, so skip it here.
3468                      */
3469                     continue;
3470                 }
3471
3472                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3473
3474                 pmegrid_f = &pmegrids->grid_th[thread_f];
3475
3476                 grid_th = pmegrid_f->grid;
3477
3478                 nsx = pmegrid_f->n[XX];
3479                 nsy = pmegrid_f->n[YY];
3480                 nsz = pmegrid_f->n[ZZ];
3481
3482 #ifdef DEBUG_PME_REDUCE
3483                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3484                        pme->nodeid,thread,thread_f,
3485                        pme->pmegrid_start_ix,
3486                        pme->pmegrid_start_iy,
3487                        pme->pmegrid_start_iz,
3488                        sx,sy,sz,
3489                        offx-ox,tx1-ox,offx,tx1,
3490                        offy-oy,ty1-oy,offy,ty1,
3491                        offz-oz,tz1-oz,offz,tz1);
3492 #endif
3493
3494                 if (!(bCommX || bCommY))
3495                 {
3496                     /* Copy from the thread local grid to the node grid */
3497                     for(x=offx; x<tx1; x++)
3498                     {
3499                         for(y=offy; y<ty1; y++)
3500                         {
3501                             i0  = (x*fft_my + y)*fft_mz;
3502                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3503                             for(z=offz; z<tz1; z++)
3504                             {
3505                                 fftgrid[i0+z] += grid_th[i0t+z];
3506                             }
3507                         }
3508                     }
3509                 }
3510                 else
3511                 {
3512                     /* The order of this conditional decides
3513                      * where the corner volume gets stored with x+y decomp.
3514                      */
3515                     if (bCommY)
3516                     {
3517                         commbuf = commbuf_y;
3518                         buf_my  = ty1 - offy;
3519                         if (bCommX)
3520                         {
3521                             /* We index commbuf modulo the local grid size */
3522                             commbuf += buf_my*fft_nx*fft_nz;
3523
3524                             bClearBuf  = bClearBufXY;
3525                             bClearBufXY = FALSE;
3526                         }
3527                         else
3528                         {
3529                             bClearBuf  = bClearBufY;
3530                             bClearBufY = FALSE;
3531                         }
3532                     }
3533                     else
3534                     {
3535                         commbuf = commbuf_x;
3536                         buf_my  = fft_ny;
3537                         bClearBuf  = bClearBufX;
3538                         bClearBufX = FALSE;
3539                     }
3540
3541                     /* Copy to the communication buffer */
3542                     for(x=offx; x<tx1; x++)
3543                     {
3544                         for(y=offy; y<ty1; y++)
3545                         {
3546                             i0  = (x*buf_my + y)*fft_nz;
3547                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3548
3549                             if (bClearBuf)
3550                             {
3551                                 /* First access of commbuf, initialize it */
3552                                 for(z=offz; z<tz1; z++)
3553                                 {
3554                                     commbuf[i0+z]  = grid_th[i0t+z];
3555                                 }
3556                             }
3557                             else
3558                             {
3559                                 for(z=offz; z<tz1; z++)
3560                                 {
3561                                     commbuf[i0+z] += grid_th[i0t+z];
3562                                 }
3563                             }
3564                         }
3565                     }
3566                 }
3567             }
3568         }
3569     }
3570 }
3571
3572
3573 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3574 {
3575     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3576     pme_overlap_t *overlap;
3577     int  send_nindex;
3578     int  recv_index0,recv_nindex;
3579 #ifdef GMX_MPI
3580     MPI_Status stat;
3581 #endif
3582     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3583     real *sendptr,*recvptr;
3584     int  x,y,z,indg,indb;
3585
3586     /* Note that this routine is only used for forward communication.
3587      * Since the force gathering, unlike the charge spreading,
3588      * can be trivially parallelized over the particles,
3589      * the backwards process is much simpler and can use the "old"
3590      * communication setup.
3591      */
3592
3593     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3594                                    local_fft_ndata,
3595                                    local_fft_offset,
3596                                    local_fft_size);
3597
3598     /* Currently supports only a single communication pulse */
3599
3600 /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3601     if (pme->nnodes_minor > 1)
3602     {
3603         /* Major dimension */
3604         overlap = &pme->overlap[1];
3605
3606         if (pme->nnodes_major > 1)
3607         {
3608              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3609         }
3610         else
3611         {
3612             size_yx = 0;
3613         }
3614         datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3615
3616         ipulse = 0;
3617
3618         send_id = overlap->send_id[ipulse];
3619         recv_id = overlap->recv_id[ipulse];
3620         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3621         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3622         recv_index0 = 0;
3623         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3624
3625         sendptr = overlap->sendbuf;
3626         recvptr = overlap->recvbuf;
3627
3628         /*
3629         printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3630                local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3631         printf("node %d send %f, %f\n",pme->nodeid,
3632                sendptr[0],sendptr[send_nindex*datasize-1]);
3633         */
3634
3635 #ifdef GMX_MPI
3636         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3637                      send_id,ipulse,
3638                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3639                      recv_id,ipulse,
3640                      overlap->mpi_comm,&stat);
3641 #endif
3642
3643         for(x=0; x<local_fft_ndata[XX]; x++)
3644         {
3645             for(y=0; y<recv_nindex; y++)
3646             {
3647                 indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3648                 indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3649                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3650                 {
3651                     fftgrid[indg+z] += recvptr[indb+z];
3652                 }
3653             }
3654         }
3655         if (pme->nnodes_major > 1)
3656         {
3657             sendptr = pme->overlap[0].sendbuf;
3658             for(x=0; x<size_yx; x++)
3659             {
3660                 for(y=0; y<recv_nindex; y++)
3661                 {
3662                     indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3663                     indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3664                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3665                     {
3666                         sendptr[indg+z] += recvptr[indb+z];
3667                     }
3668                 }
3669             }
3670         }
3671     }
3672
3673     /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3674     if (pme->nnodes_major > 1)
3675     {
3676         /* Major dimension */
3677         overlap = &pme->overlap[0];
3678
3679         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3680         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3681
3682         ipulse = 0;
3683
3684         send_id = overlap->send_id[ipulse];
3685         recv_id = overlap->recv_id[ipulse];
3686         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3687         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3688         recv_index0 = 0;
3689         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3690
3691         sendptr = overlap->sendbuf;
3692         recvptr = overlap->recvbuf;
3693
3694         if (debug != NULL)
3695         {
3696             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3697                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3698         }
3699
3700 #ifdef GMX_MPI
3701         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3702                      send_id,ipulse,
3703                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3704                      recv_id,ipulse,
3705                      overlap->mpi_comm,&stat);
3706 #endif
3707
3708         for(x=0; x<recv_nindex; x++)
3709         {
3710             for(y=0; y<local_fft_ndata[YY]; y++)
3711             {
3712                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3713                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3714                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3715                 {
3716                     fftgrid[indg+z] += recvptr[indb+z];
3717                 }
3718             }
3719         }
3720     }
3721 }
3722
3723
3724 static void spread_on_grid(gmx_pme_t pme,
3725                            pme_atomcomm_t *atc,pmegrids_t *grids,
3726                            gmx_bool bCalcSplines,gmx_bool bSpread,
3727                            real *fftgrid)
3728 {
3729     int nthread,thread;
3730 #ifdef PME_TIME_THREADS
3731     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3732     static double cs1=0,cs2=0,cs3=0;
3733     static double cs1a[6]={0,0,0,0,0,0};
3734     static int cnt=0;
3735 #endif
3736
3737     nthread = pme->nthread;
3738     assert(nthread>0);
3739
3740 #ifdef PME_TIME_THREADS
3741     c1 = omp_cyc_start();
3742 #endif
3743     if (bCalcSplines)
3744     {
3745 #pragma omp parallel for num_threads(nthread) schedule(static)
3746         for(thread=0; thread<nthread; thread++)
3747         {
3748             int start,end;
3749
3750             start = atc->n* thread   /nthread;
3751             end   = atc->n*(thread+1)/nthread;
3752
3753             /* Compute fftgrid index for all atoms,
3754              * with help of some extra variables.
3755              */
3756             calc_interpolation_idx(pme,atc,start,end,thread);
3757         }
3758     }
3759 #ifdef PME_TIME_THREADS
3760     c1 = omp_cyc_end(c1);
3761     cs1 += (double)c1;
3762 #endif
3763
3764 #ifdef PME_TIME_THREADS
3765     c2 = omp_cyc_start();
3766 #endif
3767 #pragma omp parallel for num_threads(nthread) schedule(static)
3768     for(thread=0; thread<nthread; thread++)
3769     {
3770         splinedata_t *spline;
3771         pmegrid_t *grid;
3772
3773         /* make local bsplines  */
3774         if (grids == NULL || grids->nthread == 1)
3775         {
3776             spline = &atc->spline[0];
3777
3778             spline->n = atc->n;
3779
3780             grid = &grids->grid;
3781         }
3782         else
3783         {
3784             spline = &atc->spline[thread];
3785
3786             make_thread_local_ind(atc,thread,spline);
3787
3788             grid = &grids->grid_th[thread];
3789         }
3790
3791         if (bCalcSplines)
3792         {
3793             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3794                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3795         }
3796
3797         if (bSpread)
3798         {
3799             /* put local atoms on grid. */
3800 #ifdef PME_TIME_SPREAD
3801             ct1a = omp_cyc_start();
3802 #endif
3803             spread_q_bsplines_thread(grid,atc,spline,pme->spline_work);
3804
3805             if (grids->nthread > 1)
3806             {
3807                 copy_local_grid(pme,grids,thread,fftgrid);
3808             }
3809 #ifdef PME_TIME_SPREAD
3810             ct1a = omp_cyc_end(ct1a);
3811             cs1a[thread] += (double)ct1a;
3812 #endif
3813         }
3814     }
3815 #ifdef PME_TIME_THREADS
3816     c2 = omp_cyc_end(c2);
3817     cs2 += (double)c2;
3818 #endif
3819
3820     if (bSpread && grids->nthread > 1)
3821     {
3822 #ifdef PME_TIME_THREADS
3823         c3 = omp_cyc_start();
3824 #endif
3825 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3826         for(thread=0; thread<grids->nthread; thread++)
3827         {
3828             reduce_threadgrid_overlap(pme,grids,thread,
3829                                       fftgrid,
3830                                       pme->overlap[0].sendbuf,
3831                                       pme->overlap[1].sendbuf);
3832 #ifdef PRINT_PME_SENDBUF
3833             print_sendbuf(pme,pme->overlap[0].sendbuf);
3834 #endif
3835         }
3836 #ifdef PME_TIME_THREADS
3837         c3 = omp_cyc_end(c3);
3838         cs3 += (double)c3;
3839 #endif
3840
3841         if (pme->nnodes > 1)
3842         {
3843             /* Communicate the overlapping part of the fftgrid */
3844             sum_fftgrid_dd(pme,fftgrid);
3845         }
3846     }
3847
3848 #ifdef PME_TIME_THREADS
3849     cnt++;
3850     if (cnt % 20 == 0)
3851     {
3852         printf("idx %.2f spread %.2f red %.2f",
3853                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3854 #ifdef PME_TIME_SPREAD
3855         for(thread=0; thread<nthread; thread++)
3856             printf(" %.2f",cs1a[thread]*1e-9);
3857 #endif
3858         printf("\n");
3859     }
3860 #endif
3861 }
3862
3863
3864 static void dump_grid(FILE *fp,
3865                       int sx,int sy,int sz,int nx,int ny,int nz,
3866                       int my,int mz,const real *g)
3867 {
3868     int x,y,z;
3869
3870     for(x=0; x<nx; x++)
3871     {
3872         for(y=0; y<ny; y++)
3873         {
3874             for(z=0; z<nz; z++)
3875             {
3876                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3877                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3878             }
3879         }
3880     }
3881 }
3882
3883 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3884 {
3885     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3886
3887     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3888                                    local_fft_ndata,
3889                                    local_fft_offset,
3890                                    local_fft_size);
3891
3892     dump_grid(stderr,
3893               pme->pmegrid_start_ix,
3894               pme->pmegrid_start_iy,
3895               pme->pmegrid_start_iz,
3896               pme->pmegrid_nx-pme->pme_order+1,
3897               pme->pmegrid_ny-pme->pme_order+1,
3898               pme->pmegrid_nz-pme->pme_order+1,
3899               local_fft_size[YY],
3900               local_fft_size[ZZ],
3901               fftgrid);
3902 }
3903
3904
3905 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3906 {
3907     pme_atomcomm_t *atc;
3908     pmegrids_t *grid;
3909
3910     if (pme->nnodes > 1)
3911     {
3912         gmx_incons("gmx_pme_calc_energy called in parallel");
3913     }
3914     if (pme->bFEP > 1)
3915     {
3916         gmx_incons("gmx_pme_calc_energy with free energy");
3917     }
3918
3919     atc = &pme->atc_energy;
3920     atc->nthread   = 1;
3921     if (atc->spline == NULL)
3922     {
3923         snew(atc->spline,atc->nthread);
3924     }
3925     atc->nslab     = 1;
3926     atc->bSpread   = TRUE;
3927     atc->pme_order = pme->pme_order;
3928     atc->n         = n;
3929     pme_realloc_atomcomm_things(atc);
3930     atc->x         = x;
3931     atc->q         = q;
3932
3933     /* We only use the A-charges grid */
3934     grid = &pme->pmegridA;
3935
3936     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3937
3938     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3939 }
3940
3941
3942 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3943         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3944 {
3945     /* Reset all the counters related to performance over the run */
3946     wallcycle_stop(wcycle,ewcRUN);
3947     wallcycle_reset_all(wcycle);
3948     init_nrnb(nrnb);
3949     ir->init_step += step_rel;
3950     ir->nsteps    -= step_rel;
3951     wallcycle_start(wcycle,ewcRUN);
3952 }
3953
3954
3955 int gmx_pmeonly(gmx_pme_t pme,
3956                 t_commrec *cr,    t_nrnb *nrnb,
3957                 gmx_wallcycle_t wcycle,
3958                 real ewaldcoeff,  gmx_bool bGatherOnly,
3959                 t_inputrec *ir)
3960 {
3961     gmx_pme_pp_t pme_pp;
3962     int  natoms;
3963     matrix box;
3964     rvec *x_pp=NULL,*f_pp=NULL;
3965     real *chargeA=NULL,*chargeB=NULL;
3966     real lambda=0;
3967     int  maxshift_x=0,maxshift_y=0;
3968     real energy,dvdlambda;
3969     matrix vir;
3970     float cycles;
3971     int  count;
3972     gmx_bool bEnerVir;
3973     gmx_large_int_t step,step_rel;
3974
3975
3976     pme_pp = gmx_pme_pp_init(cr);
3977
3978     init_nrnb(nrnb);
3979
3980     count = 0;
3981     do /****** this is a quasi-loop over time steps! */
3982     {
3983         /* Domain decomposition */
3984         natoms = gmx_pme_recv_q_x(pme_pp,
3985                                   &chargeA,&chargeB,box,&x_pp,&f_pp,
3986                                   &maxshift_x,&maxshift_y,
3987                                   &pme->bFEP,&lambda,
3988                                   &bEnerVir,
3989                                   &step);
3990
3991         if (natoms == -1) {
3992             /* We should stop: break out of the loop */
3993             break;
3994         }
3995
3996         step_rel = step - ir->init_step;
3997
3998         if (count == 0)
3999             wallcycle_start(wcycle,ewcRUN);
4000
4001         wallcycle_start(wcycle,ewcPMEMESH);
4002
4003         dvdlambda = 0;
4004         clear_mat(vir);
4005         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
4006                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
4007                    &energy,lambda,&dvdlambda,
4008                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4009
4010         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
4011
4012         gmx_pme_send_force_vir_ener(pme_pp,
4013                                     f_pp,vir,energy,dvdlambda,
4014                                     cycles);
4015
4016         count++;
4017
4018         if (step_rel == wcycle_get_reset_counters(wcycle))
4019         {
4020             /* Reset all the counters related to performance over the run */
4021             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4022             wcycle_set_reset_counters(wcycle, 0);
4023         }
4024
4025     } /***** end of quasi-loop, we stop with the break above */
4026     while (TRUE);
4027
4028     return 0;
4029 }
4030
4031 int gmx_pme_do(gmx_pme_t pme,
4032                int start,       int homenr,
4033                rvec x[],        rvec f[],
4034                real *chargeA,   real *chargeB,
4035                matrix box, t_commrec *cr,
4036                int  maxshift_x, int maxshift_y,
4037                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4038                matrix vir,      real ewaldcoeff,
4039                real *energy,    real lambda,
4040                real *dvdlambda, int flags)
4041 {
4042     int     q,d,i,j,ntot,npme;
4043     int     nx,ny,nz;
4044     int     n_d,local_ny;
4045     pme_atomcomm_t *atc=NULL;
4046     pmegrids_t *pmegrid=NULL;
4047     real    *grid=NULL;
4048     real    *ptr;
4049     rvec    *x_d,*f_d;
4050     real    *charge=NULL,*q_d;
4051     real    energy_AB[2];
4052     matrix  vir_AB[2];
4053     gmx_bool bClearF;
4054     gmx_parallel_3dfft_t pfft_setup;
4055     real *  fftgrid;
4056     t_complex * cfftgrid;
4057     int     thread;
4058     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4059     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4060
4061     assert(pme->nnodes > 0);
4062     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4063
4064     if (pme->nnodes > 1) {
4065         atc = &pme->atc[0];
4066         atc->npd = homenr;
4067         if (atc->npd > atc->pd_nalloc) {
4068             atc->pd_nalloc = over_alloc_dd(atc->npd);
4069             srenew(atc->pd,atc->pd_nalloc);
4070         }
4071         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4072     }
4073     else
4074     {
4075         /* This could be necessary for TPI */
4076         pme->atc[0].n = homenr;
4077     }
4078
4079     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4080         if (q == 0) {
4081             pmegrid = &pme->pmegridA;
4082             fftgrid = pme->fftgridA;
4083             cfftgrid = pme->cfftgridA;
4084             pfft_setup = pme->pfft_setupA;
4085             charge = chargeA+start;
4086         } else {
4087             pmegrid = &pme->pmegridB;
4088             fftgrid = pme->fftgridB;
4089             cfftgrid = pme->cfftgridB;
4090             pfft_setup = pme->pfft_setupB;
4091             charge = chargeB+start;
4092         }
4093         grid = pmegrid->grid.grid;
4094         /* Unpack structure */
4095         if (debug) {
4096             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4097                     cr->nnodes,cr->nodeid);
4098             fprintf(debug,"Grid = %p\n",(void*)grid);
4099             if (grid == NULL)
4100                 gmx_fatal(FARGS,"No grid!");
4101         }
4102         where();
4103
4104         m_inv_ur0(box,pme->recipbox);
4105
4106         if (pme->nnodes == 1) {
4107             atc = &pme->atc[0];
4108             if (DOMAINDECOMP(cr)) {
4109                 atc->n = homenr;
4110                 pme_realloc_atomcomm_things(atc);
4111             }
4112             atc->x = x;
4113             atc->q = charge;
4114             atc->f = f;
4115         } else {
4116             wallcycle_start(wcycle,ewcPME_REDISTXF);
4117             for(d=pme->ndecompdim-1; d>=0; d--)
4118             {
4119                 if (d == pme->ndecompdim-1)
4120                 {
4121                     n_d = homenr;
4122                     x_d = x + start;
4123                     q_d = charge;
4124                 }
4125                 else
4126                 {
4127                     n_d = pme->atc[d+1].n;
4128                     x_d = atc->x;
4129                     q_d = atc->q;
4130                 }
4131                 atc = &pme->atc[d];
4132                 atc->npd = n_d;
4133                 if (atc->npd > atc->pd_nalloc) {
4134                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4135                     srenew(atc->pd,atc->pd_nalloc);
4136                 }
4137                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4138                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4139                 where();
4140
4141                 /* Redistribute x (only once) and qA or qB */
4142                 if (DOMAINDECOMP(cr)) {
4143                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4144                 } else {
4145                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4146                 }
4147             }
4148             where();
4149
4150             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4151         }
4152
4153         if (debug)
4154             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4155                     cr->nodeid,atc->n);
4156
4157         if (flags & GMX_PME_SPREAD_Q)
4158         {
4159             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4160
4161             /* Spread the charges on a grid */
4162             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4163
4164             if (q == 0)
4165             {
4166                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4167             }
4168             inc_nrnb(nrnb,eNR_SPREADQBSP,
4169                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4170
4171             if (pme->nthread == 1)
4172             {
4173                 wrap_periodic_pmegrid(pme,grid);
4174
4175                 /* sum contributions to local grid from other nodes */
4176 #ifdef GMX_MPI
4177                 if (pme->nnodes > 1)
4178                 {
4179                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4180                     where();
4181                 }
4182 #endif
4183
4184                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4185             }
4186
4187             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4188
4189             /*
4190             dump_local_fftgrid(pme,fftgrid);
4191             exit(0);
4192             */
4193         }
4194
4195         /* Here we start a large thread parallel region */
4196 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4197         for(thread=0; thread<pme->nthread; thread++)
4198         {
4199             if (flags & GMX_PME_SOLVE)
4200             {
4201                 int loop_count;
4202
4203                 /* do 3d-fft */
4204                 if (thread == 0)
4205                 {
4206                     wallcycle_start(wcycle,ewcPME_FFT);
4207                 }
4208                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4209                                            fftgrid,cfftgrid,thread,wcycle);
4210                 if (thread == 0)
4211                 {
4212                     wallcycle_stop(wcycle,ewcPME_FFT);
4213                 }
4214                 where();
4215
4216                 /* solve in k-space for our local cells */
4217                 if (thread == 0)
4218                 {
4219                     wallcycle_start(wcycle,ewcPME_SOLVE);
4220                 }
4221                 loop_count =
4222                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4223                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4224                                   bCalcEnerVir,
4225                                   pme->nthread,thread);
4226                 if (thread == 0)
4227                 {
4228                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4229                     where();
4230                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4231                 }
4232             }
4233
4234             if (bCalcF)
4235             {
4236                 /* do 3d-invfft */
4237                 if (thread == 0)
4238                 {
4239                     where();
4240                     wallcycle_start(wcycle,ewcPME_FFT);
4241                 }
4242                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4243                                            cfftgrid,fftgrid,thread,wcycle);
4244                 if (thread == 0)
4245                 {
4246                     wallcycle_stop(wcycle,ewcPME_FFT);
4247
4248                     where();
4249
4250                     if (pme->nodeid == 0)
4251                     {
4252                         ntot = pme->nkx*pme->nky*pme->nkz;
4253                         npme  = ntot*log((real)ntot)/log(2.0);
4254                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4255                     }
4256
4257                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4258                 }
4259
4260                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4261             }
4262         }
4263         /* End of thread parallel section.
4264          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4265          */
4266
4267         if (bCalcF)
4268         {
4269             /* distribute local grid to all nodes */
4270 #ifdef GMX_MPI
4271             if (pme->nnodes > 1) {
4272                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4273             }
4274 #endif
4275             where();
4276
4277             unwrap_periodic_pmegrid(pme,grid);
4278
4279             /* interpolate forces for our local atoms */
4280
4281             where();
4282
4283             /* If we are running without parallelization,
4284              * atc->f is the actual force array, not a buffer,
4285              * therefore we should not clear it.
4286              */
4287             bClearF = (q == 0 && PAR(cr));
4288 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4289             for(thread=0; thread<pme->nthread; thread++)
4290             {
4291                 gather_f_bsplines(pme,grid,bClearF,atc,
4292                                   &atc->spline[thread],
4293                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4294             }
4295
4296             where();
4297
4298             inc_nrnb(nrnb,eNR_GATHERFBSP,
4299                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4300             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4301         }
4302
4303         if (bCalcEnerVir)
4304         {
4305             /* This should only be called on the master thread
4306              * and after the threads have synchronized.
4307              */
4308             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4309         }
4310     } /* of q-loop */
4311
4312     if (bCalcF && pme->nnodes > 1) {
4313         wallcycle_start(wcycle,ewcPME_REDISTXF);
4314         for(d=0; d<pme->ndecompdim; d++)
4315         {
4316             atc = &pme->atc[d];
4317             if (d == pme->ndecompdim - 1)
4318             {
4319                 n_d = homenr;
4320                 f_d = f + start;
4321             }
4322             else
4323             {
4324                 n_d = pme->atc[d+1].n;
4325                 f_d = pme->atc[d+1].f;
4326             }
4327             if (DOMAINDECOMP(cr)) {
4328                 dd_pmeredist_f(pme,atc,n_d,f_d,
4329                                d==pme->ndecompdim-1 && pme->bPPnode);
4330             } else {
4331                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4332             }
4333         }
4334
4335         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4336     }
4337     where();
4338
4339     if (bCalcEnerVir)
4340     {
4341         if (!pme->bFEP) {
4342             *energy = energy_AB[0];
4343             m_add(vir,vir_AB[0],vir);
4344         } else {
4345             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4346             *dvdlambda += energy_AB[1] - energy_AB[0];
4347             for(i=0; i<DIM; i++)
4348             {
4349                 for(j=0; j<DIM; j++)
4350                 {
4351                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4352                         lambda*vir_AB[1][i][j];
4353                 }
4354             }
4355         }
4356     }
4357     else
4358     {
4359         *energy = 0;
4360     }
4361
4362     if (debug)
4363     {
4364         fprintf(debug,"PME mesh energy: %g\n",*energy);
4365     }
4366
4367     return 0;
4368 }