bc2f8580b5e6759759bfd40bc2456164b720ea22
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #include <stdio.h>
64 #include <string.h>
65 #include <math.h>
66 #include <assert.h>
67 #include "typedefs.h"
68 #include "txtdump.h"
69 #include "vec.h"
70 #include "gmxcomplex.h"
71 #include "smalloc.h"
72 #include "coulomb.h"
73 #include "gmx_fatal.h"
74 #include "pme.h"
75 #include "network.h"
76 #include "physics.h"
77 #include "nrnb.h"
78 #include "macros.h"
79
80 #include "gromacs/fft/parallel_3dfft.h"
81 #include "gromacs/fileio/futil.h"
82 #include "gromacs/fileio/pdbio.h"
83 #include "gromacs/timing/cyclecounter.h"
84 #include "gromacs/timing/wallcycle.h"
85 #include "gromacs/utility/gmxmpi.h"
86 #include "gromacs/utility/gmxomp.h"
87
88 /* Include the SIMD macro file and then check for support */
89 #include "gmx_simd_macros.h"
90 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
91 /* Turn on arbitrary width SIMD intrinsics for PME solve */
92 #define PME_SIMD
93 #endif
94
95 /* Include the 4-wide SIMD macro file */
96 #include "gmx_simd4_macros.h"
97 /* Check if we have 4-wide SIMD macro support */
98 #ifdef GMX_HAVE_SIMD4_MACROS
99 /* Do PME spread and gather with 4-wide SIMD.
100  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
101  */
102 #define PME_SIMD4_SPREAD_GATHER
103
104 #ifdef GMX_SIMD4_HAVE_UNALIGNED
105 /* With PME-order=4 on x86, unaligned load+store is slightly faster
106  * than doubling all SIMD operations when using aligned load+store.
107  */
108 #define PME_SIMD4_UNALIGNED
109 #endif
110 #endif
111
112 #define DFT_TOL 1e-7
113 /* #define PRT_FORCE */
114 /* conditions for on the fly time-measurement */
115 /* #define TAKETIME (step > 1 && timesteps < 10) */
116 #define TAKETIME FALSE
117
118 /* #define PME_TIME_THREADS */
119
120 #ifdef GMX_DOUBLE
121 #define mpi_type MPI_DOUBLE
122 #else
123 #define mpi_type MPI_FLOAT
124 #endif
125
126 #ifdef PME_SIMD4_SPREAD_GATHER
127 #define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
128 #else
129 /* We can use any alignment, apart from 0, so we use 4 reals */
130 #define SIMD4_ALIGNMENT  (4*sizeof(real))
131 #endif
132
133 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
134  * to preserve alignment.
135  */
136 #define GMX_CACHE_SEP 64
137
138 /* We only define a maximum to be able to use local arrays without allocation.
139  * An order larger than 12 should never be needed, even for test cases.
140  * If needed it can be changed here.
141  */
142 #define PME_ORDER_MAX 12
143
144 /* Internal datastructures */
145 typedef struct {
146     int send_index0;
147     int send_nindex;
148     int recv_index0;
149     int recv_nindex;
150     int recv_size;   /* Receive buffer width, used with OpenMP */
151 } pme_grid_comm_t;
152
153 typedef struct {
154 #ifdef GMX_MPI
155     MPI_Comm         mpi_comm;
156 #endif
157     int              nnodes, nodeid;
158     int             *s2g0;
159     int             *s2g1;
160     int              noverlap_nodes;
161     int             *send_id, *recv_id;
162     int              send_size; /* Send buffer width, used with OpenMP */
163     pme_grid_comm_t *comm_data;
164     real            *sendbuf;
165     real            *recvbuf;
166 } pme_overlap_t;
167
168 typedef struct {
169     int *n;      /* Cumulative counts of the number of particles per thread */
170     int  nalloc; /* Allocation size of i */
171     int *i;      /* Particle indices ordered on thread index (n) */
172 } thread_plist_t;
173
174 typedef struct {
175     int      *thread_one;
176     int       n;
177     int      *ind;
178     splinevec theta;
179     real     *ptr_theta_z;
180     splinevec dtheta;
181     real     *ptr_dtheta_z;
182 } splinedata_t;
183
184 typedef struct {
185     int      dimind;        /* The index of the dimension, 0=x, 1=y */
186     int      nslab;
187     int      nodeid;
188 #ifdef GMX_MPI
189     MPI_Comm mpi_comm;
190 #endif
191
192     int     *node_dest;     /* The nodes to send x and q to with DD */
193     int     *node_src;      /* The nodes to receive x and q from with DD */
194     int     *buf_index;     /* Index for commnode into the buffers */
195
196     int      maxshift;
197
198     int      npd;
199     int      pd_nalloc;
200     int     *pd;
201     int     *count;         /* The number of atoms to send to each node */
202     int    **count_thread;
203     int     *rcount;        /* The number of atoms to receive */
204
205     int      n;
206     int      nalloc;
207     rvec    *x;
208     real    *q;
209     rvec    *f;
210     gmx_bool bSpread;       /* These coordinates are used for spreading */
211     int      pme_order;
212     ivec    *idx;
213     rvec    *fractx;            /* Fractional coordinate relative to the
214                                  * lower cell boundary
215                                  */
216     int             nthread;
217     int            *thread_idx; /* Which thread should spread which charge */
218     thread_plist_t *thread_plist;
219     splinedata_t   *spline;
220 } pme_atomcomm_t;
221
222 #define FLBS  3
223 #define FLBSZ 4
224
225 typedef struct {
226     ivec  ci;     /* The spatial location of this grid         */
227     ivec  n;      /* The used size of *grid, including order-1 */
228     ivec  offset; /* The grid offset from the full node grid   */
229     int   order;  /* PME spreading order                       */
230     ivec  s;      /* The allocated size of *grid, s >= n       */
231     real *grid;   /* The grid local thread, size n             */
232 } pmegrid_t;
233
234 typedef struct {
235     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
236     int        nthread;      /* The number of threads operating on this grid     */
237     ivec       nc;           /* The local spatial decomposition over the threads */
238     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
239     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
240     int      **g2t;          /* The grid to thread index                         */
241     ivec       nthread_comm; /* The number of threads to communicate with        */
242 } pmegrids_t;
243
244
245 typedef struct {
246 #ifdef PME_SIMD4_SPREAD_GATHER
247     /* Masks for 4-wide SIMD aligned spreading and gathering */
248     gmx_simd4_pb mask_S0[6], mask_S1[6];
249 #else
250     int          dummy; /* C89 requires that struct has at least one member */
251 #endif
252 } pme_spline_work_t;
253
254 typedef struct {
255     /* work data for solve_pme */
256     int      nalloc;
257     real *   mhx;
258     real *   mhy;
259     real *   mhz;
260     real *   m2;
261     real *   denom;
262     real *   tmp1_alloc;
263     real *   tmp1;
264     real *   eterm;
265     real *   m2inv;
266
267     real     energy;
268     matrix   vir;
269 } pme_work_t;
270
271 typedef struct gmx_pme {
272     int           ndecompdim; /* The number of decomposition dimensions */
273     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
274     int           nodeid_major;
275     int           nodeid_minor;
276     int           nnodes;    /* The number of nodes doing PME */
277     int           nnodes_major;
278     int           nnodes_minor;
279
280     MPI_Comm      mpi_comm;
281     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
282 #ifdef GMX_MPI
283     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
284 #endif
285
286     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
287     int        nthread;       /* The number of threads doing PME on our rank */
288
289     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
290     gmx_bool   bFEP;          /* Compute Free energy contribution */
291     int        nkx, nky, nkz; /* Grid dimensions */
292     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
293     int        pme_order;
294     real       epsilon_r;
295
296     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
297     pmegrids_t pmegridB;
298     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
299     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
300     /* pmegrid_nz might be larger than strictly necessary to ensure
301      * memory alignment, pmegrid_nz_base gives the real base size.
302      */
303     int     pmegrid_nz_base;
304     /* The local PME grid starting indices */
305     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
306
307     /* Work data for spreading and gathering */
308     pme_spline_work_t    *spline_work;
309
310     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
311     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
312     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
313
314     t_complex            *cfftgridA;  /* Grids for complex FFT data */
315     t_complex            *cfftgridB;
316     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
317
318     gmx_parallel_3dfft_t  pfft_setupA;
319     gmx_parallel_3dfft_t  pfft_setupB;
320
321     int                  *nnx, *nny, *nnz;
322     real                 *fshx, *fshy, *fshz;
323
324     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
325     matrix                recipbox;
326     splinevec             bsp_mod;
327
328     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
329
330     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
331
332     rvec                 *bufv;       /* Communication buffer */
333     real                 *bufr;       /* Communication buffer */
334     int                   buf_nalloc; /* The communication buffer size */
335
336     /* thread local work data for solve_pme */
337     pme_work_t *work;
338
339     /* Work data for PME_redist */
340     gmx_bool redist_init;
341     int *    scounts;
342     int *    rcounts;
343     int *    sdispls;
344     int *    rdispls;
345     int *    sidx;
346     int *    idxa;
347     real *   redist_buf;
348     int      redist_buf_nalloc;
349
350     /* Work data for sum_qgrid */
351     real *   sum_qgrid_tmp;
352     real *   sum_qgrid_dd_tmp;
353 } t_gmx_pme;
354
355
356 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
357                                    int start, int end, int thread)
358 {
359     int             i;
360     int            *idxptr, tix, tiy, tiz;
361     real           *xptr, *fptr, tx, ty, tz;
362     real            rxx, ryx, ryy, rzx, rzy, rzz;
363     int             nx, ny, nz;
364     int             start_ix, start_iy, start_iz;
365     int            *g2tx, *g2ty, *g2tz;
366     gmx_bool        bThreads;
367     int            *thread_idx = NULL;
368     thread_plist_t *tpl        = NULL;
369     int            *tpl_n      = NULL;
370     int             thread_i;
371
372     nx  = pme->nkx;
373     ny  = pme->nky;
374     nz  = pme->nkz;
375
376     start_ix = pme->pmegrid_start_ix;
377     start_iy = pme->pmegrid_start_iy;
378     start_iz = pme->pmegrid_start_iz;
379
380     rxx = pme->recipbox[XX][XX];
381     ryx = pme->recipbox[YY][XX];
382     ryy = pme->recipbox[YY][YY];
383     rzx = pme->recipbox[ZZ][XX];
384     rzy = pme->recipbox[ZZ][YY];
385     rzz = pme->recipbox[ZZ][ZZ];
386
387     g2tx = pme->pmegridA.g2t[XX];
388     g2ty = pme->pmegridA.g2t[YY];
389     g2tz = pme->pmegridA.g2t[ZZ];
390
391     bThreads = (atc->nthread > 1);
392     if (bThreads)
393     {
394         thread_idx = atc->thread_idx;
395
396         tpl   = &atc->thread_plist[thread];
397         tpl_n = tpl->n;
398         for (i = 0; i < atc->nthread; i++)
399         {
400             tpl_n[i] = 0;
401         }
402     }
403
404     for (i = start; i < end; i++)
405     {
406         xptr   = atc->x[i];
407         idxptr = atc->idx[i];
408         fptr   = atc->fractx[i];
409
410         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
411         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
412         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
413         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
414
415         tix = (int)(tx);
416         tiy = (int)(ty);
417         tiz = (int)(tz);
418
419         /* Because decomposition only occurs in x and y,
420          * we never have a fraction correction in z.
421          */
422         fptr[XX] = tx - tix + pme->fshx[tix];
423         fptr[YY] = ty - tiy + pme->fshy[tiy];
424         fptr[ZZ] = tz - tiz;
425
426         idxptr[XX] = pme->nnx[tix];
427         idxptr[YY] = pme->nny[tiy];
428         idxptr[ZZ] = pme->nnz[tiz];
429
430 #ifdef DEBUG
431         range_check(idxptr[XX], 0, pme->pmegrid_nx);
432         range_check(idxptr[YY], 0, pme->pmegrid_ny);
433         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
434 #endif
435
436         if (bThreads)
437         {
438             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
439             thread_idx[i] = thread_i;
440             tpl_n[thread_i]++;
441         }
442     }
443
444     if (bThreads)
445     {
446         /* Make a list of particle indices sorted on thread */
447
448         /* Get the cumulative count */
449         for (i = 1; i < atc->nthread; i++)
450         {
451             tpl_n[i] += tpl_n[i-1];
452         }
453         /* The current implementation distributes particles equally
454          * over the threads, so we could actually allocate for that
455          * in pme_realloc_atomcomm_things.
456          */
457         if (tpl_n[atc->nthread-1] > tpl->nalloc)
458         {
459             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
460             srenew(tpl->i, tpl->nalloc);
461         }
462         /* Set tpl_n to the cumulative start */
463         for (i = atc->nthread-1; i >= 1; i--)
464         {
465             tpl_n[i] = tpl_n[i-1];
466         }
467         tpl_n[0] = 0;
468
469         /* Fill our thread local array with indices sorted on thread */
470         for (i = start; i < end; i++)
471         {
472             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
473         }
474         /* Now tpl_n contains the cummulative count again */
475     }
476 }
477
478 static void make_thread_local_ind(pme_atomcomm_t *atc,
479                                   int thread, splinedata_t *spline)
480 {
481     int             n, t, i, start, end;
482     thread_plist_t *tpl;
483
484     /* Combine the indices made by each thread into one index */
485
486     n     = 0;
487     start = 0;
488     for (t = 0; t < atc->nthread; t++)
489     {
490         tpl = &atc->thread_plist[t];
491         /* Copy our part (start - end) from the list of thread t */
492         if (thread > 0)
493         {
494             start = tpl->n[thread-1];
495         }
496         end = tpl->n[thread];
497         for (i = start; i < end; i++)
498         {
499             spline->ind[n++] = tpl->i[i];
500         }
501     }
502
503     spline->n = n;
504 }
505
506
507 static void pme_calc_pidx(int start, int end,
508                           matrix recipbox, rvec x[],
509                           pme_atomcomm_t *atc, int *count)
510 {
511     int   nslab, i;
512     int   si;
513     real *xptr, s;
514     real  rxx, ryx, rzx, ryy, rzy;
515     int  *pd;
516
517     /* Calculate PME task index (pidx) for each grid index.
518      * Here we always assign equally sized slabs to each node
519      * for load balancing reasons (the PME grid spacing is not used).
520      */
521
522     nslab = atc->nslab;
523     pd    = atc->pd;
524
525     /* Reset the count */
526     for (i = 0; i < nslab; i++)
527     {
528         count[i] = 0;
529     }
530
531     if (atc->dimind == 0)
532     {
533         rxx = recipbox[XX][XX];
534         ryx = recipbox[YY][XX];
535         rzx = recipbox[ZZ][XX];
536         /* Calculate the node index in x-dimension */
537         for (i = start; i < end; i++)
538         {
539             xptr   = x[i];
540             /* Fractional coordinates along box vectors */
541             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
542             si    = (int)(s + 2*nslab) % nslab;
543             pd[i] = si;
544             count[si]++;
545         }
546     }
547     else
548     {
549         ryy = recipbox[YY][YY];
550         rzy = recipbox[ZZ][YY];
551         /* Calculate the node index in y-dimension */
552         for (i = start; i < end; i++)
553         {
554             xptr   = x[i];
555             /* Fractional coordinates along box vectors */
556             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
557             si    = (int)(s + 2*nslab) % nslab;
558             pd[i] = si;
559             count[si]++;
560         }
561     }
562 }
563
564 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
565                                   pme_atomcomm_t *atc)
566 {
567     int nthread, thread, slab;
568
569     nthread = atc->nthread;
570
571 #pragma omp parallel for num_threads(nthread) schedule(static)
572     for (thread = 0; thread < nthread; thread++)
573     {
574         pme_calc_pidx(natoms* thread   /nthread,
575                       natoms*(thread+1)/nthread,
576                       recipbox, x, atc, atc->count_thread[thread]);
577     }
578     /* Non-parallel reduction, since nslab is small */
579
580     for (thread = 1; thread < nthread; thread++)
581     {
582         for (slab = 0; slab < atc->nslab; slab++)
583         {
584             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
585         }
586     }
587 }
588
589 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
590 {
591     const int padding = 4;
592     int       i;
593
594     srenew(th[XX], nalloc);
595     srenew(th[YY], nalloc);
596     /* In z we add padding, this is only required for the aligned SIMD code */
597     sfree_aligned(*ptr_z);
598     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
599     th[ZZ] = *ptr_z + padding;
600
601     for (i = 0; i < padding; i++)
602     {
603         (*ptr_z)[               i] = 0;
604         (*ptr_z)[padding+nalloc+i] = 0;
605     }
606 }
607
608 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
609 {
610     int i, d;
611
612     srenew(spline->ind, atc->nalloc);
613     /* Initialize the index to identity so it works without threads */
614     for (i = 0; i < atc->nalloc; i++)
615     {
616         spline->ind[i] = i;
617     }
618
619     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
620                       atc->pme_order*atc->nalloc);
621     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
622                       atc->pme_order*atc->nalloc);
623 }
624
625 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
626 {
627     int nalloc_old, i, j, nalloc_tpl;
628
629     /* We have to avoid a NULL pointer for atc->x to avoid
630      * possible fatal errors in MPI routines.
631      */
632     if (atc->n > atc->nalloc || atc->nalloc == 0)
633     {
634         nalloc_old  = atc->nalloc;
635         atc->nalloc = over_alloc_dd(max(atc->n, 1));
636
637         if (atc->nslab > 1)
638         {
639             srenew(atc->x, atc->nalloc);
640             srenew(atc->q, atc->nalloc);
641             srenew(atc->f, atc->nalloc);
642             for (i = nalloc_old; i < atc->nalloc; i++)
643             {
644                 clear_rvec(atc->f[i]);
645             }
646         }
647         if (atc->bSpread)
648         {
649             srenew(atc->fractx, atc->nalloc);
650             srenew(atc->idx, atc->nalloc);
651
652             if (atc->nthread > 1)
653             {
654                 srenew(atc->thread_idx, atc->nalloc);
655             }
656
657             for (i = 0; i < atc->nthread; i++)
658             {
659                 pme_realloc_splinedata(&atc->spline[i], atc);
660             }
661         }
662     }
663 }
664
665 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
666                          int n, gmx_bool bXF, rvec *x_f, real *charge,
667                          pme_atomcomm_t *atc)
668 /* Redistribute particle data for PME calculation */
669 /* domain decomposition by x coordinate           */
670 {
671     int *idxa;
672     int  i, ii;
673
674     if (FALSE == pme->redist_init)
675     {
676         snew(pme->scounts, atc->nslab);
677         snew(pme->rcounts, atc->nslab);
678         snew(pme->sdispls, atc->nslab);
679         snew(pme->rdispls, atc->nslab);
680         snew(pme->sidx, atc->nslab);
681         pme->redist_init = TRUE;
682     }
683     if (n > pme->redist_buf_nalloc)
684     {
685         pme->redist_buf_nalloc = over_alloc_dd(n);
686         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
687     }
688
689     pme->idxa = atc->pd;
690
691 #ifdef GMX_MPI
692     if (forw && bXF)
693     {
694         /* forward, redistribution from pp to pme */
695
696         /* Calculate send counts and exchange them with other nodes */
697         for (i = 0; (i < atc->nslab); i++)
698         {
699             pme->scounts[i] = 0;
700         }
701         for (i = 0; (i < n); i++)
702         {
703             pme->scounts[pme->idxa[i]]++;
704         }
705         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
706
707         /* Calculate send and receive displacements and index into send
708            buffer */
709         pme->sdispls[0] = 0;
710         pme->rdispls[0] = 0;
711         pme->sidx[0]    = 0;
712         for (i = 1; i < atc->nslab; i++)
713         {
714             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
715             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
716             pme->sidx[i]    = pme->sdispls[i];
717         }
718         /* Total # of particles to be received */
719         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
720
721         pme_realloc_atomcomm_things(atc);
722
723         /* Copy particle coordinates into send buffer and exchange*/
724         for (i = 0; (i < n); i++)
725         {
726             ii = DIM*pme->sidx[pme->idxa[i]];
727             pme->sidx[pme->idxa[i]]++;
728             pme->redist_buf[ii+XX] = x_f[i][XX];
729             pme->redist_buf[ii+YY] = x_f[i][YY];
730             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
731         }
732         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
733                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
734                       pme->rvec_mpi, atc->mpi_comm);
735     }
736     if (forw)
737     {
738         /* Copy charge into send buffer and exchange*/
739         for (i = 0; i < atc->nslab; i++)
740         {
741             pme->sidx[i] = pme->sdispls[i];
742         }
743         for (i = 0; (i < n); i++)
744         {
745             ii = pme->sidx[pme->idxa[i]];
746             pme->sidx[pme->idxa[i]]++;
747             pme->redist_buf[ii] = charge[i];
748         }
749         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
750                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
751                       atc->mpi_comm);
752     }
753     else   /* backward, redistribution from pme to pp */
754     {
755         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
756                       pme->redist_buf, pme->scounts, pme->sdispls,
757                       pme->rvec_mpi, atc->mpi_comm);
758
759         /* Copy data from receive buffer */
760         for (i = 0; i < atc->nslab; i++)
761         {
762             pme->sidx[i] = pme->sdispls[i];
763         }
764         for (i = 0; (i < n); i++)
765         {
766             ii          = DIM*pme->sidx[pme->idxa[i]];
767             x_f[i][XX] += pme->redist_buf[ii+XX];
768             x_f[i][YY] += pme->redist_buf[ii+YY];
769             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
770             pme->sidx[pme->idxa[i]]++;
771         }
772     }
773 #endif
774 }
775
776 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
777                             gmx_bool bBackward, int shift,
778                             void *buf_s, int nbyte_s,
779                             void *buf_r, int nbyte_r)
780 {
781 #ifdef GMX_MPI
782     int        dest, src;
783     MPI_Status stat;
784
785     if (bBackward == FALSE)
786     {
787         dest = atc->node_dest[shift];
788         src  = atc->node_src[shift];
789     }
790     else
791     {
792         dest = atc->node_src[shift];
793         src  = atc->node_dest[shift];
794     }
795
796     if (nbyte_s > 0 && nbyte_r > 0)
797     {
798         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
799                      dest, shift,
800                      buf_r, nbyte_r, MPI_BYTE,
801                      src, shift,
802                      atc->mpi_comm, &stat);
803     }
804     else if (nbyte_s > 0)
805     {
806         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
807                  dest, shift,
808                  atc->mpi_comm);
809     }
810     else if (nbyte_r > 0)
811     {
812         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
813                  src, shift,
814                  atc->mpi_comm, &stat);
815     }
816 #endif
817 }
818
819 static void dd_pmeredist_x_q(gmx_pme_t pme,
820                              int n, gmx_bool bX, rvec *x, real *charge,
821                              pme_atomcomm_t *atc)
822 {
823     int *commnode, *buf_index;
824     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
825
826     commnode  = atc->node_dest;
827     buf_index = atc->buf_index;
828
829     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
830
831     nsend = 0;
832     for (i = 0; i < nnodes_comm; i++)
833     {
834         buf_index[commnode[i]] = nsend;
835         nsend                 += atc->count[commnode[i]];
836     }
837     if (bX)
838     {
839         if (atc->count[atc->nodeid] + nsend != n)
840         {
841             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
842                       "This usually means that your system is not well equilibrated.",
843                       n - (atc->count[atc->nodeid] + nsend),
844                       pme->nodeid, 'x'+atc->dimind);
845         }
846
847         if (nsend > pme->buf_nalloc)
848         {
849             pme->buf_nalloc = over_alloc_dd(nsend);
850             srenew(pme->bufv, pme->buf_nalloc);
851             srenew(pme->bufr, pme->buf_nalloc);
852         }
853
854         atc->n = atc->count[atc->nodeid];
855         for (i = 0; i < nnodes_comm; i++)
856         {
857             scount = atc->count[commnode[i]];
858             /* Communicate the count */
859             if (debug)
860             {
861                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
862                         atc->dimind, atc->nodeid, commnode[i], scount);
863             }
864             pme_dd_sendrecv(atc, FALSE, i,
865                             &scount, sizeof(int),
866                             &atc->rcount[i], sizeof(int));
867             atc->n += atc->rcount[i];
868         }
869
870         pme_realloc_atomcomm_things(atc);
871     }
872
873     local_pos = 0;
874     for (i = 0; i < n; i++)
875     {
876         node = atc->pd[i];
877         if (node == atc->nodeid)
878         {
879             /* Copy direct to the receive buffer */
880             if (bX)
881             {
882                 copy_rvec(x[i], atc->x[local_pos]);
883             }
884             atc->q[local_pos] = charge[i];
885             local_pos++;
886         }
887         else
888         {
889             /* Copy to the send buffer */
890             if (bX)
891             {
892                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
893             }
894             pme->bufr[buf_index[node]] = charge[i];
895             buf_index[node]++;
896         }
897     }
898
899     buf_pos = 0;
900     for (i = 0; i < nnodes_comm; i++)
901     {
902         scount = atc->count[commnode[i]];
903         rcount = atc->rcount[i];
904         if (scount > 0 || rcount > 0)
905         {
906             if (bX)
907             {
908                 /* Communicate the coordinates */
909                 pme_dd_sendrecv(atc, FALSE, i,
910                                 pme->bufv[buf_pos], scount*sizeof(rvec),
911                                 atc->x[local_pos], rcount*sizeof(rvec));
912             }
913             /* Communicate the charges */
914             pme_dd_sendrecv(atc, FALSE, i,
915                             pme->bufr+buf_pos, scount*sizeof(real),
916                             atc->q+local_pos, rcount*sizeof(real));
917             buf_pos   += scount;
918             local_pos += atc->rcount[i];
919         }
920     }
921 }
922
923 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
924                            int n, rvec *f,
925                            gmx_bool bAddF)
926 {
927     int *commnode, *buf_index;
928     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
929
930     commnode  = atc->node_dest;
931     buf_index = atc->buf_index;
932
933     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
934
935     local_pos = atc->count[atc->nodeid];
936     buf_pos   = 0;
937     for (i = 0; i < nnodes_comm; i++)
938     {
939         scount = atc->rcount[i];
940         rcount = atc->count[commnode[i]];
941         if (scount > 0 || rcount > 0)
942         {
943             /* Communicate the forces */
944             pme_dd_sendrecv(atc, TRUE, i,
945                             atc->f[local_pos], scount*sizeof(rvec),
946                             pme->bufv[buf_pos], rcount*sizeof(rvec));
947             local_pos += scount;
948         }
949         buf_index[commnode[i]] = buf_pos;
950         buf_pos               += rcount;
951     }
952
953     local_pos = 0;
954     if (bAddF)
955     {
956         for (i = 0; i < n; i++)
957         {
958             node = atc->pd[i];
959             if (node == atc->nodeid)
960             {
961                 /* Add from the local force array */
962                 rvec_inc(f[i], atc->f[local_pos]);
963                 local_pos++;
964             }
965             else
966             {
967                 /* Add from the receive buffer */
968                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
969                 buf_index[node]++;
970             }
971         }
972     }
973     else
974     {
975         for (i = 0; i < n; i++)
976         {
977             node = atc->pd[i];
978             if (node == atc->nodeid)
979             {
980                 /* Copy from the local force array */
981                 copy_rvec(atc->f[local_pos], f[i]);
982                 local_pos++;
983             }
984             else
985             {
986                 /* Copy from the receive buffer */
987                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
988                 buf_index[node]++;
989             }
990         }
991     }
992 }
993
994 #ifdef GMX_MPI
995 static void
996 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
997 {
998     pme_overlap_t *overlap;
999     int            send_index0, send_nindex;
1000     int            recv_index0, recv_nindex;
1001     MPI_Status     stat;
1002     int            i, j, k, ix, iy, iz, icnt;
1003     int            ipulse, send_id, recv_id, datasize;
1004     real          *p;
1005     real          *sendptr, *recvptr;
1006
1007     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1008     overlap = &pme->overlap[1];
1009
1010     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1011     {
1012         /* Since we have already (un)wrapped the overlap in the z-dimension,
1013          * we only have to communicate 0 to nkz (not pmegrid_nz).
1014          */
1015         if (direction == GMX_SUM_QGRID_FORWARD)
1016         {
1017             send_id       = overlap->send_id[ipulse];
1018             recv_id       = overlap->recv_id[ipulse];
1019             send_index0   = overlap->comm_data[ipulse].send_index0;
1020             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1021             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1022             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1023         }
1024         else
1025         {
1026             send_id       = overlap->recv_id[ipulse];
1027             recv_id       = overlap->send_id[ipulse];
1028             send_index0   = overlap->comm_data[ipulse].recv_index0;
1029             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1030             recv_index0   = overlap->comm_data[ipulse].send_index0;
1031             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1032         }
1033
1034         /* Copy data to contiguous send buffer */
1035         if (debug)
1036         {
1037             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1038                     pme->nodeid, overlap->nodeid, send_id,
1039                     pme->pmegrid_start_iy,
1040                     send_index0-pme->pmegrid_start_iy,
1041                     send_index0-pme->pmegrid_start_iy+send_nindex);
1042         }
1043         icnt = 0;
1044         for (i = 0; i < pme->pmegrid_nx; i++)
1045         {
1046             ix = i;
1047             for (j = 0; j < send_nindex; j++)
1048             {
1049                 iy = j + send_index0 - pme->pmegrid_start_iy;
1050                 for (k = 0; k < pme->nkz; k++)
1051                 {
1052                     iz = k;
1053                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1054                 }
1055             }
1056         }
1057
1058         datasize      = pme->pmegrid_nx * pme->nkz;
1059
1060         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1061                      send_id, ipulse,
1062                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1063                      recv_id, ipulse,
1064                      overlap->mpi_comm, &stat);
1065
1066         /* Get data from contiguous recv buffer */
1067         if (debug)
1068         {
1069             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1070                     pme->nodeid, overlap->nodeid, recv_id,
1071                     pme->pmegrid_start_iy,
1072                     recv_index0-pme->pmegrid_start_iy,
1073                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1074         }
1075         icnt = 0;
1076         for (i = 0; i < pme->pmegrid_nx; i++)
1077         {
1078             ix = i;
1079             for (j = 0; j < recv_nindex; j++)
1080             {
1081                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1082                 for (k = 0; k < pme->nkz; k++)
1083                 {
1084                     iz = k;
1085                     if (direction == GMX_SUM_QGRID_FORWARD)
1086                     {
1087                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1088                     }
1089                     else
1090                     {
1091                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1092                     }
1093                 }
1094             }
1095         }
1096     }
1097
1098     /* Major dimension is easier, no copying required,
1099      * but we might have to sum to separate array.
1100      * Since we don't copy, we have to communicate up to pmegrid_nz,
1101      * not nkz as for the minor direction.
1102      */
1103     overlap = &pme->overlap[0];
1104
1105     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1106     {
1107         if (direction == GMX_SUM_QGRID_FORWARD)
1108         {
1109             send_id       = overlap->send_id[ipulse];
1110             recv_id       = overlap->recv_id[ipulse];
1111             send_index0   = overlap->comm_data[ipulse].send_index0;
1112             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1113             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1114             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1115             recvptr       = overlap->recvbuf;
1116         }
1117         else
1118         {
1119             send_id       = overlap->recv_id[ipulse];
1120             recv_id       = overlap->send_id[ipulse];
1121             send_index0   = overlap->comm_data[ipulse].recv_index0;
1122             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1123             recv_index0   = overlap->comm_data[ipulse].send_index0;
1124             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1125             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1126         }
1127
1128         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1129         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1130
1131         if (debug)
1132         {
1133             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1134                     pme->nodeid, overlap->nodeid, send_id,
1135                     pme->pmegrid_start_ix,
1136                     send_index0-pme->pmegrid_start_ix,
1137                     send_index0-pme->pmegrid_start_ix+send_nindex);
1138             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1139                     pme->nodeid, overlap->nodeid, recv_id,
1140                     pme->pmegrid_start_ix,
1141                     recv_index0-pme->pmegrid_start_ix,
1142                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1143         }
1144
1145         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1146                      send_id, ipulse,
1147                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1148                      recv_id, ipulse,
1149                      overlap->mpi_comm, &stat);
1150
1151         /* ADD data from contiguous recv buffer */
1152         if (direction == GMX_SUM_QGRID_FORWARD)
1153         {
1154             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1155             for (i = 0; i < recv_nindex*datasize; i++)
1156             {
1157                 p[i] += overlap->recvbuf[i];
1158             }
1159         }
1160     }
1161 }
1162 #endif
1163
1164
1165 static int
1166 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1167 {
1168     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1169     ivec    local_pme_size;
1170     int     i, ix, iy, iz;
1171     int     pmeidx, fftidx;
1172
1173     /* Dimensions should be identical for A/B grid, so we just use A here */
1174     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1175                                    local_fft_ndata,
1176                                    local_fft_offset,
1177                                    local_fft_size);
1178
1179     local_pme_size[0] = pme->pmegrid_nx;
1180     local_pme_size[1] = pme->pmegrid_ny;
1181     local_pme_size[2] = pme->pmegrid_nz;
1182
1183     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1184        the offset is identical, and the PME grid always has more data (due to overlap)
1185      */
1186     {
1187 #ifdef DEBUG_PME
1188         FILE *fp, *fp2;
1189         char  fn[STRLEN], format[STRLEN];
1190         real  val;
1191         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1192         fp = ffopen(fn, "w");
1193         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1194         fp2 = ffopen(fn, "w");
1195         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1196 #endif
1197
1198         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1199         {
1200             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1201             {
1202                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1203                 {
1204                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1205                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1206                     fftgrid[fftidx] = pmegrid[pmeidx];
1207 #ifdef DEBUG_PME
1208                     val = 100*pmegrid[pmeidx];
1209                     if (pmegrid[pmeidx] != 0)
1210                     {
1211                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1212                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1213                     }
1214                     if (pmegrid[pmeidx] != 0)
1215                     {
1216                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1217                                 "qgrid",
1218                                 pme->pmegrid_start_ix + ix,
1219                                 pme->pmegrid_start_iy + iy,
1220                                 pme->pmegrid_start_iz + iz,
1221                                 pmegrid[pmeidx]);
1222                     }
1223 #endif
1224                 }
1225             }
1226         }
1227 #ifdef DEBUG_PME
1228         ffclose(fp);
1229         ffclose(fp2);
1230 #endif
1231     }
1232     return 0;
1233 }
1234
1235
1236 static gmx_cycles_t omp_cyc_start()
1237 {
1238     return gmx_cycles_read();
1239 }
1240
1241 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1242 {
1243     return gmx_cycles_read() - c;
1244 }
1245
1246
1247 static int
1248 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1249                         int nthread, int thread)
1250 {
1251     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1252     ivec          local_pme_size;
1253     int           ixy0, ixy1, ixy, ix, iy, iz;
1254     int           pmeidx, fftidx;
1255 #ifdef PME_TIME_THREADS
1256     gmx_cycles_t  c1;
1257     static double cs1 = 0;
1258     static int    cnt = 0;
1259 #endif
1260
1261 #ifdef PME_TIME_THREADS
1262     c1 = omp_cyc_start();
1263 #endif
1264     /* Dimensions should be identical for A/B grid, so we just use A here */
1265     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1266                                    local_fft_ndata,
1267                                    local_fft_offset,
1268                                    local_fft_size);
1269
1270     local_pme_size[0] = pme->pmegrid_nx;
1271     local_pme_size[1] = pme->pmegrid_ny;
1272     local_pme_size[2] = pme->pmegrid_nz;
1273
1274     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1275        the offset is identical, and the PME grid always has more data (due to overlap)
1276      */
1277     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1278     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1279
1280     for (ixy = ixy0; ixy < ixy1; ixy++)
1281     {
1282         ix = ixy/local_fft_ndata[YY];
1283         iy = ixy - ix*local_fft_ndata[YY];
1284
1285         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1286         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1287         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1288         {
1289             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1290         }
1291     }
1292
1293 #ifdef PME_TIME_THREADS
1294     c1   = omp_cyc_end(c1);
1295     cs1 += (double)c1;
1296     cnt++;
1297     if (cnt % 20 == 0)
1298     {
1299         printf("copy %.2f\n", cs1*1e-9);
1300     }
1301 #endif
1302
1303     return 0;
1304 }
1305
1306
1307 static void
1308 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1309 {
1310     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1311
1312     nx = pme->nkx;
1313     ny = pme->nky;
1314     nz = pme->nkz;
1315
1316     pnx = pme->pmegrid_nx;
1317     pny = pme->pmegrid_ny;
1318     pnz = pme->pmegrid_nz;
1319
1320     overlap = pme->pme_order - 1;
1321
1322     /* Add periodic overlap in z */
1323     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1324     {
1325         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1326         {
1327             for (iz = 0; iz < overlap; iz++)
1328             {
1329                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1330                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1331             }
1332         }
1333     }
1334
1335     if (pme->nnodes_minor == 1)
1336     {
1337         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1338         {
1339             for (iy = 0; iy < overlap; iy++)
1340             {
1341                 for (iz = 0; iz < nz; iz++)
1342                 {
1343                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1344                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1345                 }
1346             }
1347         }
1348     }
1349
1350     if (pme->nnodes_major == 1)
1351     {
1352         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1353
1354         for (ix = 0; ix < overlap; ix++)
1355         {
1356             for (iy = 0; iy < ny_x; iy++)
1357             {
1358                 for (iz = 0; iz < nz; iz++)
1359                 {
1360                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1361                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1362                 }
1363             }
1364         }
1365     }
1366 }
1367
1368
1369 static void
1370 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1371 {
1372     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1373
1374     nx = pme->nkx;
1375     ny = pme->nky;
1376     nz = pme->nkz;
1377
1378     pnx = pme->pmegrid_nx;
1379     pny = pme->pmegrid_ny;
1380     pnz = pme->pmegrid_nz;
1381
1382     overlap = pme->pme_order - 1;
1383
1384     if (pme->nnodes_major == 1)
1385     {
1386         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1387
1388         for (ix = 0; ix < overlap; ix++)
1389         {
1390             int iy, iz;
1391
1392             for (iy = 0; iy < ny_x; iy++)
1393             {
1394                 for (iz = 0; iz < nz; iz++)
1395                 {
1396                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1397                         pmegrid[(ix*pny+iy)*pnz+iz];
1398                 }
1399             }
1400         }
1401     }
1402
1403     if (pme->nnodes_minor == 1)
1404     {
1405 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1406         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1407         {
1408             int iy, iz;
1409
1410             for (iy = 0; iy < overlap; iy++)
1411             {
1412                 for (iz = 0; iz < nz; iz++)
1413                 {
1414                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1415                         pmegrid[(ix*pny+iy)*pnz+iz];
1416                 }
1417             }
1418         }
1419     }
1420
1421     /* Copy periodic overlap in z */
1422 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1423     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1424     {
1425         int iy, iz;
1426
1427         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1428         {
1429             for (iz = 0; iz < overlap; iz++)
1430             {
1431                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1432                     pmegrid[(ix*pny+iy)*pnz+iz];
1433             }
1434         }
1435     }
1436 }
1437
1438
1439 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1440 #define DO_BSPLINE(order)                            \
1441     for (ithx = 0; (ithx < order); ithx++)                    \
1442     {                                                    \
1443         index_x = (i0+ithx)*pny*pnz;                     \
1444         valx    = qn*thx[ithx];                          \
1445                                                      \
1446         for (ithy = 0; (ithy < order); ithy++)                \
1447         {                                                \
1448             valxy    = valx*thy[ithy];                   \
1449             index_xy = index_x+(j0+ithy)*pnz;            \
1450                                                      \
1451             for (ithz = 0; (ithz < order); ithz++)            \
1452             {                                            \
1453                 index_xyz        = index_xy+(k0+ithz);   \
1454                 grid[index_xyz] += valxy*thz[ithz];      \
1455             }                                            \
1456         }                                                \
1457     }
1458
1459
1460 static void spread_q_bsplines_thread(pmegrid_t                    *pmegrid,
1461                                      pme_atomcomm_t               *atc,
1462                                      splinedata_t                 *spline,
1463                                      pme_spline_work_t gmx_unused *work)
1464 {
1465
1466     /* spread charges from home atoms to local grid */
1467     real          *grid;
1468     pme_overlap_t *ol;
1469     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1470     int       *    idxptr;
1471     int            order, norder, index_x, index_xy, index_xyz;
1472     real           valx, valxy, qn;
1473     real          *thx, *thy, *thz;
1474     int            localsize, bndsize;
1475     int            pnx, pny, pnz, ndatatot;
1476     int            offx, offy, offz;
1477
1478 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1479     real           thz_buffer[12], *thz_aligned;
1480
1481     thz_aligned = gmx_simd4_align_real(thz_buffer);
1482 #endif
1483
1484     pnx = pmegrid->s[XX];
1485     pny = pmegrid->s[YY];
1486     pnz = pmegrid->s[ZZ];
1487
1488     offx = pmegrid->offset[XX];
1489     offy = pmegrid->offset[YY];
1490     offz = pmegrid->offset[ZZ];
1491
1492     ndatatot = pnx*pny*pnz;
1493     grid     = pmegrid->grid;
1494     for (i = 0; i < ndatatot; i++)
1495     {
1496         grid[i] = 0;
1497     }
1498
1499     order = pmegrid->order;
1500
1501     for (nn = 0; nn < spline->n; nn++)
1502     {
1503         n  = spline->ind[nn];
1504         qn = atc->q[n];
1505
1506         if (qn != 0)
1507         {
1508             idxptr = atc->idx[n];
1509             norder = nn*order;
1510
1511             i0   = idxptr[XX] - offx;
1512             j0   = idxptr[YY] - offy;
1513             k0   = idxptr[ZZ] - offz;
1514
1515             thx = spline->theta[XX] + norder;
1516             thy = spline->theta[YY] + norder;
1517             thz = spline->theta[ZZ] + norder;
1518
1519             switch (order)
1520             {
1521                 case 4:
1522 #ifdef PME_SIMD4_SPREAD_GATHER
1523 #ifdef PME_SIMD4_UNALIGNED
1524 #define PME_SPREAD_SIMD4_ORDER4
1525 #else
1526 #define PME_SPREAD_SIMD4_ALIGNED
1527 #define PME_ORDER 4
1528 #endif
1529 #include "pme_simd4.h"
1530 #else
1531                     DO_BSPLINE(4);
1532 #endif
1533                     break;
1534                 case 5:
1535 #ifdef PME_SIMD4_SPREAD_GATHER
1536 #define PME_SPREAD_SIMD4_ALIGNED
1537 #define PME_ORDER 5
1538 #include "pme_simd4.h"
1539 #else
1540                     DO_BSPLINE(5);
1541 #endif
1542                     break;
1543                 default:
1544                     DO_BSPLINE(order);
1545                     break;
1546             }
1547         }
1548     }
1549 }
1550
1551 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1552 {
1553 #ifdef PME_SIMD4_SPREAD_GATHER
1554     if (pme_order == 5
1555 #ifndef PME_SIMD4_UNALIGNED
1556         || pme_order == 4
1557 #endif
1558         )
1559     {
1560         /* Round nz up to a multiple of 4 to ensure alignment */
1561         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1562     }
1563 #endif
1564 }
1565
1566 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1567 {
1568 #ifdef PME_SIMD4_SPREAD_GATHER
1569 #ifndef PME_SIMD4_UNALIGNED
1570     if (pme_order == 4)
1571     {
1572         /* Add extra elements to ensured aligned operations do not go
1573          * beyond the allocated grid size.
1574          * Note that for pme_order=5, the pme grid z-size alignment
1575          * ensures that we will not go beyond the grid size.
1576          */
1577         *gridsize += 4;
1578     }
1579 #endif
1580 #endif
1581 }
1582
1583 static void pmegrid_init(pmegrid_t *grid,
1584                          int cx, int cy, int cz,
1585                          int x0, int y0, int z0,
1586                          int x1, int y1, int z1,
1587                          gmx_bool set_alignment,
1588                          int pme_order,
1589                          real *ptr)
1590 {
1591     int nz, gridsize;
1592
1593     grid->ci[XX]     = cx;
1594     grid->ci[YY]     = cy;
1595     grid->ci[ZZ]     = cz;
1596     grid->offset[XX] = x0;
1597     grid->offset[YY] = y0;
1598     grid->offset[ZZ] = z0;
1599     grid->n[XX]      = x1 - x0 + pme_order - 1;
1600     grid->n[YY]      = y1 - y0 + pme_order - 1;
1601     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1602     copy_ivec(grid->n, grid->s);
1603
1604     nz = grid->s[ZZ];
1605     set_grid_alignment(&nz, pme_order);
1606     if (set_alignment)
1607     {
1608         grid->s[ZZ] = nz;
1609     }
1610     else if (nz != grid->s[ZZ])
1611     {
1612         gmx_incons("pmegrid_init call with an unaligned z size");
1613     }
1614
1615     grid->order = pme_order;
1616     if (ptr == NULL)
1617     {
1618         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1619         set_gridsize_alignment(&gridsize, pme_order);
1620         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1621     }
1622     else
1623     {
1624         grid->grid = ptr;
1625     }
1626 }
1627
1628 static int div_round_up(int enumerator, int denominator)
1629 {
1630     return (enumerator + denominator - 1)/denominator;
1631 }
1632
1633 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1634                                   ivec nsub)
1635 {
1636     int gsize_opt, gsize;
1637     int nsx, nsy, nsz;
1638     char *env;
1639
1640     gsize_opt = -1;
1641     for (nsx = 1; nsx <= nthread; nsx++)
1642     {
1643         if (nthread % nsx == 0)
1644         {
1645             for (nsy = 1; nsy <= nthread; nsy++)
1646             {
1647                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1648                 {
1649                     nsz = nthread/(nsx*nsy);
1650
1651                     /* Determine the number of grid points per thread */
1652                     gsize =
1653                         (div_round_up(n[XX], nsx) + ovl)*
1654                         (div_round_up(n[YY], nsy) + ovl)*
1655                         (div_round_up(n[ZZ], nsz) + ovl);
1656
1657                     /* Minimize the number of grids points per thread
1658                      * and, secondarily, the number of cuts in minor dimensions.
1659                      */
1660                     if (gsize_opt == -1 ||
1661                         gsize < gsize_opt ||
1662                         (gsize == gsize_opt &&
1663                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1664                     {
1665                         nsub[XX]  = nsx;
1666                         nsub[YY]  = nsy;
1667                         nsub[ZZ]  = nsz;
1668                         gsize_opt = gsize;
1669                     }
1670                 }
1671             }
1672         }
1673     }
1674
1675     env = getenv("GMX_PME_THREAD_DIVISION");
1676     if (env != NULL)
1677     {
1678         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1679     }
1680
1681     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1682     {
1683         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1684     }
1685 }
1686
1687 static void pmegrids_init(pmegrids_t *grids,
1688                           int nx, int ny, int nz, int nz_base,
1689                           int pme_order,
1690                           gmx_bool bUseThreads,
1691                           int nthread,
1692                           int overlap_x,
1693                           int overlap_y)
1694 {
1695     ivec n, n_base, g0, g1;
1696     int t, x, y, z, d, i, tfac;
1697     int max_comm_lines = -1;
1698
1699     n[XX] = nx - (pme_order - 1);
1700     n[YY] = ny - (pme_order - 1);
1701     n[ZZ] = nz - (pme_order - 1);
1702
1703     copy_ivec(n, n_base);
1704     n_base[ZZ] = nz_base;
1705
1706     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1707                  NULL);
1708
1709     grids->nthread = nthread;
1710
1711     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1712
1713     if (bUseThreads)
1714     {
1715         ivec nst;
1716         int gridsize;
1717
1718         for (d = 0; d < DIM; d++)
1719         {
1720             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1721         }
1722         set_grid_alignment(&nst[ZZ], pme_order);
1723
1724         if (debug)
1725         {
1726             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1727                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1728             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1729                     nx, ny, nz,
1730                     nst[XX], nst[YY], nst[ZZ]);
1731         }
1732
1733         snew(grids->grid_th, grids->nthread);
1734         t        = 0;
1735         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1736         set_gridsize_alignment(&gridsize, pme_order);
1737         snew_aligned(grids->grid_all,
1738                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1739                      SIMD4_ALIGNMENT);
1740
1741         for (x = 0; x < grids->nc[XX]; x++)
1742         {
1743             for (y = 0; y < grids->nc[YY]; y++)
1744             {
1745                 for (z = 0; z < grids->nc[ZZ]; z++)
1746                 {
1747                     pmegrid_init(&grids->grid_th[t],
1748                                  x, y, z,
1749                                  (n[XX]*(x  ))/grids->nc[XX],
1750                                  (n[YY]*(y  ))/grids->nc[YY],
1751                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1752                                  (n[XX]*(x+1))/grids->nc[XX],
1753                                  (n[YY]*(y+1))/grids->nc[YY],
1754                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1755                                  TRUE,
1756                                  pme_order,
1757                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1758                     t++;
1759                 }
1760             }
1761         }
1762     }
1763     else
1764     {
1765         grids->grid_th = NULL;
1766     }
1767
1768     snew(grids->g2t, DIM);
1769     tfac = 1;
1770     for (d = DIM-1; d >= 0; d--)
1771     {
1772         snew(grids->g2t[d], n[d]);
1773         t = 0;
1774         for (i = 0; i < n[d]; i++)
1775         {
1776             /* The second check should match the parameters
1777              * of the pmegrid_init call above.
1778              */
1779             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1780             {
1781                 t++;
1782             }
1783             grids->g2t[d][i] = t*tfac;
1784         }
1785
1786         tfac *= grids->nc[d];
1787
1788         switch (d)
1789         {
1790             case XX: max_comm_lines = overlap_x;     break;
1791             case YY: max_comm_lines = overlap_y;     break;
1792             case ZZ: max_comm_lines = pme_order - 1; break;
1793         }
1794         grids->nthread_comm[d] = 0;
1795         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1796                grids->nthread_comm[d] < grids->nc[d])
1797         {
1798             grids->nthread_comm[d]++;
1799         }
1800         if (debug != NULL)
1801         {
1802             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1803                     'x'+d, grids->nthread_comm[d]);
1804         }
1805         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1806          * work, but this is not a problematic restriction.
1807          */
1808         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1809         {
1810             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1811         }
1812     }
1813 }
1814
1815
1816 static void pmegrids_destroy(pmegrids_t *grids)
1817 {
1818     int t;
1819
1820     if (grids->grid.grid != NULL)
1821     {
1822         sfree(grids->grid.grid);
1823
1824         if (grids->nthread > 0)
1825         {
1826             for (t = 0; t < grids->nthread; t++)
1827             {
1828                 sfree(grids->grid_th[t].grid);
1829             }
1830             sfree(grids->grid_th);
1831         }
1832     }
1833 }
1834
1835
1836 static void realloc_work(pme_work_t *work, int nkx)
1837 {
1838     int simd_width;
1839
1840     if (nkx > work->nalloc)
1841     {
1842         work->nalloc = nkx;
1843         srenew(work->mhx, work->nalloc);
1844         srenew(work->mhy, work->nalloc);
1845         srenew(work->mhz, work->nalloc);
1846         srenew(work->m2, work->nalloc);
1847         /* Allocate an aligned pointer for SIMD operations, including extra
1848          * elements at the end for padding.
1849          */
1850 #ifdef PME_SIMD
1851         simd_width = GMX_SIMD_WIDTH_HERE;
1852 #else
1853         /* We can use any alignment, apart from 0, so we use 4 */
1854         simd_width = 4;
1855 #endif
1856         sfree_aligned(work->denom);
1857         sfree_aligned(work->tmp1);
1858         sfree_aligned(work->eterm);
1859         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1860         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1861         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1862         srenew(work->m2inv, work->nalloc);
1863     }
1864 }
1865
1866
1867 static void free_work(pme_work_t *work)
1868 {
1869     sfree(work->mhx);
1870     sfree(work->mhy);
1871     sfree(work->mhz);
1872     sfree(work->m2);
1873     sfree_aligned(work->denom);
1874     sfree_aligned(work->tmp1);
1875     sfree_aligned(work->eterm);
1876     sfree(work->m2inv);
1877 }
1878
1879
1880 #ifdef PME_SIMD
1881 /* Calculate exponentials through SIMD */
1882 inline static void calc_exponentials(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1883 {
1884     {
1885         const gmx_mm_pr two = gmx_set1_pr(2.0);
1886         gmx_mm_pr f_simd;
1887         gmx_mm_pr lu;
1888         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1889         int kx;
1890         f_simd = gmx_set1_pr(f);
1891         /* We only need to calculate from start. But since start is 0 or 1
1892          * and we want to use aligned loads/stores, we always start from 0.
1893          */
1894         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1895         {
1896             tmp_d1   = gmx_load_pr(d_aligned+kx);
1897             d_inv    = gmx_inv_pr(tmp_d1);
1898             tmp_r    = gmx_load_pr(r_aligned+kx);
1899             tmp_r    = gmx_exp_pr(tmp_r);
1900             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1901             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1902             gmx_store_pr(e_aligned+kx, tmp_e);
1903         }
1904     }
1905 }
1906 #else
1907 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1908 {
1909     int kx;
1910     for (kx = start; kx < end; kx++)
1911     {
1912         d[kx] = 1.0/d[kx];
1913     }
1914     for (kx = start; kx < end; kx++)
1915     {
1916         r[kx] = exp(r[kx]);
1917     }
1918     for (kx = start; kx < end; kx++)
1919     {
1920         e[kx] = f*r[kx]*d[kx];
1921     }
1922 }
1923 #endif
1924
1925
1926 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1927                          real ewaldcoeff, real vol,
1928                          gmx_bool bEnerVir,
1929                          int nthread, int thread)
1930 {
1931     /* do recip sum over local cells in grid */
1932     /* y major, z middle, x minor or continuous */
1933     t_complex *p0;
1934     int     kx, ky, kz, maxkx, maxky, maxkz;
1935     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1936     real    mx, my, mz;
1937     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1938     real    ets2, struct2, vfactor, ets2vf;
1939     real    d1, d2, energy = 0;
1940     real    by, bz;
1941     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1942     real    rxx, ryx, ryy, rzx, rzy, rzz;
1943     pme_work_t *work;
1944     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1945     real    mhxk, mhyk, mhzk, m2k;
1946     real    corner_fac;
1947     ivec    complex_order;
1948     ivec    local_ndata, local_offset, local_size;
1949     real    elfac;
1950
1951     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1952
1953     nx = pme->nkx;
1954     ny = pme->nky;
1955     nz = pme->nkz;
1956
1957     /* Dimensions should be identical for A/B grid, so we just use A here */
1958     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1959                                       complex_order,
1960                                       local_ndata,
1961                                       local_offset,
1962                                       local_size);
1963
1964     rxx = pme->recipbox[XX][XX];
1965     ryx = pme->recipbox[YY][XX];
1966     ryy = pme->recipbox[YY][YY];
1967     rzx = pme->recipbox[ZZ][XX];
1968     rzy = pme->recipbox[ZZ][YY];
1969     rzz = pme->recipbox[ZZ][ZZ];
1970
1971     maxkx = (nx+1)/2;
1972     maxky = (ny+1)/2;
1973     maxkz = nz/2+1;
1974
1975     work  = &pme->work[thread];
1976     mhx   = work->mhx;
1977     mhy   = work->mhy;
1978     mhz   = work->mhz;
1979     m2    = work->m2;
1980     denom = work->denom;
1981     tmp1  = work->tmp1;
1982     eterm = work->eterm;
1983     m2inv = work->m2inv;
1984
1985     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1986     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1987
1988     for (iyz = iyz0; iyz < iyz1; iyz++)
1989     {
1990         iy = iyz/local_ndata[ZZ];
1991         iz = iyz - iy*local_ndata[ZZ];
1992
1993         ky = iy + local_offset[YY];
1994
1995         if (ky < maxky)
1996         {
1997             my = ky;
1998         }
1999         else
2000         {
2001             my = (ky - ny);
2002         }
2003
2004         by = M_PI*vol*pme->bsp_mod[YY][ky];
2005
2006         kz = iz + local_offset[ZZ];
2007
2008         mz = kz;
2009
2010         bz = pme->bsp_mod[ZZ][kz];
2011
2012         /* 0.5 correction for corner points */
2013         corner_fac = 1;
2014         if (kz == 0 || kz == (nz+1)/2)
2015         {
2016             corner_fac = 0.5;
2017         }
2018
2019         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2020
2021         /* We should skip the k-space point (0,0,0) */
2022         /* Note that since here x is the minor index, local_offset[XX]=0 */
2023         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2024         {
2025             kxstart = local_offset[XX];
2026         }
2027         else
2028         {
2029             kxstart = local_offset[XX] + 1;
2030             p0++;
2031         }
2032         kxend = local_offset[XX] + local_ndata[XX];
2033
2034         if (bEnerVir)
2035         {
2036             /* More expensive inner loop, especially because of the storage
2037              * of the mh elements in array's.
2038              * Because x is the minor grid index, all mh elements
2039              * depend on kx for triclinic unit cells.
2040              */
2041
2042             /* Two explicit loops to avoid a conditional inside the loop */
2043             for (kx = kxstart; kx < maxkx; kx++)
2044             {
2045                 mx = kx;
2046
2047                 mhxk      = mx * rxx;
2048                 mhyk      = mx * ryx + my * ryy;
2049                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2050                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2051                 mhx[kx]   = mhxk;
2052                 mhy[kx]   = mhyk;
2053                 mhz[kx]   = mhzk;
2054                 m2[kx]    = m2k;
2055                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2056                 tmp1[kx]  = -factor*m2k;
2057             }
2058
2059             for (kx = maxkx; kx < kxend; kx++)
2060             {
2061                 mx = (kx - nx);
2062
2063                 mhxk      = mx * rxx;
2064                 mhyk      = mx * ryx + my * ryy;
2065                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2066                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2067                 mhx[kx]   = mhxk;
2068                 mhy[kx]   = mhyk;
2069                 mhz[kx]   = mhzk;
2070                 m2[kx]    = m2k;
2071                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2072                 tmp1[kx]  = -factor*m2k;
2073             }
2074
2075             for (kx = kxstart; kx < kxend; kx++)
2076             {
2077                 m2inv[kx] = 1.0/m2[kx];
2078             }
2079
2080             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2081
2082             for (kx = kxstart; kx < kxend; kx++, p0++)
2083             {
2084                 d1      = p0->re;
2085                 d2      = p0->im;
2086
2087                 p0->re  = d1*eterm[kx];
2088                 p0->im  = d2*eterm[kx];
2089
2090                 struct2 = 2.0*(d1*d1+d2*d2);
2091
2092                 tmp1[kx] = eterm[kx]*struct2;
2093             }
2094
2095             for (kx = kxstart; kx < kxend; kx++)
2096             {
2097                 ets2     = corner_fac*tmp1[kx];
2098                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2099                 energy  += ets2;
2100
2101                 ets2vf   = ets2*vfactor;
2102                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2103                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2104                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2105                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2106                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2107                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2108             }
2109         }
2110         else
2111         {
2112             /* We don't need to calculate the energy and the virial.
2113              * In this case the triclinic overhead is small.
2114              */
2115
2116             /* Two explicit loops to avoid a conditional inside the loop */
2117
2118             for (kx = kxstart; kx < maxkx; kx++)
2119             {
2120                 mx = kx;
2121
2122                 mhxk      = mx * rxx;
2123                 mhyk      = mx * ryx + my * ryy;
2124                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2125                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2126                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2127                 tmp1[kx]  = -factor*m2k;
2128             }
2129
2130             for (kx = maxkx; kx < kxend; kx++)
2131             {
2132                 mx = (kx - nx);
2133
2134                 mhxk      = mx * rxx;
2135                 mhyk      = mx * ryx + my * ryy;
2136                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2137                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2138                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2139                 tmp1[kx]  = -factor*m2k;
2140             }
2141
2142             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2143
2144             for (kx = kxstart; kx < kxend; kx++, p0++)
2145             {
2146                 d1      = p0->re;
2147                 d2      = p0->im;
2148
2149                 p0->re  = d1*eterm[kx];
2150                 p0->im  = d2*eterm[kx];
2151             }
2152         }
2153     }
2154
2155     if (bEnerVir)
2156     {
2157         /* Update virial with local values.
2158          * The virial is symmetric by definition.
2159          * this virial seems ok for isotropic scaling, but I'm
2160          * experiencing problems on semiisotropic membranes.
2161          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2162          */
2163         work->vir[XX][XX] = 0.25*virxx;
2164         work->vir[YY][YY] = 0.25*viryy;
2165         work->vir[ZZ][ZZ] = 0.25*virzz;
2166         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2167         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2168         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2169
2170         /* This energy should be corrected for a charged system */
2171         work->energy = 0.5*energy;
2172     }
2173
2174     /* Return the loop count */
2175     return local_ndata[YY]*local_ndata[XX];
2176 }
2177
2178 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2179                              real *mesh_energy, matrix vir)
2180 {
2181     /* This function sums output over threads
2182      * and should therefore only be called after thread synchronization.
2183      */
2184     int thread;
2185
2186     *mesh_energy = pme->work[0].energy;
2187     copy_mat(pme->work[0].vir, vir);
2188
2189     for (thread = 1; thread < nthread; thread++)
2190     {
2191         *mesh_energy += pme->work[thread].energy;
2192         m_add(vir, pme->work[thread].vir, vir);
2193     }
2194 }
2195
2196 #define DO_FSPLINE(order)                      \
2197     for (ithx = 0; (ithx < order); ithx++)              \
2198     {                                              \
2199         index_x = (i0+ithx)*pny*pnz;               \
2200         tx      = thx[ithx];                       \
2201         dx      = dthx[ithx];                      \
2202                                                \
2203         for (ithy = 0; (ithy < order); ithy++)          \
2204         {                                          \
2205             index_xy = index_x+(j0+ithy)*pnz;      \
2206             ty       = thy[ithy];                  \
2207             dy       = dthy[ithy];                 \
2208             fxy1     = fz1 = 0;                    \
2209                                                \
2210             for (ithz = 0; (ithz < order); ithz++)      \
2211             {                                      \
2212                 gval  = grid[index_xy+(k0+ithz)];  \
2213                 fxy1 += thz[ithz]*gval;            \
2214                 fz1  += dthz[ithz]*gval;           \
2215             }                                      \
2216             fx += dx*ty*fxy1;                      \
2217             fy += tx*dy*fxy1;                      \
2218             fz += tx*ty*fz1;                       \
2219         }                                          \
2220     }
2221
2222
2223 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2224                               gmx_bool bClearF, pme_atomcomm_t *atc,
2225                               splinedata_t *spline,
2226                               real scale)
2227 {
2228     /* sum forces for local particles */
2229     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2230     int     index_x, index_xy;
2231     int     nx, ny, nz, pnx, pny, pnz;
2232     int *   idxptr;
2233     real    tx, ty, dx, dy, qn;
2234     real    fx, fy, fz, gval;
2235     real    fxy1, fz1;
2236     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2237     int     norder;
2238     real    rxx, ryx, ryy, rzx, rzy, rzz;
2239     int     order;
2240
2241     pme_spline_work_t *work;
2242
2243 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2244     real           thz_buffer[12],  *thz_aligned;
2245     real           dthz_buffer[12], *dthz_aligned;
2246
2247     thz_aligned  = gmx_simd4_align_real(thz_buffer);
2248     dthz_aligned = gmx_simd4_align_real(dthz_buffer);
2249 #endif
2250
2251     work = pme->spline_work;
2252
2253     order = pme->pme_order;
2254     thx   = spline->theta[XX];
2255     thy   = spline->theta[YY];
2256     thz   = spline->theta[ZZ];
2257     dthx  = spline->dtheta[XX];
2258     dthy  = spline->dtheta[YY];
2259     dthz  = spline->dtheta[ZZ];
2260     nx    = pme->nkx;
2261     ny    = pme->nky;
2262     nz    = pme->nkz;
2263     pnx   = pme->pmegrid_nx;
2264     pny   = pme->pmegrid_ny;
2265     pnz   = pme->pmegrid_nz;
2266
2267     rxx   = pme->recipbox[XX][XX];
2268     ryx   = pme->recipbox[YY][XX];
2269     ryy   = pme->recipbox[YY][YY];
2270     rzx   = pme->recipbox[ZZ][XX];
2271     rzy   = pme->recipbox[ZZ][YY];
2272     rzz   = pme->recipbox[ZZ][ZZ];
2273
2274     for (nn = 0; nn < spline->n; nn++)
2275     {
2276         n  = spline->ind[nn];
2277         qn = scale*atc->q[n];
2278
2279         if (bClearF)
2280         {
2281             atc->f[n][XX] = 0;
2282             atc->f[n][YY] = 0;
2283             atc->f[n][ZZ] = 0;
2284         }
2285         if (qn != 0)
2286         {
2287             fx     = 0;
2288             fy     = 0;
2289             fz     = 0;
2290             idxptr = atc->idx[n];
2291             norder = nn*order;
2292
2293             i0   = idxptr[XX];
2294             j0   = idxptr[YY];
2295             k0   = idxptr[ZZ];
2296
2297             /* Pointer arithmetic alert, next six statements */
2298             thx  = spline->theta[XX] + norder;
2299             thy  = spline->theta[YY] + norder;
2300             thz  = spline->theta[ZZ] + norder;
2301             dthx = spline->dtheta[XX] + norder;
2302             dthy = spline->dtheta[YY] + norder;
2303             dthz = spline->dtheta[ZZ] + norder;
2304
2305             switch (order)
2306             {
2307                 case 4:
2308 #ifdef PME_SIMD4_SPREAD_GATHER
2309 #ifdef PME_SIMD4_UNALIGNED
2310 #define PME_GATHER_F_SIMD4_ORDER4
2311 #else
2312 #define PME_GATHER_F_SIMD4_ALIGNED
2313 #define PME_ORDER 4
2314 #endif
2315 #include "pme_simd4.h"
2316 #else
2317                     DO_FSPLINE(4);
2318 #endif
2319                     break;
2320                 case 5:
2321 #ifdef PME_SIMD4_SPREAD_GATHER
2322 #define PME_GATHER_F_SIMD4_ALIGNED
2323 #define PME_ORDER 5
2324 #include "pme_simd4.h"
2325 #else
2326                     DO_FSPLINE(5);
2327 #endif
2328                     break;
2329                 default:
2330                     DO_FSPLINE(order);
2331                     break;
2332             }
2333
2334             atc->f[n][XX] += -qn*( fx*nx*rxx );
2335             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2336             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2337         }
2338     }
2339     /* Since the energy and not forces are interpolated
2340      * the net force might not be exactly zero.
2341      * This can be solved by also interpolating F, but
2342      * that comes at a cost.
2343      * A better hack is to remove the net force every
2344      * step, but that must be done at a higher level
2345      * since this routine doesn't see all atoms if running
2346      * in parallel. Don't know how important it is?  EL 990726
2347      */
2348 }
2349
2350
2351 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2352                                    pme_atomcomm_t *atc)
2353 {
2354     splinedata_t *spline;
2355     int     n, ithx, ithy, ithz, i0, j0, k0;
2356     int     index_x, index_xy;
2357     int *   idxptr;
2358     real    energy, pot, tx, ty, qn, gval;
2359     real    *thx, *thy, *thz;
2360     int     norder;
2361     int     order;
2362
2363     spline = &atc->spline[0];
2364
2365     order = pme->pme_order;
2366
2367     energy = 0;
2368     for (n = 0; (n < atc->n); n++)
2369     {
2370         qn      = atc->q[n];
2371
2372         if (qn != 0)
2373         {
2374             idxptr = atc->idx[n];
2375             norder = n*order;
2376
2377             i0   = idxptr[XX];
2378             j0   = idxptr[YY];
2379             k0   = idxptr[ZZ];
2380
2381             /* Pointer arithmetic alert, next three statements */
2382             thx  = spline->theta[XX] + norder;
2383             thy  = spline->theta[YY] + norder;
2384             thz  = spline->theta[ZZ] + norder;
2385
2386             pot = 0;
2387             for (ithx = 0; (ithx < order); ithx++)
2388             {
2389                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2390                 tx      = thx[ithx];
2391
2392                 for (ithy = 0; (ithy < order); ithy++)
2393                 {
2394                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2395                     ty       = thy[ithy];
2396
2397                     for (ithz = 0; (ithz < order); ithz++)
2398                     {
2399                         gval  = grid[index_xy+(k0+ithz)];
2400                         pot  += tx*ty*thz[ithz]*gval;
2401                     }
2402
2403                 }
2404             }
2405
2406             energy += pot*qn;
2407         }
2408     }
2409
2410     return energy;
2411 }
2412
2413 /* Macro to force loop unrolling by fixing order.
2414  * This gives a significant performance gain.
2415  */
2416 #define CALC_SPLINE(order)                     \
2417     {                                              \
2418         int j, k, l;                                 \
2419         real dr, div;                               \
2420         real data[PME_ORDER_MAX];                  \
2421         real ddata[PME_ORDER_MAX];                 \
2422                                                \
2423         for (j = 0; (j < DIM); j++)                     \
2424         {                                          \
2425             dr  = xptr[j];                         \
2426                                                \
2427             /* dr is relative offset from lower cell limit */ \
2428             data[order-1] = 0;                     \
2429             data[1]       = dr;                          \
2430             data[0]       = 1 - dr;                      \
2431                                                \
2432             for (k = 3; (k < order); k++)               \
2433             {                                      \
2434                 div       = 1.0/(k - 1.0);               \
2435                 data[k-1] = div*dr*data[k-2];      \
2436                 for (l = 1; (l < (k-1)); l++)           \
2437                 {                                  \
2438                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2439                                        data[k-l-1]);                \
2440                 }                                  \
2441                 data[0] = div*(1-dr)*data[0];      \
2442             }                                      \
2443             /* differentiate */                    \
2444             ddata[0] = -data[0];                   \
2445             for (k = 1; (k < order); k++)               \
2446             {                                      \
2447                 ddata[k] = data[k-1] - data[k];    \
2448             }                                      \
2449                                                \
2450             div           = 1.0/(order - 1);                 \
2451             data[order-1] = div*dr*data[order-2];  \
2452             for (l = 1; (l < (order-1)); l++)           \
2453             {                                      \
2454                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2455                                        (order-l-dr)*data[order-l-1]); \
2456             }                                      \
2457             data[0] = div*(1 - dr)*data[0];        \
2458                                                \
2459             for (k = 0; k < order; k++)                 \
2460             {                                      \
2461                 theta[j][i*order+k]  = data[k];    \
2462                 dtheta[j][i*order+k] = ddata[k];   \
2463             }                                      \
2464         }                                          \
2465     }
2466
2467 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2468                    rvec fractx[], int nr, int ind[], real charge[],
2469                    gmx_bool bFreeEnergy)
2470 {
2471     /* construct splines for local atoms */
2472     int  i, ii;
2473     real *xptr;
2474
2475     for (i = 0; i < nr; i++)
2476     {
2477         /* With free energy we do not use the charge check.
2478          * In most cases this will be more efficient than calling make_bsplines
2479          * twice, since usually more than half the particles have charges.
2480          */
2481         ii = ind[i];
2482         if (bFreeEnergy || charge[ii] != 0.0)
2483         {
2484             xptr = fractx[ii];
2485             switch (order)
2486             {
2487                 case 4:  CALC_SPLINE(4);     break;
2488                 case 5:  CALC_SPLINE(5);     break;
2489                 default: CALC_SPLINE(order); break;
2490             }
2491         }
2492     }
2493 }
2494
2495
2496 void make_dft_mod(real *mod, real *data, int ndata)
2497 {
2498     int i, j;
2499     real sc, ss, arg;
2500
2501     for (i = 0; i < ndata; i++)
2502     {
2503         sc = ss = 0;
2504         for (j = 0; j < ndata; j++)
2505         {
2506             arg = (2.0*M_PI*i*j)/ndata;
2507             sc += data[j]*cos(arg);
2508             ss += data[j]*sin(arg);
2509         }
2510         mod[i] = sc*sc+ss*ss;
2511     }
2512     for (i = 0; i < ndata; i++)
2513     {
2514         if (mod[i] < 1e-7)
2515         {
2516             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2517         }
2518     }
2519 }
2520
2521
2522 static void make_bspline_moduli(splinevec bsp_mod,
2523                                 int nx, int ny, int nz, int order)
2524 {
2525     int nmax = max(nx, max(ny, nz));
2526     real *data, *ddata, *bsp_data;
2527     int i, k, l;
2528     real div;
2529
2530     snew(data, order);
2531     snew(ddata, order);
2532     snew(bsp_data, nmax);
2533
2534     data[order-1] = 0;
2535     data[1]       = 0;
2536     data[0]       = 1;
2537
2538     for (k = 3; k < order; k++)
2539     {
2540         div       = 1.0/(k-1.0);
2541         data[k-1] = 0;
2542         for (l = 1; l < (k-1); l++)
2543         {
2544             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2545         }
2546         data[0] = div*data[0];
2547     }
2548     /* differentiate */
2549     ddata[0] = -data[0];
2550     for (k = 1; k < order; k++)
2551     {
2552         ddata[k] = data[k-1]-data[k];
2553     }
2554     div           = 1.0/(order-1);
2555     data[order-1] = 0;
2556     for (l = 1; l < (order-1); l++)
2557     {
2558         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2559     }
2560     data[0] = div*data[0];
2561
2562     for (i = 0; i < nmax; i++)
2563     {
2564         bsp_data[i] = 0;
2565     }
2566     for (i = 1; i <= order; i++)
2567     {
2568         bsp_data[i] = data[i-1];
2569     }
2570
2571     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2572     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2573     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2574
2575     sfree(data);
2576     sfree(ddata);
2577     sfree(bsp_data);
2578 }
2579
2580
2581 /* Return the P3M optimal influence function */
2582 static double do_p3m_influence(double z, int order)
2583 {
2584     double z2, z4;
2585
2586     z2 = z*z;
2587     z4 = z2*z2;
2588
2589     /* The formula and most constants can be found in:
2590      * Ballenegger et al., JCTC 8, 936 (2012)
2591      */
2592     switch (order)
2593     {
2594         case 2:
2595             return 1.0 - 2.0*z2/3.0;
2596             break;
2597         case 3:
2598             return 1.0 - z2 + 2.0*z4/15.0;
2599             break;
2600         case 4:
2601             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2602             break;
2603         case 5:
2604             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2605             break;
2606         case 6:
2607             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2608             break;
2609         case 7:
2610             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2611         case 8:
2612             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2613             break;
2614     }
2615
2616     return 0.0;
2617 }
2618
2619 /* Calculate the P3M B-spline moduli for one dimension */
2620 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2621 {
2622     double zarg, zai, sinzai, infl;
2623     int    maxk, i;
2624
2625     if (order > 8)
2626     {
2627         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2628     }
2629
2630     zarg = M_PI/n;
2631
2632     maxk = (n + 1)/2;
2633
2634     for (i = -maxk; i < 0; i++)
2635     {
2636         zai          = zarg*i;
2637         sinzai       = sin(zai);
2638         infl         = do_p3m_influence(sinzai, order);
2639         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2640     }
2641     bsp_mod[0] = 1.0;
2642     for (i = 1; i < maxk; i++)
2643     {
2644         zai        = zarg*i;
2645         sinzai     = sin(zai);
2646         infl       = do_p3m_influence(sinzai, order);
2647         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2648     }
2649 }
2650
2651 /* Calculate the P3M B-spline moduli */
2652 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2653                                     int nx, int ny, int nz, int order)
2654 {
2655     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2656     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2657     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2658 }
2659
2660
2661 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2662 {
2663     int nslab, n, i;
2664     int fw, bw;
2665
2666     nslab = atc->nslab;
2667
2668     n = 0;
2669     for (i = 1; i <= nslab/2; i++)
2670     {
2671         fw = (atc->nodeid + i) % nslab;
2672         bw = (atc->nodeid - i + nslab) % nslab;
2673         if (n < nslab - 1)
2674         {
2675             atc->node_dest[n] = fw;
2676             atc->node_src[n]  = bw;
2677             n++;
2678         }
2679         if (n < nslab - 1)
2680         {
2681             atc->node_dest[n] = bw;
2682             atc->node_src[n]  = fw;
2683             n++;
2684         }
2685     }
2686 }
2687
2688 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2689 {
2690     int thread;
2691
2692     if (NULL != log)
2693     {
2694         fprintf(log, "Destroying PME data structures.\n");
2695     }
2696
2697     sfree((*pmedata)->nnx);
2698     sfree((*pmedata)->nny);
2699     sfree((*pmedata)->nnz);
2700
2701     pmegrids_destroy(&(*pmedata)->pmegridA);
2702
2703     sfree((*pmedata)->fftgridA);
2704     sfree((*pmedata)->cfftgridA);
2705     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2706
2707     if ((*pmedata)->pmegridB.grid.grid != NULL)
2708     {
2709         pmegrids_destroy(&(*pmedata)->pmegridB);
2710         sfree((*pmedata)->fftgridB);
2711         sfree((*pmedata)->cfftgridB);
2712         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2713     }
2714     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2715     {
2716         free_work(&(*pmedata)->work[thread]);
2717     }
2718     sfree((*pmedata)->work);
2719
2720     sfree(*pmedata);
2721     *pmedata = NULL;
2722
2723     return 0;
2724 }
2725
2726 static int mult_up(int n, int f)
2727 {
2728     return ((n + f - 1)/f)*f;
2729 }
2730
2731
2732 static double pme_load_imbalance(gmx_pme_t pme)
2733 {
2734     int    nma, nmi;
2735     double n1, n2, n3;
2736
2737     nma = pme->nnodes_major;
2738     nmi = pme->nnodes_minor;
2739
2740     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2741     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2742     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2743
2744     /* pme_solve is roughly double the cost of an fft */
2745
2746     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2747 }
2748
2749 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
2750                           int dimind, gmx_bool bSpread)
2751 {
2752     int nk, k, s, thread;
2753
2754     atc->dimind    = dimind;
2755     atc->nslab     = 1;
2756     atc->nodeid    = 0;
2757     atc->pd_nalloc = 0;
2758 #ifdef GMX_MPI
2759     if (pme->nnodes > 1)
2760     {
2761         atc->mpi_comm = pme->mpi_comm_d[dimind];
2762         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2763         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2764     }
2765     if (debug)
2766     {
2767         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2768     }
2769 #endif
2770
2771     atc->bSpread   = bSpread;
2772     atc->pme_order = pme->pme_order;
2773
2774     if (atc->nslab > 1)
2775     {
2776         /* These three allocations are not required for particle decomp. */
2777         snew(atc->node_dest, atc->nslab);
2778         snew(atc->node_src, atc->nslab);
2779         setup_coordinate_communication(atc);
2780
2781         snew(atc->count_thread, pme->nthread);
2782         for (thread = 0; thread < pme->nthread; thread++)
2783         {
2784             snew(atc->count_thread[thread], atc->nslab);
2785         }
2786         atc->count = atc->count_thread[0];
2787         snew(atc->rcount, atc->nslab);
2788         snew(atc->buf_index, atc->nslab);
2789     }
2790
2791     atc->nthread = pme->nthread;
2792     if (atc->nthread > 1)
2793     {
2794         snew(atc->thread_plist, atc->nthread);
2795     }
2796     snew(atc->spline, atc->nthread);
2797     for (thread = 0; thread < atc->nthread; thread++)
2798     {
2799         if (atc->nthread > 1)
2800         {
2801             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2802             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2803         }
2804         snew(atc->spline[thread].thread_one, pme->nthread);
2805         atc->spline[thread].thread_one[thread] = 1;
2806     }
2807 }
2808
2809 static void
2810 init_overlap_comm(pme_overlap_t *  ol,
2811                   int              norder,
2812 #ifdef GMX_MPI
2813                   MPI_Comm         comm,
2814 #endif
2815                   int              nnodes,
2816                   int              nodeid,
2817                   int              ndata,
2818                   int              commplainsize)
2819 {
2820     int lbnd, rbnd, maxlr, b, i;
2821     int exten;
2822     int nn, nk;
2823     pme_grid_comm_t *pgc;
2824     gmx_bool bCont;
2825     int fft_start, fft_end, send_index1, recv_index1;
2826 #ifdef GMX_MPI
2827     MPI_Status stat;
2828
2829     ol->mpi_comm = comm;
2830 #endif
2831
2832     ol->nnodes = nnodes;
2833     ol->nodeid = nodeid;
2834
2835     /* Linear translation of the PME grid won't affect reciprocal space
2836      * calculations, so to optimize we only interpolate "upwards",
2837      * which also means we only have to consider overlap in one direction.
2838      * I.e., particles on this node might also be spread to grid indices
2839      * that belong to higher nodes (modulo nnodes)
2840      */
2841
2842     snew(ol->s2g0, ol->nnodes+1);
2843     snew(ol->s2g1, ol->nnodes);
2844     if (debug)
2845     {
2846         fprintf(debug, "PME slab boundaries:");
2847     }
2848     for (i = 0; i < nnodes; i++)
2849     {
2850         /* s2g0 the local interpolation grid start.
2851          * s2g1 the local interpolation grid end.
2852          * Because grid overlap communication only goes forward,
2853          * the grid the slabs for fft's should be rounded down.
2854          */
2855         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2856         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2857
2858         if (debug)
2859         {
2860             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2861         }
2862     }
2863     ol->s2g0[nnodes] = ndata;
2864     if (debug)
2865     {
2866         fprintf(debug, "\n");
2867     }
2868
2869     /* Determine with how many nodes we need to communicate the grid overlap */
2870     b = 0;
2871     do
2872     {
2873         b++;
2874         bCont = FALSE;
2875         for (i = 0; i < nnodes; i++)
2876         {
2877             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2878                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2879             {
2880                 bCont = TRUE;
2881             }
2882         }
2883     }
2884     while (bCont && b < nnodes);
2885     ol->noverlap_nodes = b - 1;
2886
2887     snew(ol->send_id, ol->noverlap_nodes);
2888     snew(ol->recv_id, ol->noverlap_nodes);
2889     for (b = 0; b < ol->noverlap_nodes; b++)
2890     {
2891         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2892         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2893     }
2894     snew(ol->comm_data, ol->noverlap_nodes);
2895
2896     ol->send_size = 0;
2897     for (b = 0; b < ol->noverlap_nodes; b++)
2898     {
2899         pgc = &ol->comm_data[b];
2900         /* Send */
2901         fft_start        = ol->s2g0[ol->send_id[b]];
2902         fft_end          = ol->s2g0[ol->send_id[b]+1];
2903         if (ol->send_id[b] < nodeid)
2904         {
2905             fft_start += ndata;
2906             fft_end   += ndata;
2907         }
2908         send_index1       = ol->s2g1[nodeid];
2909         send_index1       = min(send_index1, fft_end);
2910         pgc->send_index0  = fft_start;
2911         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2912         ol->send_size    += pgc->send_nindex;
2913
2914         /* We always start receiving to the first index of our slab */
2915         fft_start        = ol->s2g0[ol->nodeid];
2916         fft_end          = ol->s2g0[ol->nodeid+1];
2917         recv_index1      = ol->s2g1[ol->recv_id[b]];
2918         if (ol->recv_id[b] > nodeid)
2919         {
2920             recv_index1 -= ndata;
2921         }
2922         recv_index1      = min(recv_index1, fft_end);
2923         pgc->recv_index0 = fft_start;
2924         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2925     }
2926
2927 #ifdef GMX_MPI
2928     /* Communicate the buffer sizes to receive */
2929     for (b = 0; b < ol->noverlap_nodes; b++)
2930     {
2931         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2932                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2933                      ol->mpi_comm, &stat);
2934     }
2935 #endif
2936
2937     /* For non-divisible grid we need pme_order iso pme_order-1 */
2938     snew(ol->sendbuf, norder*commplainsize);
2939     snew(ol->recvbuf, norder*commplainsize);
2940 }
2941
2942 static void
2943 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2944                               int **global_to_local,
2945                               real **fraction_shift)
2946 {
2947     int i;
2948     int * gtl;
2949     real * fsh;
2950
2951     snew(gtl, 5*n);
2952     snew(fsh, 5*n);
2953     for (i = 0; (i < 5*n); i++)
2954     {
2955         /* Determine the global to local grid index */
2956         gtl[i] = (i - local_start + n) % n;
2957         /* For coordinates that fall within the local grid the fraction
2958          * is correct, we don't need to shift it.
2959          */
2960         fsh[i] = 0;
2961         if (local_range < n)
2962         {
2963             /* Due to rounding issues i could be 1 beyond the lower or
2964              * upper boundary of the local grid. Correct the index for this.
2965              * If we shift the index, we need to shift the fraction by
2966              * the same amount in the other direction to not affect
2967              * the weights.
2968              * Note that due to this shifting the weights at the end of
2969              * the spline might change, but that will only involve values
2970              * between zero and values close to the precision of a real,
2971              * which is anyhow the accuracy of the whole mesh calculation.
2972              */
2973             /* With local_range=0 we should not change i=local_start */
2974             if (i % n != local_start)
2975             {
2976                 if (gtl[i] == n-1)
2977                 {
2978                     gtl[i] = 0;
2979                     fsh[i] = -1;
2980                 }
2981                 else if (gtl[i] == local_range)
2982                 {
2983                     gtl[i] = local_range - 1;
2984                     fsh[i] = 1;
2985                 }
2986             }
2987         }
2988     }
2989
2990     *global_to_local = gtl;
2991     *fraction_shift  = fsh;
2992 }
2993
2994 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
2995 {
2996     pme_spline_work_t *work;
2997
2998 #ifdef PME_SIMD4_SPREAD_GATHER
2999     real         tmp[12], *tmp_aligned;
3000     gmx_simd4_pr zero_S;
3001     gmx_simd4_pr real_mask_S0, real_mask_S1;
3002     int          of, i;
3003
3004     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3005
3006     tmp_aligned = gmx_simd4_align_real(tmp);
3007
3008     zero_S = gmx_simd4_setzero_pr();
3009
3010     /* Generate bit masks to mask out the unused grid entries,
3011      * as we only operate on order of the 8 grid entries that are
3012      * load into 2 SIMD registers.
3013      */
3014     for (of = 0; of < 8-(order-1); of++)
3015     {
3016         for (i = 0; i < 8; i++)
3017         {
3018             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3019         }
3020         real_mask_S0      = gmx_simd4_load_pr(tmp_aligned);
3021         real_mask_S1      = gmx_simd4_load_pr(tmp_aligned+4);
3022         work->mask_S0[of] = gmx_simd4_cmplt_pr(real_mask_S0, zero_S);
3023         work->mask_S1[of] = gmx_simd4_cmplt_pr(real_mask_S1, zero_S);
3024     }
3025 #else
3026     work = NULL;
3027 #endif
3028
3029     return work;
3030 }
3031
3032 void gmx_pme_check_restrictions(int pme_order,
3033                                 int nkx, int nky, int nkz,
3034                                 int nnodes_major,
3035                                 int nnodes_minor,
3036                                 gmx_bool bUseThreads,
3037                                 gmx_bool bFatal,
3038                                 gmx_bool *bValidSettings)
3039 {
3040     if (pme_order > PME_ORDER_MAX)
3041     {
3042         if (!bFatal)
3043         {
3044             *bValidSettings = FALSE;
3045             return;
3046         }
3047         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3048                   pme_order, PME_ORDER_MAX);
3049     }
3050
3051     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3052         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3053         nkz <= pme_order)
3054     {
3055         if (!bFatal)
3056         {
3057             *bValidSettings = FALSE;
3058             return;
3059         }
3060         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3061                   pme_order);
3062     }
3063
3064     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3065      * We only allow multiple communication pulses in dim 1, not in dim 0.
3066      */
3067     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3068                         nkx != nnodes_major*(pme_order - 1)))
3069     {
3070         if (!bFatal)
3071         {
3072             *bValidSettings = FALSE;
3073             return;
3074         }
3075         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3076                   nkx/(double)nnodes_major, pme_order);
3077     }
3078
3079     if (bValidSettings != NULL)
3080     {
3081         *bValidSettings = TRUE;
3082     }
3083
3084     return;
3085 }
3086
3087 int gmx_pme_init(gmx_pme_t *         pmedata,
3088                  t_commrec *         cr,
3089                  int                 nnodes_major,
3090                  int                 nnodes_minor,
3091                  t_inputrec *        ir,
3092                  int                 homenr,
3093                  gmx_bool            bFreeEnergy,
3094                  gmx_bool            bReproducible,
3095                  int                 nthread)
3096 {
3097     gmx_pme_t pme = NULL;
3098
3099     int  use_threads, sum_use_threads;
3100     ivec ndata;
3101
3102     if (debug)
3103     {
3104         fprintf(debug, "Creating PME data structures.\n");
3105     }
3106     snew(pme, 1);
3107
3108     pme->redist_init         = FALSE;
3109     pme->sum_qgrid_tmp       = NULL;
3110     pme->sum_qgrid_dd_tmp    = NULL;
3111     pme->buf_nalloc          = 0;
3112     pme->redist_buf_nalloc   = 0;
3113
3114     pme->nnodes              = 1;
3115     pme->bPPnode             = TRUE;
3116
3117     pme->nnodes_major        = nnodes_major;
3118     pme->nnodes_minor        = nnodes_minor;
3119
3120 #ifdef GMX_MPI
3121     if (nnodes_major*nnodes_minor > 1)
3122     {
3123         pme->mpi_comm = cr->mpi_comm_mygroup;
3124
3125         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3126         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3127         if (pme->nnodes != nnodes_major*nnodes_minor)
3128         {
3129             gmx_incons("PME node count mismatch");
3130         }
3131     }
3132     else
3133     {
3134         pme->mpi_comm = MPI_COMM_NULL;
3135     }
3136 #endif
3137
3138     if (pme->nnodes == 1)
3139     {
3140 #ifdef GMX_MPI
3141         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3142         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3143 #endif
3144         pme->ndecompdim   = 0;
3145         pme->nodeid_major = 0;
3146         pme->nodeid_minor = 0;
3147 #ifdef GMX_MPI
3148         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3149 #endif
3150     }
3151     else
3152     {
3153         if (nnodes_minor == 1)
3154         {
3155 #ifdef GMX_MPI
3156             pme->mpi_comm_d[0] = pme->mpi_comm;
3157             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3158 #endif
3159             pme->ndecompdim   = 1;
3160             pme->nodeid_major = pme->nodeid;
3161             pme->nodeid_minor = 0;
3162
3163         }
3164         else if (nnodes_major == 1)
3165         {
3166 #ifdef GMX_MPI
3167             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3168             pme->mpi_comm_d[1] = pme->mpi_comm;
3169 #endif
3170             pme->ndecompdim   = 1;
3171             pme->nodeid_major = 0;
3172             pme->nodeid_minor = pme->nodeid;
3173         }
3174         else
3175         {
3176             if (pme->nnodes % nnodes_major != 0)
3177             {
3178                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3179             }
3180             pme->ndecompdim = 2;
3181
3182 #ifdef GMX_MPI
3183             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3184                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3185             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3186                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3187
3188             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3189             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3190             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3191             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3192 #endif
3193         }
3194         pme->bPPnode = (cr->duty & DUTY_PP);
3195     }
3196
3197     pme->nthread = nthread;
3198
3199     /* Check if any of the PME MPI ranks uses threads */
3200     use_threads = (pme->nthread > 1 ? 1 : 0);
3201 #ifdef GMX_MPI
3202     if (pme->nnodes > 1)
3203     {
3204         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3205                       MPI_SUM, pme->mpi_comm);
3206     }
3207     else
3208 #endif
3209     {
3210         sum_use_threads = use_threads;
3211     }
3212     pme->bUseThreads = (sum_use_threads > 0);
3213
3214     if (ir->ePBC == epbcSCREW)
3215     {
3216         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3217     }
3218
3219     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3220     pme->nkx         = ir->nkx;
3221     pme->nky         = ir->nky;
3222     pme->nkz         = ir->nkz;
3223     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3224     pme->pme_order   = ir->pme_order;
3225     pme->epsilon_r   = ir->epsilon_r;
3226
3227     /* If we violate restrictions, generate a fatal error here */
3228     gmx_pme_check_restrictions(pme->pme_order,
3229                                pme->nkx, pme->nky, pme->nkz,
3230                                pme->nnodes_major,
3231                                pme->nnodes_minor,
3232                                pme->bUseThreads,
3233                                TRUE,
3234                                NULL);
3235
3236     if (pme->nnodes > 1)
3237     {
3238         double imbal;
3239
3240 #ifdef GMX_MPI
3241         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3242         MPI_Type_commit(&(pme->rvec_mpi));
3243 #endif
3244
3245         /* Note that the charge spreading and force gathering, which usually
3246          * takes about the same amount of time as FFT+solve_pme,
3247          * is always fully load balanced
3248          * (unless the charge distribution is inhomogeneous).
3249          */
3250
3251         imbal = pme_load_imbalance(pme);
3252         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3253         {
3254             fprintf(stderr,
3255                     "\n"
3256                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3257                     "      For optimal PME load balancing\n"
3258                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3259                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3260                     "\n",
3261                     (int)((imbal-1)*100 + 0.5),
3262                     pme->nkx, pme->nky, pme->nnodes_major,
3263                     pme->nky, pme->nkz, pme->nnodes_minor);
3264         }
3265     }
3266
3267     /* For non-divisible grid we need pme_order iso pme_order-1 */
3268     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3269      * y is always copied through a buffer: we don't need padding in z,
3270      * but we do need the overlap in x because of the communication order.
3271      */
3272     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3273 #ifdef GMX_MPI
3274                       pme->mpi_comm_d[0],
3275 #endif
3276                       pme->nnodes_major, pme->nodeid_major,
3277                       pme->nkx,
3278                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3279
3280     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3281      * We do this with an offset buffer of equal size, so we need to allocate
3282      * extra for the offset. That's what the (+1)*pme->nkz is for.
3283      */
3284     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3285 #ifdef GMX_MPI
3286                       pme->mpi_comm_d[1],
3287 #endif
3288                       pme->nnodes_minor, pme->nodeid_minor,
3289                       pme->nky,
3290                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3291
3292     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3293      * Note that gmx_pme_check_restrictions checked for this already.
3294      */
3295     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3296     {
3297         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3298     }
3299
3300     snew(pme->bsp_mod[XX], pme->nkx);
3301     snew(pme->bsp_mod[YY], pme->nky);
3302     snew(pme->bsp_mod[ZZ], pme->nkz);
3303
3304     /* The required size of the interpolation grid, including overlap.
3305      * The allocated size (pmegrid_n?) might be slightly larger.
3306      */
3307     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3308         pme->overlap[0].s2g0[pme->nodeid_major];
3309     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3310         pme->overlap[1].s2g0[pme->nodeid_minor];
3311     pme->pmegrid_nz_base = pme->nkz;
3312     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3313     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3314
3315     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3316     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3317     pme->pmegrid_start_iz = 0;
3318
3319     make_gridindex5_to_localindex(pme->nkx,
3320                                   pme->pmegrid_start_ix,
3321                                   pme->pmegrid_nx - (pme->pme_order-1),
3322                                   &pme->nnx, &pme->fshx);
3323     make_gridindex5_to_localindex(pme->nky,
3324                                   pme->pmegrid_start_iy,
3325                                   pme->pmegrid_ny - (pme->pme_order-1),
3326                                   &pme->nny, &pme->fshy);
3327     make_gridindex5_to_localindex(pme->nkz,
3328                                   pme->pmegrid_start_iz,
3329                                   pme->pmegrid_nz_base,
3330                                   &pme->nnz, &pme->fshz);
3331
3332     pmegrids_init(&pme->pmegridA,
3333                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3334                   pme->pmegrid_nz_base,
3335                   pme->pme_order,
3336                   pme->bUseThreads,
3337                   pme->nthread,
3338                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3339                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3340
3341     pme->spline_work = make_pme_spline_work(pme->pme_order);
3342
3343     ndata[0] = pme->nkx;
3344     ndata[1] = pme->nky;
3345     ndata[2] = pme->nkz;
3346
3347     /* This routine will allocate the grid data to fit the FFTs */
3348     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3349                             &pme->fftgridA, &pme->cfftgridA,
3350                             pme->mpi_comm_d,
3351                             bReproducible, pme->nthread);
3352
3353     if (bFreeEnergy)
3354     {
3355         pmegrids_init(&pme->pmegridB,
3356                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3357                       pme->pmegrid_nz_base,
3358                       pme->pme_order,
3359                       pme->bUseThreads,
3360                       pme->nthread,
3361                       pme->nkx % pme->nnodes_major != 0,
3362                       pme->nky % pme->nnodes_minor != 0);
3363
3364         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3365                                 &pme->fftgridB, &pme->cfftgridB,
3366                                 pme->mpi_comm_d,
3367                                 bReproducible, pme->nthread);
3368     }
3369     else
3370     {
3371         pme->pmegridB.grid.grid = NULL;
3372         pme->fftgridB           = NULL;
3373         pme->cfftgridB          = NULL;
3374     }
3375
3376     if (!pme->bP3M)
3377     {
3378         /* Use plain SPME B-spline interpolation */
3379         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3380     }
3381     else
3382     {
3383         /* Use the P3M grid-optimized influence function */
3384         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3385     }
3386
3387     /* Use atc[0] for spreading */
3388     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3389     if (pme->ndecompdim >= 2)
3390     {
3391         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3392     }
3393
3394     if (pme->nnodes == 1)
3395     {
3396         pme->atc[0].n = homenr;
3397         pme_realloc_atomcomm_things(&pme->atc[0]);
3398     }
3399
3400     {
3401         int thread;
3402
3403         /* Use fft5d, order after FFT is y major, z, x minor */
3404
3405         snew(pme->work, pme->nthread);
3406         for (thread = 0; thread < pme->nthread; thread++)
3407         {
3408             realloc_work(&pme->work[thread], pme->nkx);
3409         }
3410     }
3411
3412     *pmedata = pme;
3413
3414     return 0;
3415 }
3416
3417 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3418 {
3419     int d, t;
3420
3421     for (d = 0; d < DIM; d++)
3422     {
3423         if (new->grid.n[d] > old->grid.n[d])
3424         {
3425             return;
3426         }
3427     }
3428
3429     sfree_aligned(new->grid.grid);
3430     new->grid.grid = old->grid.grid;
3431
3432     if (new->grid_th != NULL && new->nthread == old->nthread)
3433     {
3434         sfree_aligned(new->grid_all);
3435         for (t = 0; t < new->nthread; t++)
3436         {
3437             new->grid_th[t].grid = old->grid_th[t].grid;
3438         }
3439     }
3440 }
3441
3442 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3443                    t_commrec *         cr,
3444                    gmx_pme_t           pme_src,
3445                    const t_inputrec *  ir,
3446                    ivec                grid_size)
3447 {
3448     t_inputrec irc;
3449     int homenr;
3450     int ret;
3451
3452     irc     = *ir;
3453     irc.nkx = grid_size[XX];
3454     irc.nky = grid_size[YY];
3455     irc.nkz = grid_size[ZZ];
3456
3457     if (pme_src->nnodes == 1)
3458     {
3459         homenr = pme_src->atc[0].n;
3460     }
3461     else
3462     {
3463         homenr = -1;
3464     }
3465
3466     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3467                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3468
3469     if (ret == 0)
3470     {
3471         /* We can easily reuse the allocated pme grids in pme_src */
3472         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3473         /* We would like to reuse the fft grids, but that's harder */
3474     }
3475
3476     return ret;
3477 }
3478
3479
3480 static void copy_local_grid(gmx_pme_t pme,
3481                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3482 {
3483     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3484     int  fft_my, fft_mz;
3485     int  nsx, nsy, nsz;
3486     ivec nf;
3487     int  offx, offy, offz, x, y, z, i0, i0t;
3488     int  d;
3489     pmegrid_t *pmegrid;
3490     real *grid_th;
3491
3492     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3493                                    local_fft_ndata,
3494                                    local_fft_offset,
3495                                    local_fft_size);
3496     fft_my = local_fft_size[YY];
3497     fft_mz = local_fft_size[ZZ];
3498
3499     pmegrid = &pmegrids->grid_th[thread];
3500
3501     nsx = pmegrid->s[XX];
3502     nsy = pmegrid->s[YY];
3503     nsz = pmegrid->s[ZZ];
3504
3505     for (d = 0; d < DIM; d++)
3506     {
3507         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3508                     local_fft_ndata[d] - pmegrid->offset[d]);
3509     }
3510
3511     offx = pmegrid->offset[XX];
3512     offy = pmegrid->offset[YY];
3513     offz = pmegrid->offset[ZZ];
3514
3515     /* Directly copy the non-overlapping parts of the local grids.
3516      * This also initializes the full grid.
3517      */
3518     grid_th = pmegrid->grid;
3519     for (x = 0; x < nf[XX]; x++)
3520     {
3521         for (y = 0; y < nf[YY]; y++)
3522         {
3523             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3524             i0t = (x*nsy + y)*nsz;
3525             for (z = 0; z < nf[ZZ]; z++)
3526             {
3527                 fftgrid[i0+z] = grid_th[i0t+z];
3528             }
3529         }
3530     }
3531 }
3532
3533 static void
3534 reduce_threadgrid_overlap(gmx_pme_t pme,
3535                           const pmegrids_t *pmegrids, int thread,
3536                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3537 {
3538     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3539     int  fft_nx, fft_ny, fft_nz;
3540     int  fft_my, fft_mz;
3541     int  buf_my = -1;
3542     int  nsx, nsy, nsz;
3543     ivec ne;
3544     int  offx, offy, offz, x, y, z, i0, i0t;
3545     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3546     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3547     gmx_bool bCommX, bCommY;
3548     int  d;
3549     int  thread_f;
3550     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3551     const real *grid_th;
3552     real *commbuf = NULL;
3553
3554     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3555                                    local_fft_ndata,
3556                                    local_fft_offset,
3557                                    local_fft_size);
3558     fft_nx = local_fft_ndata[XX];
3559     fft_ny = local_fft_ndata[YY];
3560     fft_nz = local_fft_ndata[ZZ];
3561
3562     fft_my = local_fft_size[YY];
3563     fft_mz = local_fft_size[ZZ];
3564
3565     /* This routine is called when all thread have finished spreading.
3566      * Here each thread sums grid contributions calculated by other threads
3567      * to the thread local grid volume.
3568      * To minimize the number of grid copying operations,
3569      * this routines sums immediately from the pmegrid to the fftgrid.
3570      */
3571
3572     /* Determine which part of the full node grid we should operate on,
3573      * this is our thread local part of the full grid.
3574      */
3575     pmegrid = &pmegrids->grid_th[thread];
3576
3577     for (d = 0; d < DIM; d++)
3578     {
3579         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3580                     local_fft_ndata[d]);
3581     }
3582
3583     offx = pmegrid->offset[XX];
3584     offy = pmegrid->offset[YY];
3585     offz = pmegrid->offset[ZZ];
3586
3587
3588     bClearBufX  = TRUE;
3589     bClearBufY  = TRUE;
3590     bClearBufXY = TRUE;
3591
3592     /* Now loop over all the thread data blocks that contribute
3593      * to the grid region we (our thread) are operating on.
3594      */
3595     /* Note that ffy_nx/y is equal to the number of grid points
3596      * between the first point of our node grid and the one of the next node.
3597      */
3598     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3599     {
3600         fx     = pmegrid->ci[XX] + sx;
3601         ox     = 0;
3602         bCommX = FALSE;
3603         if (fx < 0)
3604         {
3605             fx    += pmegrids->nc[XX];
3606             ox    -= fft_nx;
3607             bCommX = (pme->nnodes_major > 1);
3608         }
3609         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3610         ox       += pmegrid_g->offset[XX];
3611         if (!bCommX)
3612         {
3613             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3614         }
3615         else
3616         {
3617             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3618         }
3619
3620         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3621         {
3622             fy     = pmegrid->ci[YY] + sy;
3623             oy     = 0;
3624             bCommY = FALSE;
3625             if (fy < 0)
3626             {
3627                 fy    += pmegrids->nc[YY];
3628                 oy    -= fft_ny;
3629                 bCommY = (pme->nnodes_minor > 1);
3630             }
3631             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3632             oy       += pmegrid_g->offset[YY];
3633             if (!bCommY)
3634             {
3635                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3636             }
3637             else
3638             {
3639                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3640             }
3641
3642             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3643             {
3644                 fz = pmegrid->ci[ZZ] + sz;
3645                 oz = 0;
3646                 if (fz < 0)
3647                 {
3648                     fz += pmegrids->nc[ZZ];
3649                     oz -= fft_nz;
3650                 }
3651                 pmegrid_g = &pmegrids->grid_th[fz];
3652                 oz       += pmegrid_g->offset[ZZ];
3653                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3654
3655                 if (sx == 0 && sy == 0 && sz == 0)
3656                 {
3657                     /* We have already added our local contribution
3658                      * before calling this routine, so skip it here.
3659                      */
3660                     continue;
3661                 }
3662
3663                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3664
3665                 pmegrid_f = &pmegrids->grid_th[thread_f];
3666
3667                 grid_th = pmegrid_f->grid;
3668
3669                 nsx = pmegrid_f->s[XX];
3670                 nsy = pmegrid_f->s[YY];
3671                 nsz = pmegrid_f->s[ZZ];
3672
3673 #ifdef DEBUG_PME_REDUCE
3674                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3675                        pme->nodeid, thread, thread_f,
3676                        pme->pmegrid_start_ix,
3677                        pme->pmegrid_start_iy,
3678                        pme->pmegrid_start_iz,
3679                        sx, sy, sz,
3680                        offx-ox, tx1-ox, offx, tx1,
3681                        offy-oy, ty1-oy, offy, ty1,
3682                        offz-oz, tz1-oz, offz, tz1);
3683 #endif
3684
3685                 if (!(bCommX || bCommY))
3686                 {
3687                     /* Copy from the thread local grid to the node grid */
3688                     for (x = offx; x < tx1; x++)
3689                     {
3690                         for (y = offy; y < ty1; y++)
3691                         {
3692                             i0  = (x*fft_my + y)*fft_mz;
3693                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3694                             for (z = offz; z < tz1; z++)
3695                             {
3696                                 fftgrid[i0+z] += grid_th[i0t+z];
3697                             }
3698                         }
3699                     }
3700                 }
3701                 else
3702                 {
3703                     /* The order of this conditional decides
3704                      * where the corner volume gets stored with x+y decomp.
3705                      */
3706                     if (bCommY)
3707                     {
3708                         commbuf = commbuf_y;
3709                         buf_my  = ty1 - offy;
3710                         if (bCommX)
3711                         {
3712                             /* We index commbuf modulo the local grid size */
3713                             commbuf += buf_my*fft_nx*fft_nz;
3714
3715                             bClearBuf   = bClearBufXY;
3716                             bClearBufXY = FALSE;
3717                         }
3718                         else
3719                         {
3720                             bClearBuf  = bClearBufY;
3721                             bClearBufY = FALSE;
3722                         }
3723                     }
3724                     else
3725                     {
3726                         commbuf    = commbuf_x;
3727                         buf_my     = fft_ny;
3728                         bClearBuf  = bClearBufX;
3729                         bClearBufX = FALSE;
3730                     }
3731
3732                     /* Copy to the communication buffer */
3733                     for (x = offx; x < tx1; x++)
3734                     {
3735                         for (y = offy; y < ty1; y++)
3736                         {
3737                             i0  = (x*buf_my + y)*fft_nz;
3738                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3739
3740                             if (bClearBuf)
3741                             {
3742                                 /* First access of commbuf, initialize it */
3743                                 for (z = offz; z < tz1; z++)
3744                                 {
3745                                     commbuf[i0+z]  = grid_th[i0t+z];
3746                                 }
3747                             }
3748                             else
3749                             {
3750                                 for (z = offz; z < tz1; z++)
3751                                 {
3752                                     commbuf[i0+z] += grid_th[i0t+z];
3753                                 }
3754                             }
3755                         }
3756                     }
3757                 }
3758             }
3759         }
3760     }
3761 }
3762
3763
3764 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3765 {
3766     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3767     pme_overlap_t *overlap;
3768     int  send_index0, send_nindex;
3769     int  recv_nindex;
3770 #ifdef GMX_MPI
3771     MPI_Status stat;
3772 #endif
3773     int  send_size_y, recv_size_y;
3774     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3775     real *sendptr, *recvptr;
3776     int  x, y, z, indg, indb;
3777
3778     /* Note that this routine is only used for forward communication.
3779      * Since the force gathering, unlike the charge spreading,
3780      * can be trivially parallelized over the particles,
3781      * the backwards process is much simpler and can use the "old"
3782      * communication setup.
3783      */
3784
3785     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3786                                    local_fft_ndata,
3787                                    local_fft_offset,
3788                                    local_fft_size);
3789
3790     if (pme->nnodes_minor > 1)
3791     {
3792         /* Major dimension */
3793         overlap = &pme->overlap[1];
3794
3795         if (pme->nnodes_major > 1)
3796         {
3797             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3798         }
3799         else
3800         {
3801             size_yx = 0;
3802         }
3803         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3804
3805         send_size_y = overlap->send_size;
3806
3807         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3808         {
3809             send_id       = overlap->send_id[ipulse];
3810             recv_id       = overlap->recv_id[ipulse];
3811             send_index0   =
3812                 overlap->comm_data[ipulse].send_index0 -
3813                 overlap->comm_data[0].send_index0;
3814             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3815             /* We don't use recv_index0, as we always receive starting at 0 */
3816             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3817             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3818
3819             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3820             recvptr = overlap->recvbuf;
3821
3822 #ifdef GMX_MPI
3823             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3824                          send_id, ipulse,
3825                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3826                          recv_id, ipulse,
3827                          overlap->mpi_comm, &stat);
3828 #endif
3829
3830             for (x = 0; x < local_fft_ndata[XX]; x++)
3831             {
3832                 for (y = 0; y < recv_nindex; y++)
3833                 {
3834                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3835                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3836                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3837                     {
3838                         fftgrid[indg+z] += recvptr[indb+z];
3839                     }
3840                 }
3841             }
3842
3843             if (pme->nnodes_major > 1)
3844             {
3845                 /* Copy from the received buffer to the send buffer for dim 0 */
3846                 sendptr = pme->overlap[0].sendbuf;
3847                 for (x = 0; x < size_yx; x++)
3848                 {
3849                     for (y = 0; y < recv_nindex; y++)
3850                     {
3851                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3852                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3853                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3854                         {
3855                             sendptr[indg+z] += recvptr[indb+z];
3856                         }
3857                     }
3858                 }
3859             }
3860         }
3861     }
3862
3863     /* We only support a single pulse here.
3864      * This is not a severe limitation, as this code is only used
3865      * with OpenMP and with OpenMP the (PME) domains can be larger.
3866      */
3867     if (pme->nnodes_major > 1)
3868     {
3869         /* Major dimension */
3870         overlap = &pme->overlap[0];
3871
3872         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3873         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3874
3875         ipulse = 0;
3876
3877         send_id       = overlap->send_id[ipulse];
3878         recv_id       = overlap->recv_id[ipulse];
3879         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3880         /* We don't use recv_index0, as we always receive starting at 0 */
3881         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3882
3883         sendptr = overlap->sendbuf;
3884         recvptr = overlap->recvbuf;
3885
3886         if (debug != NULL)
3887         {
3888             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3889                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3890         }
3891
3892 #ifdef GMX_MPI
3893         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3894                      send_id, ipulse,
3895                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3896                      recv_id, ipulse,
3897                      overlap->mpi_comm, &stat);
3898 #endif
3899
3900         for (x = 0; x < recv_nindex; x++)
3901         {
3902             for (y = 0; y < local_fft_ndata[YY]; y++)
3903             {
3904                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3905                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3906                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3907                 {
3908                     fftgrid[indg+z] += recvptr[indb+z];
3909                 }
3910             }
3911         }
3912     }
3913 }
3914
3915
3916 static void spread_on_grid(gmx_pme_t pme,
3917                            pme_atomcomm_t *atc, pmegrids_t *grids,
3918                            gmx_bool bCalcSplines, gmx_bool bSpread,
3919                            real *fftgrid)
3920 {
3921     int nthread, thread;
3922 #ifdef PME_TIME_THREADS
3923     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3924     static double cs1     = 0, cs2 = 0, cs3 = 0;
3925     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3926     static int cnt        = 0;
3927 #endif
3928
3929     nthread = pme->nthread;
3930     assert(nthread > 0);
3931
3932 #ifdef PME_TIME_THREADS
3933     c1 = omp_cyc_start();
3934 #endif
3935     if (bCalcSplines)
3936     {
3937 #pragma omp parallel for num_threads(nthread) schedule(static)
3938         for (thread = 0; thread < nthread; thread++)
3939         {
3940             int start, end;
3941
3942             start = atc->n* thread   /nthread;
3943             end   = atc->n*(thread+1)/nthread;
3944
3945             /* Compute fftgrid index for all atoms,
3946              * with help of some extra variables.
3947              */
3948             calc_interpolation_idx(pme, atc, start, end, thread);
3949         }
3950     }
3951 #ifdef PME_TIME_THREADS
3952     c1   = omp_cyc_end(c1);
3953     cs1 += (double)c1;
3954 #endif
3955
3956 #ifdef PME_TIME_THREADS
3957     c2 = omp_cyc_start();
3958 #endif
3959 #pragma omp parallel for num_threads(nthread) schedule(static)
3960     for (thread = 0; thread < nthread; thread++)
3961     {
3962         splinedata_t *spline;
3963         pmegrid_t *grid = NULL;
3964
3965         /* make local bsplines  */
3966         if (grids == NULL || !pme->bUseThreads)
3967         {
3968             spline = &atc->spline[0];
3969
3970             spline->n = atc->n;
3971
3972             if (bSpread)
3973             {
3974                 grid = &grids->grid;
3975             }
3976         }
3977         else
3978         {
3979             spline = &atc->spline[thread];
3980
3981             if (grids->nthread == 1)
3982             {
3983                 /* One thread, we operate on all charges */
3984                 spline->n = atc->n;
3985             }
3986             else
3987             {
3988                 /* Get the indices our thread should operate on */
3989                 make_thread_local_ind(atc, thread, spline);
3990             }
3991
3992             grid = &grids->grid_th[thread];
3993         }
3994
3995         if (bCalcSplines)
3996         {
3997             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
3998                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
3999         }
4000
4001         if (bSpread)
4002         {
4003             /* put local atoms on grid. */
4004 #ifdef PME_TIME_SPREAD
4005             ct1a = omp_cyc_start();
4006 #endif
4007             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4008
4009             if (pme->bUseThreads)
4010             {
4011                 copy_local_grid(pme, grids, thread, fftgrid);
4012             }
4013 #ifdef PME_TIME_SPREAD
4014             ct1a          = omp_cyc_end(ct1a);
4015             cs1a[thread] += (double)ct1a;
4016 #endif
4017         }
4018     }
4019 #ifdef PME_TIME_THREADS
4020     c2   = omp_cyc_end(c2);
4021     cs2 += (double)c2;
4022 #endif
4023
4024     if (bSpread && pme->bUseThreads)
4025     {
4026 #ifdef PME_TIME_THREADS
4027         c3 = omp_cyc_start();
4028 #endif
4029 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4030         for (thread = 0; thread < grids->nthread; thread++)
4031         {
4032             reduce_threadgrid_overlap(pme, grids, thread,
4033                                       fftgrid,
4034                                       pme->overlap[0].sendbuf,
4035                                       pme->overlap[1].sendbuf);
4036         }
4037 #ifdef PME_TIME_THREADS
4038         c3   = omp_cyc_end(c3);
4039         cs3 += (double)c3;
4040 #endif
4041
4042         if (pme->nnodes > 1)
4043         {
4044             /* Communicate the overlapping part of the fftgrid.
4045              * For this communication call we need to check pme->bUseThreads
4046              * to have all ranks communicate here, regardless of pme->nthread.
4047              */
4048             sum_fftgrid_dd(pme, fftgrid);
4049         }
4050     }
4051
4052 #ifdef PME_TIME_THREADS
4053     cnt++;
4054     if (cnt % 20 == 0)
4055     {
4056         printf("idx %.2f spread %.2f red %.2f",
4057                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4058 #ifdef PME_TIME_SPREAD
4059         for (thread = 0; thread < nthread; thread++)
4060         {
4061             printf(" %.2f", cs1a[thread]*1e-9);
4062         }
4063 #endif
4064         printf("\n");
4065     }
4066 #endif
4067 }
4068
4069
4070 static void dump_grid(FILE *fp,
4071                       int sx, int sy, int sz, int nx, int ny, int nz,
4072                       int my, int mz, const real *g)
4073 {
4074     int x, y, z;
4075
4076     for (x = 0; x < nx; x++)
4077     {
4078         for (y = 0; y < ny; y++)
4079         {
4080             for (z = 0; z < nz; z++)
4081             {
4082                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4083                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4084             }
4085         }
4086     }
4087 }
4088
4089 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4090 {
4091     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4092
4093     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4094                                    local_fft_ndata,
4095                                    local_fft_offset,
4096                                    local_fft_size);
4097
4098     dump_grid(stderr,
4099               pme->pmegrid_start_ix,
4100               pme->pmegrid_start_iy,
4101               pme->pmegrid_start_iz,
4102               pme->pmegrid_nx-pme->pme_order+1,
4103               pme->pmegrid_ny-pme->pme_order+1,
4104               pme->pmegrid_nz-pme->pme_order+1,
4105               local_fft_size[YY],
4106               local_fft_size[ZZ],
4107               fftgrid);
4108 }
4109
4110
4111 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4112 {
4113     pme_atomcomm_t *atc;
4114     pmegrids_t *grid;
4115
4116     if (pme->nnodes > 1)
4117     {
4118         gmx_incons("gmx_pme_calc_energy called in parallel");
4119     }
4120     if (pme->bFEP > 1)
4121     {
4122         gmx_incons("gmx_pme_calc_energy with free energy");
4123     }
4124
4125     atc            = &pme->atc_energy;
4126     atc->nthread   = 1;
4127     if (atc->spline == NULL)
4128     {
4129         snew(atc->spline, atc->nthread);
4130     }
4131     atc->nslab     = 1;
4132     atc->bSpread   = TRUE;
4133     atc->pme_order = pme->pme_order;
4134     atc->n         = n;
4135     pme_realloc_atomcomm_things(atc);
4136     atc->x         = x;
4137     atc->q         = q;
4138
4139     /* We only use the A-charges grid */
4140     grid = &pme->pmegridA;
4141
4142     /* Only calculate the spline coefficients, don't actually spread */
4143     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4144
4145     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4146 }
4147
4148
4149 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4150                                    gmx_walltime_accounting_t walltime_accounting,
4151                                    t_nrnb *nrnb, t_inputrec *ir,
4152                                    gmx_large_int_t step)
4153 {
4154     /* Reset all the counters related to performance over the run */
4155     wallcycle_stop(wcycle, ewcRUN);
4156     wallcycle_reset_all(wcycle);
4157     init_nrnb(nrnb);
4158     if (ir->nsteps >= 0)
4159     {
4160         /* ir->nsteps is not used here, but we update it for consistency */
4161         ir->nsteps -= step - ir->init_step;
4162     }
4163     ir->init_step = step;
4164     wallcycle_start(wcycle, ewcRUN);
4165     walltime_accounting_start(walltime_accounting);
4166 }
4167
4168
4169 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4170                                ivec grid_size,
4171                                t_commrec *cr, t_inputrec *ir,
4172                                gmx_pme_t *pme_ret)
4173 {
4174     int ind;
4175     gmx_pme_t pme = NULL;
4176
4177     ind = 0;
4178     while (ind < *npmedata)
4179     {
4180         pme = (*pmedata)[ind];
4181         if (pme->nkx == grid_size[XX] &&
4182             pme->nky == grid_size[YY] &&
4183             pme->nkz == grid_size[ZZ])
4184         {
4185             *pme_ret = pme;
4186
4187             return;
4188         }
4189
4190         ind++;
4191     }
4192
4193     (*npmedata)++;
4194     srenew(*pmedata, *npmedata);
4195
4196     /* Generate a new PME data structure, copying part of the old pointers */
4197     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4198
4199     *pme_ret = (*pmedata)[ind];
4200 }
4201
4202
4203 int gmx_pmeonly(gmx_pme_t pme,
4204                 t_commrec *cr,    t_nrnb *nrnb,
4205                 gmx_wallcycle_t wcycle,
4206                 gmx_walltime_accounting_t walltime_accounting,
4207                 real ewaldcoeff,
4208                 t_inputrec *ir)
4209 {
4210     int npmedata;
4211     gmx_pme_t *pmedata;
4212     gmx_pme_pp_t pme_pp;
4213     int  ret;
4214     int  natoms;
4215     matrix box;
4216     rvec *x_pp      = NULL, *f_pp = NULL;
4217     real *chargeA   = NULL, *chargeB = NULL;
4218     real lambda     = 0;
4219     int  maxshift_x = 0, maxshift_y = 0;
4220     real energy, dvdlambda;
4221     matrix vir;
4222     float cycles;
4223     int  count;
4224     gmx_bool bEnerVir;
4225     gmx_large_int_t step, step_rel;
4226     ivec grid_switch;
4227
4228     /* This data will only use with PME tuning, i.e. switching PME grids */
4229     npmedata = 1;
4230     snew(pmedata, npmedata);
4231     pmedata[0] = pme;
4232
4233     pme_pp = gmx_pme_pp_init(cr);
4234
4235     init_nrnb(nrnb);
4236
4237     count = 0;
4238     do /****** this is a quasi-loop over time steps! */
4239     {
4240         /* The reason for having a loop here is PME grid tuning/switching */
4241         do
4242         {
4243             /* Domain decomposition */
4244             ret = gmx_pme_recv_q_x(pme_pp,
4245                                    &natoms,
4246                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4247                                    &maxshift_x, &maxshift_y,
4248                                    &pme->bFEP, &lambda,
4249                                    &bEnerVir,
4250                                    &step,
4251                                    grid_switch, &ewaldcoeff);
4252
4253             if (ret == pmerecvqxSWITCHGRID)
4254             {
4255                 /* Switch the PME grid to grid_switch */
4256                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4257             }
4258
4259             if (ret == pmerecvqxRESETCOUNTERS)
4260             {
4261                 /* Reset the cycle and flop counters */
4262                 reset_pmeonly_counters(wcycle, walltime_accounting, nrnb, ir, step);
4263             }
4264         }
4265         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4266
4267         if (ret == pmerecvqxFINISH)
4268         {
4269             /* We should stop: break out of the loop */
4270             break;
4271         }
4272
4273         step_rel = step - ir->init_step;
4274
4275         if (count == 0)
4276         {
4277             wallcycle_start(wcycle, ewcRUN);
4278             walltime_accounting_start(walltime_accounting);
4279         }
4280
4281         wallcycle_start(wcycle, ewcPMEMESH);
4282
4283         dvdlambda = 0;
4284         clear_mat(vir);
4285         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4286                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4287                    &energy, lambda, &dvdlambda,
4288                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4289
4290         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4291
4292         gmx_pme_send_force_vir_ener(pme_pp,
4293                                     f_pp, vir, energy, dvdlambda,
4294                                     cycles);
4295
4296         count++;
4297     } /***** end of quasi-loop, we stop with the break above */
4298     while (TRUE);
4299
4300     walltime_accounting_end(walltime_accounting);
4301
4302     return 0;
4303 }
4304
4305 int gmx_pme_do(gmx_pme_t pme,
4306                int start,       int homenr,
4307                rvec x[],        rvec f[],
4308                real *chargeA,   real *chargeB,
4309                matrix box, t_commrec *cr,
4310                int  maxshift_x, int maxshift_y,
4311                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4312                matrix vir,      real ewaldcoeff,
4313                real *energy,    real lambda,
4314                real *dvdlambda, int flags)
4315 {
4316     int     q, d, i, j, ntot, npme;
4317     int     nx, ny, nz;
4318     int     n_d, local_ny;
4319     pme_atomcomm_t *atc = NULL;
4320     pmegrids_t *pmegrid = NULL;
4321     real    *grid       = NULL;
4322     real    *ptr;
4323     rvec    *x_d, *f_d;
4324     real    *charge = NULL, *q_d;
4325     real    energy_AB[2];
4326     matrix  vir_AB[2];
4327     gmx_bool bClearF;
4328     gmx_parallel_3dfft_t pfft_setup;
4329     real *  fftgrid;
4330     t_complex * cfftgrid;
4331     int     thread;
4332     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4333     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4334
4335     assert(pme->nnodes > 0);
4336     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4337
4338     if (pme->nnodes > 1)
4339     {
4340         atc      = &pme->atc[0];
4341         atc->npd = homenr;
4342         if (atc->npd > atc->pd_nalloc)
4343         {
4344             atc->pd_nalloc = over_alloc_dd(atc->npd);
4345             srenew(atc->pd, atc->pd_nalloc);
4346         }
4347         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4348     }
4349     else
4350     {
4351         /* This could be necessary for TPI */
4352         pme->atc[0].n = homenr;
4353     }
4354
4355     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4356     {
4357         if (q == 0)
4358         {
4359             pmegrid    = &pme->pmegridA;
4360             fftgrid    = pme->fftgridA;
4361             cfftgrid   = pme->cfftgridA;
4362             pfft_setup = pme->pfft_setupA;
4363             charge     = chargeA+start;
4364         }
4365         else
4366         {
4367             pmegrid    = &pme->pmegridB;
4368             fftgrid    = pme->fftgridB;
4369             cfftgrid   = pme->cfftgridB;
4370             pfft_setup = pme->pfft_setupB;
4371             charge     = chargeB+start;
4372         }
4373         grid = pmegrid->grid.grid;
4374         /* Unpack structure */
4375         if (debug)
4376         {
4377             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4378                     cr->nnodes, cr->nodeid);
4379             fprintf(debug, "Grid = %p\n", (void*)grid);
4380             if (grid == NULL)
4381             {
4382                 gmx_fatal(FARGS, "No grid!");
4383             }
4384         }
4385         where();
4386
4387         m_inv_ur0(box, pme->recipbox);
4388
4389         if (pme->nnodes == 1)
4390         {
4391             atc = &pme->atc[0];
4392             if (DOMAINDECOMP(cr))
4393             {
4394                 atc->n = homenr;
4395                 pme_realloc_atomcomm_things(atc);
4396             }
4397             atc->x = x;
4398             atc->q = charge;
4399             atc->f = f;
4400         }
4401         else
4402         {
4403             wallcycle_start(wcycle, ewcPME_REDISTXF);
4404             for (d = pme->ndecompdim-1; d >= 0; d--)
4405             {
4406                 if (d == pme->ndecompdim-1)
4407                 {
4408                     n_d = homenr;
4409                     x_d = x + start;
4410                     q_d = charge;
4411                 }
4412                 else
4413                 {
4414                     n_d = pme->atc[d+1].n;
4415                     x_d = atc->x;
4416                     q_d = atc->q;
4417                 }
4418                 atc      = &pme->atc[d];
4419                 atc->npd = n_d;
4420                 if (atc->npd > atc->pd_nalloc)
4421                 {
4422                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4423                     srenew(atc->pd, atc->pd_nalloc);
4424                 }
4425                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4426                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4427                 where();
4428
4429                 /* Redistribute x (only once) and qA or qB */
4430                 if (DOMAINDECOMP(cr))
4431                 {
4432                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4433                 }
4434                 else
4435                 {
4436                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4437                 }
4438             }
4439             where();
4440
4441             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4442         }
4443
4444         if (debug)
4445         {
4446             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4447                     cr->nodeid, atc->n);
4448         }
4449
4450         if (flags & GMX_PME_SPREAD_Q)
4451         {
4452             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4453
4454             /* Spread the charges on a grid */
4455             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4456
4457             if (q == 0)
4458             {
4459                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4460             }
4461             inc_nrnb(nrnb, eNR_SPREADQBSP,
4462                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4463
4464             if (!pme->bUseThreads)
4465             {
4466                 wrap_periodic_pmegrid(pme, grid);
4467
4468                 /* sum contributions to local grid from other nodes */
4469 #ifdef GMX_MPI
4470                 if (pme->nnodes > 1)
4471                 {
4472                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4473                     where();
4474                 }
4475 #endif
4476
4477                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4478             }
4479
4480             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4481
4482             /*
4483                dump_local_fftgrid(pme,fftgrid);
4484                exit(0);
4485              */
4486         }
4487
4488         /* Here we start a large thread parallel region */
4489 #pragma omp parallel num_threads(pme->nthread) private(thread)
4490         {
4491             thread = gmx_omp_get_thread_num();
4492             if (flags & GMX_PME_SOLVE)
4493             {
4494                 int loop_count;
4495
4496                 /* do 3d-fft */
4497                 if (thread == 0)
4498                 {
4499                     wallcycle_start(wcycle, ewcPME_FFT);
4500                 }
4501                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4502                                            thread, wcycle);
4503                 if (thread == 0)
4504                 {
4505                     wallcycle_stop(wcycle, ewcPME_FFT);
4506                 }
4507                 where();
4508
4509                 /* solve in k-space for our local cells */
4510                 if (thread == 0)
4511                 {
4512                     wallcycle_start(wcycle, ewcPME_SOLVE);
4513                 }
4514                 loop_count =
4515                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4516                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4517                                   bCalcEnerVir,
4518                                   pme->nthread, thread);
4519                 if (thread == 0)
4520                 {
4521                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4522                     where();
4523                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4524                 }
4525             }
4526
4527             if (bCalcF)
4528             {
4529                 /* do 3d-invfft */
4530                 if (thread == 0)
4531                 {
4532                     where();
4533                     wallcycle_start(wcycle, ewcPME_FFT);
4534                 }
4535                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4536                                            thread, wcycle);
4537                 if (thread == 0)
4538                 {
4539                     wallcycle_stop(wcycle, ewcPME_FFT);
4540
4541                     where();
4542
4543                     if (pme->nodeid == 0)
4544                     {
4545                         ntot  = pme->nkx*pme->nky*pme->nkz;
4546                         npme  = ntot*log((real)ntot)/log(2.0);
4547                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4548                     }
4549
4550                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4551                 }
4552
4553                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4554             }
4555         }
4556         /* End of thread parallel section.
4557          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4558          */
4559
4560         if (bCalcF)
4561         {
4562             /* distribute local grid to all nodes */
4563 #ifdef GMX_MPI
4564             if (pme->nnodes > 1)
4565             {
4566                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4567             }
4568 #endif
4569             where();
4570
4571             unwrap_periodic_pmegrid(pme, grid);
4572
4573             /* interpolate forces for our local atoms */
4574
4575             where();
4576
4577             /* If we are running without parallelization,
4578              * atc->f is the actual force array, not a buffer,
4579              * therefore we should not clear it.
4580              */
4581             bClearF = (q == 0 && PAR(cr));
4582 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4583             for (thread = 0; thread < pme->nthread; thread++)
4584             {
4585                 gather_f_bsplines(pme, grid, bClearF, atc,
4586                                   &atc->spline[thread],
4587                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4588             }
4589
4590             where();
4591
4592             inc_nrnb(nrnb, eNR_GATHERFBSP,
4593                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4594             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4595         }
4596
4597         if (bCalcEnerVir)
4598         {
4599             /* This should only be called on the master thread
4600              * and after the threads have synchronized.
4601              */
4602             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4603         }
4604     } /* of q-loop */
4605
4606     if (bCalcF && pme->nnodes > 1)
4607     {
4608         wallcycle_start(wcycle, ewcPME_REDISTXF);
4609         for (d = 0; d < pme->ndecompdim; d++)
4610         {
4611             atc = &pme->atc[d];
4612             if (d == pme->ndecompdim - 1)
4613             {
4614                 n_d = homenr;
4615                 f_d = f + start;
4616             }
4617             else
4618             {
4619                 n_d = pme->atc[d+1].n;
4620                 f_d = pme->atc[d+1].f;
4621             }
4622             if (DOMAINDECOMP(cr))
4623             {
4624                 dd_pmeredist_f(pme, atc, n_d, f_d,
4625                                d == pme->ndecompdim-1 && pme->bPPnode);
4626             }
4627             else
4628             {
4629                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4630             }
4631         }
4632
4633         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4634     }
4635     where();
4636
4637     if (bCalcEnerVir)
4638     {
4639         if (!pme->bFEP)
4640         {
4641             *energy = energy_AB[0];
4642             m_add(vir, vir_AB[0], vir);
4643         }
4644         else
4645         {
4646             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4647             *dvdlambda += energy_AB[1] - energy_AB[0];
4648             for (i = 0; i < DIM; i++)
4649             {
4650                 for (j = 0; j < DIM; j++)
4651                 {
4652                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4653                         lambda*vir_AB[1][i][j];
4654                 }
4655             }
4656         }
4657     }
4658     else
4659     {
4660         *energy = 0;
4661     }
4662
4663     if (debug)
4664     {
4665         fprintf(debug, "PME mesh energy: %g\n", *energy);
4666     }
4667
4668     return 0;
4669 }