Merge release-4-6 into master
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #include "gromacs/fft/parallel_3dfft.h"
64 #include "gromacs/utility/gmxmpi.h"
65
66 #include <stdio.h>
67 #include <string.h>
68 #include <math.h>
69 #include <assert.h>
70 #include "typedefs.h"
71 #include "txtdump.h"
72 #include "vec.h"
73 #include "gmxcomplex.h"
74 #include "smalloc.h"
75 #include "gromacs/fileio/futil.h"
76 #include "coulomb.h"
77 #include "gmx_fatal.h"
78 #include "pme.h"
79 #include "network.h"
80 #include "physics.h"
81 #include "nrnb.h"
82 #include "gmx_wallcycle.h"
83 #include "gromacs/fileio/pdbio.h"
84 #include "gmx_cyclecounter.h"
85 #include "gmx_omp.h"
86 #include "macros.h"
87
88
89 /* Include the SIMD macro file and then check for support */
90 #include "gmx_simd_macros.h"
91 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
92 /* Turn on arbitrary width SIMD intrinsics for PME solve */
93 #define PME_SIMD
94 #endif
95
96 /* Include the 4-wide SIMD macro file */
97 #include "gmx_simd4_macros.h"
98 /* Check if we have 4-wide SIMD macro support */
99 #ifdef GMX_HAVE_SIMD4_MACROS
100 /* Do PME spread and gather with 4-wide SIMD.
101  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
102  */
103 #define PME_SIMD4_SPREAD_GATHER
104
105 #ifdef GMX_SIMD4_HAVE_UNALIGNED
106 /* With PME-order=4 on x86, unaligned load+store is slightly faster
107  * than doubling all SIMD operations when using aligned load+store.
108  */
109 #define PME_SIMD4_UNALIGNED
110 #endif
111 #endif
112
113 #define DFT_TOL 1e-7
114 /* #define PRT_FORCE */
115 /* conditions for on the fly time-measurement */
116 /* #define TAKETIME (step > 1 && timesteps < 10) */
117 #define TAKETIME FALSE
118
119 /* #define PME_TIME_THREADS */
120
121 #ifdef GMX_DOUBLE
122 #define mpi_type MPI_DOUBLE
123 #else
124 #define mpi_type MPI_FLOAT
125 #endif
126
127 #ifdef PME_SIMD4_SPREAD_GATHER
128 #define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
129 #else
130 /* We can use any alignment, apart from 0, so we use 4 reals */
131 #define SIMD4_ALIGNMENT  (4*sizeof(real))
132 #endif
133
134 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
135  * to preserve alignment.
136  */
137 #define GMX_CACHE_SEP 64
138
139 /* We only define a maximum to be able to use local arrays without allocation.
140  * An order larger than 12 should never be needed, even for test cases.
141  * If needed it can be changed here.
142  */
143 #define PME_ORDER_MAX 12
144
145 /* Internal datastructures */
146 typedef struct {
147     int send_index0;
148     int send_nindex;
149     int recv_index0;
150     int recv_nindex;
151     int recv_size;   /* Receive buffer width, used with OpenMP */
152 } pme_grid_comm_t;
153
154 typedef struct {
155 #ifdef GMX_MPI
156     MPI_Comm         mpi_comm;
157 #endif
158     int              nnodes, nodeid;
159     int             *s2g0;
160     int             *s2g1;
161     int              noverlap_nodes;
162     int             *send_id, *recv_id;
163     int              send_size; /* Send buffer width, used with OpenMP */
164     pme_grid_comm_t *comm_data;
165     real            *sendbuf;
166     real            *recvbuf;
167 } pme_overlap_t;
168
169 typedef struct {
170     int *n;      /* Cumulative counts of the number of particles per thread */
171     int  nalloc; /* Allocation size of i */
172     int *i;      /* Particle indices ordered on thread index (n) */
173 } thread_plist_t;
174
175 typedef struct {
176     int      *thread_one;
177     int       n;
178     int      *ind;
179     splinevec theta;
180     real     *ptr_theta_z;
181     splinevec dtheta;
182     real     *ptr_dtheta_z;
183 } splinedata_t;
184
185 typedef struct {
186     int      dimind;        /* The index of the dimension, 0=x, 1=y */
187     int      nslab;
188     int      nodeid;
189 #ifdef GMX_MPI
190     MPI_Comm mpi_comm;
191 #endif
192
193     int     *node_dest;     /* The nodes to send x and q to with DD */
194     int     *node_src;      /* The nodes to receive x and q from with DD */
195     int     *buf_index;     /* Index for commnode into the buffers */
196
197     int      maxshift;
198
199     int      npd;
200     int      pd_nalloc;
201     int     *pd;
202     int     *count;         /* The number of atoms to send to each node */
203     int    **count_thread;
204     int     *rcount;        /* The number of atoms to receive */
205
206     int      n;
207     int      nalloc;
208     rvec    *x;
209     real    *q;
210     rvec    *f;
211     gmx_bool bSpread;       /* These coordinates are used for spreading */
212     int      pme_order;
213     ivec    *idx;
214     rvec    *fractx;            /* Fractional coordinate relative to the
215                                  * lower cell boundary
216                                  */
217     int             nthread;
218     int            *thread_idx; /* Which thread should spread which charge */
219     thread_plist_t *thread_plist;
220     splinedata_t   *spline;
221 } pme_atomcomm_t;
222
223 #define FLBS  3
224 #define FLBSZ 4
225
226 typedef struct {
227     ivec  ci;     /* The spatial location of this grid         */
228     ivec  n;      /* The used size of *grid, including order-1 */
229     ivec  offset; /* The grid offset from the full node grid   */
230     int   order;  /* PME spreading order                       */
231     ivec  s;      /* The allocated size of *grid, s >= n       */
232     real *grid;   /* The grid local thread, size n             */
233 } pmegrid_t;
234
235 typedef struct {
236     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
237     int        nthread;      /* The number of threads operating on this grid     */
238     ivec       nc;           /* The local spatial decomposition over the threads */
239     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
240     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
241     int      **g2t;          /* The grid to thread index                         */
242     ivec       nthread_comm; /* The number of threads to communicate with        */
243 } pmegrids_t;
244
245
246 typedef struct {
247 #ifdef PME_SIMD4_SPREAD_GATHER
248     /* Masks for 4-wide SIMD aligned spreading and gathering */
249     gmx_simd4_pb mask_S0[6], mask_S1[6];
250 #else
251     int    dummy; /* C89 requires that struct has at least one member */
252 #endif
253 } pme_spline_work_t;
254
255 typedef struct {
256     /* work data for solve_pme */
257     int      nalloc;
258     real *   mhx;
259     real *   mhy;
260     real *   mhz;
261     real *   m2;
262     real *   denom;
263     real *   tmp1_alloc;
264     real *   tmp1;
265     real *   eterm;
266     real *   m2inv;
267
268     real     energy;
269     matrix   vir;
270 } pme_work_t;
271
272 typedef struct gmx_pme {
273     int           ndecompdim; /* The number of decomposition dimensions */
274     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
275     int           nodeid_major;
276     int           nodeid_minor;
277     int           nnodes;    /* The number of nodes doing PME */
278     int           nnodes_major;
279     int           nnodes_minor;
280
281     MPI_Comm      mpi_comm;
282     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
283 #ifdef GMX_MPI
284     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
285 #endif
286
287     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
288     int        nthread;       /* The number of threads doing PME on our rank */
289
290     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
291     gmx_bool   bFEP;          /* Compute Free energy contribution */
292     int        nkx, nky, nkz; /* Grid dimensions */
293     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
294     int        pme_order;
295     real       epsilon_r;
296
297     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
298     pmegrids_t pmegridB;
299     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
300     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
301     /* pmegrid_nz might be larger than strictly necessary to ensure
302      * memory alignment, pmegrid_nz_base gives the real base size.
303      */
304     int     pmegrid_nz_base;
305     /* The local PME grid starting indices */
306     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
307
308     /* Work data for spreading and gathering */
309     pme_spline_work_t    *spline_work;
310
311     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
312     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
313     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
314
315     t_complex            *cfftgridA;  /* Grids for complex FFT data */
316     t_complex            *cfftgridB;
317     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
318
319     gmx_parallel_3dfft_t  pfft_setupA;
320     gmx_parallel_3dfft_t  pfft_setupB;
321
322     int                  *nnx, *nny, *nnz;
323     real                 *fshx, *fshy, *fshz;
324
325     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
326     matrix                recipbox;
327     splinevec             bsp_mod;
328
329     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
330
331     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
332
333     rvec                 *bufv;       /* Communication buffer */
334     real                 *bufr;       /* Communication buffer */
335     int                   buf_nalloc; /* The communication buffer size */
336
337     /* thread local work data for solve_pme */
338     pme_work_t *work;
339
340     /* Work data for PME_redist */
341     gmx_bool redist_init;
342     int *    scounts;
343     int *    rcounts;
344     int *    sdispls;
345     int *    rdispls;
346     int *    sidx;
347     int *    idxa;
348     real *   redist_buf;
349     int      redist_buf_nalloc;
350
351     /* Work data for sum_qgrid */
352     real *   sum_qgrid_tmp;
353     real *   sum_qgrid_dd_tmp;
354 } t_gmx_pme;
355
356
357 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
358                                    int start, int end, int thread)
359 {
360     int             i;
361     int            *idxptr, tix, tiy, tiz;
362     real           *xptr, *fptr, tx, ty, tz;
363     real            rxx, ryx, ryy, rzx, rzy, rzz;
364     int             nx, ny, nz;
365     int             start_ix, start_iy, start_iz;
366     int            *g2tx, *g2ty, *g2tz;
367     gmx_bool        bThreads;
368     int            *thread_idx = NULL;
369     thread_plist_t *tpl        = NULL;
370     int            *tpl_n      = NULL;
371     int             thread_i;
372
373     nx  = pme->nkx;
374     ny  = pme->nky;
375     nz  = pme->nkz;
376
377     start_ix = pme->pmegrid_start_ix;
378     start_iy = pme->pmegrid_start_iy;
379     start_iz = pme->pmegrid_start_iz;
380
381     rxx = pme->recipbox[XX][XX];
382     ryx = pme->recipbox[YY][XX];
383     ryy = pme->recipbox[YY][YY];
384     rzx = pme->recipbox[ZZ][XX];
385     rzy = pme->recipbox[ZZ][YY];
386     rzz = pme->recipbox[ZZ][ZZ];
387
388     g2tx = pme->pmegridA.g2t[XX];
389     g2ty = pme->pmegridA.g2t[YY];
390     g2tz = pme->pmegridA.g2t[ZZ];
391
392     bThreads = (atc->nthread > 1);
393     if (bThreads)
394     {
395         thread_idx = atc->thread_idx;
396
397         tpl   = &atc->thread_plist[thread];
398         tpl_n = tpl->n;
399         for (i = 0; i < atc->nthread; i++)
400         {
401             tpl_n[i] = 0;
402         }
403     }
404
405     for (i = start; i < end; i++)
406     {
407         xptr   = atc->x[i];
408         idxptr = atc->idx[i];
409         fptr   = atc->fractx[i];
410
411         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
412         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
413         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
414         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
415
416         tix = (int)(tx);
417         tiy = (int)(ty);
418         tiz = (int)(tz);
419
420         /* Because decomposition only occurs in x and y,
421          * we never have a fraction correction in z.
422          */
423         fptr[XX] = tx - tix + pme->fshx[tix];
424         fptr[YY] = ty - tiy + pme->fshy[tiy];
425         fptr[ZZ] = tz - tiz;
426
427         idxptr[XX] = pme->nnx[tix];
428         idxptr[YY] = pme->nny[tiy];
429         idxptr[ZZ] = pme->nnz[tiz];
430
431 #ifdef DEBUG
432         range_check(idxptr[XX], 0, pme->pmegrid_nx);
433         range_check(idxptr[YY], 0, pme->pmegrid_ny);
434         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
435 #endif
436
437         if (bThreads)
438         {
439             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
440             thread_idx[i] = thread_i;
441             tpl_n[thread_i]++;
442         }
443     }
444
445     if (bThreads)
446     {
447         /* Make a list of particle indices sorted on thread */
448
449         /* Get the cumulative count */
450         for (i = 1; i < atc->nthread; i++)
451         {
452             tpl_n[i] += tpl_n[i-1];
453         }
454         /* The current implementation distributes particles equally
455          * over the threads, so we could actually allocate for that
456          * in pme_realloc_atomcomm_things.
457          */
458         if (tpl_n[atc->nthread-1] > tpl->nalloc)
459         {
460             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
461             srenew(tpl->i, tpl->nalloc);
462         }
463         /* Set tpl_n to the cumulative start */
464         for (i = atc->nthread-1; i >= 1; i--)
465         {
466             tpl_n[i] = tpl_n[i-1];
467         }
468         tpl_n[0] = 0;
469
470         /* Fill our thread local array with indices sorted on thread */
471         for (i = start; i < end; i++)
472         {
473             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
474         }
475         /* Now tpl_n contains the cummulative count again */
476     }
477 }
478
479 static void make_thread_local_ind(pme_atomcomm_t *atc,
480                                   int thread, splinedata_t *spline)
481 {
482     int             n, t, i, start, end;
483     thread_plist_t *tpl;
484
485     /* Combine the indices made by each thread into one index */
486
487     n     = 0;
488     start = 0;
489     for (t = 0; t < atc->nthread; t++)
490     {
491         tpl = &atc->thread_plist[t];
492         /* Copy our part (start - end) from the list of thread t */
493         if (thread > 0)
494         {
495             start = tpl->n[thread-1];
496         }
497         end = tpl->n[thread];
498         for (i = start; i < end; i++)
499         {
500             spline->ind[n++] = tpl->i[i];
501         }
502     }
503
504     spline->n = n;
505 }
506
507
508 static void pme_calc_pidx(int start, int end,
509                           matrix recipbox, rvec x[],
510                           pme_atomcomm_t *atc, int *count)
511 {
512     int   nslab, i;
513     int   si;
514     real *xptr, s;
515     real  rxx, ryx, rzx, ryy, rzy;
516     int  *pd;
517
518     /* Calculate PME task index (pidx) for each grid index.
519      * Here we always assign equally sized slabs to each node
520      * for load balancing reasons (the PME grid spacing is not used).
521      */
522
523     nslab = atc->nslab;
524     pd    = atc->pd;
525
526     /* Reset the count */
527     for (i = 0; i < nslab; i++)
528     {
529         count[i] = 0;
530     }
531
532     if (atc->dimind == 0)
533     {
534         rxx = recipbox[XX][XX];
535         ryx = recipbox[YY][XX];
536         rzx = recipbox[ZZ][XX];
537         /* Calculate the node index in x-dimension */
538         for (i = start; i < end; i++)
539         {
540             xptr   = x[i];
541             /* Fractional coordinates along box vectors */
542             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
543             si    = (int)(s + 2*nslab) % nslab;
544             pd[i] = si;
545             count[si]++;
546         }
547     }
548     else
549     {
550         ryy = recipbox[YY][YY];
551         rzy = recipbox[ZZ][YY];
552         /* Calculate the node index in y-dimension */
553         for (i = start; i < end; i++)
554         {
555             xptr   = x[i];
556             /* Fractional coordinates along box vectors */
557             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
558             si    = (int)(s + 2*nslab) % nslab;
559             pd[i] = si;
560             count[si]++;
561         }
562     }
563 }
564
565 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
566                                   pme_atomcomm_t *atc)
567 {
568     int nthread, thread, slab;
569
570     nthread = atc->nthread;
571
572 #pragma omp parallel for num_threads(nthread) schedule(static)
573     for (thread = 0; thread < nthread; thread++)
574     {
575         pme_calc_pidx(natoms* thread   /nthread,
576                       natoms*(thread+1)/nthread,
577                       recipbox, x, atc, atc->count_thread[thread]);
578     }
579     /* Non-parallel reduction, since nslab is small */
580
581     for (thread = 1; thread < nthread; thread++)
582     {
583         for (slab = 0; slab < atc->nslab; slab++)
584         {
585             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
586         }
587     }
588 }
589
590 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
591 {
592     const int padding = 4;
593     int       i;
594
595     srenew(th[XX], nalloc);
596     srenew(th[YY], nalloc);
597     /* In z we add padding, this is only required for the aligned SIMD code */
598     sfree_aligned(*ptr_z);
599     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
600     th[ZZ] = *ptr_z + padding;
601
602     for (i = 0; i < padding; i++)
603     {
604         (*ptr_z)[               i] = 0;
605         (*ptr_z)[padding+nalloc+i] = 0;
606     }
607 }
608
609 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
610 {
611     int i, d;
612
613     srenew(spline->ind, atc->nalloc);
614     /* Initialize the index to identity so it works without threads */
615     for (i = 0; i < atc->nalloc; i++)
616     {
617         spline->ind[i] = i;
618     }
619
620     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
621                       atc->pme_order*atc->nalloc);
622     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
623                       atc->pme_order*atc->nalloc);
624 }
625
626 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
627 {
628     int nalloc_old, i, j, nalloc_tpl;
629
630     /* We have to avoid a NULL pointer for atc->x to avoid
631      * possible fatal errors in MPI routines.
632      */
633     if (atc->n > atc->nalloc || atc->nalloc == 0)
634     {
635         nalloc_old  = atc->nalloc;
636         atc->nalloc = over_alloc_dd(max(atc->n, 1));
637
638         if (atc->nslab > 1)
639         {
640             srenew(atc->x, atc->nalloc);
641             srenew(atc->q, atc->nalloc);
642             srenew(atc->f, atc->nalloc);
643             for (i = nalloc_old; i < atc->nalloc; i++)
644             {
645                 clear_rvec(atc->f[i]);
646             }
647         }
648         if (atc->bSpread)
649         {
650             srenew(atc->fractx, atc->nalloc);
651             srenew(atc->idx, atc->nalloc);
652
653             if (atc->nthread > 1)
654             {
655                 srenew(atc->thread_idx, atc->nalloc);
656             }
657
658             for (i = 0; i < atc->nthread; i++)
659             {
660                 pme_realloc_splinedata(&atc->spline[i], atc);
661             }
662         }
663     }
664 }
665
666 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
667                          int n, gmx_bool bXF, rvec *x_f, real *charge,
668                          pme_atomcomm_t *atc)
669 /* Redistribute particle data for PME calculation */
670 /* domain decomposition by x coordinate           */
671 {
672     int *idxa;
673     int  i, ii;
674
675     if (FALSE == pme->redist_init)
676     {
677         snew(pme->scounts, atc->nslab);
678         snew(pme->rcounts, atc->nslab);
679         snew(pme->sdispls, atc->nslab);
680         snew(pme->rdispls, atc->nslab);
681         snew(pme->sidx, atc->nslab);
682         pme->redist_init = TRUE;
683     }
684     if (n > pme->redist_buf_nalloc)
685     {
686         pme->redist_buf_nalloc = over_alloc_dd(n);
687         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
688     }
689
690     pme->idxa = atc->pd;
691
692 #ifdef GMX_MPI
693     if (forw && bXF)
694     {
695         /* forward, redistribution from pp to pme */
696
697         /* Calculate send counts and exchange them with other nodes */
698         for (i = 0; (i < atc->nslab); i++)
699         {
700             pme->scounts[i] = 0;
701         }
702         for (i = 0; (i < n); i++)
703         {
704             pme->scounts[pme->idxa[i]]++;
705         }
706         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
707
708         /* Calculate send and receive displacements and index into send
709            buffer */
710         pme->sdispls[0] = 0;
711         pme->rdispls[0] = 0;
712         pme->sidx[0]    = 0;
713         for (i = 1; i < atc->nslab; i++)
714         {
715             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
716             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
717             pme->sidx[i]    = pme->sdispls[i];
718         }
719         /* Total # of particles to be received */
720         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
721
722         pme_realloc_atomcomm_things(atc);
723
724         /* Copy particle coordinates into send buffer and exchange*/
725         for (i = 0; (i < n); i++)
726         {
727             ii = DIM*pme->sidx[pme->idxa[i]];
728             pme->sidx[pme->idxa[i]]++;
729             pme->redist_buf[ii+XX] = x_f[i][XX];
730             pme->redist_buf[ii+YY] = x_f[i][YY];
731             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
732         }
733         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
734                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
735                       pme->rvec_mpi, atc->mpi_comm);
736     }
737     if (forw)
738     {
739         /* Copy charge into send buffer and exchange*/
740         for (i = 0; i < atc->nslab; i++)
741         {
742             pme->sidx[i] = pme->sdispls[i];
743         }
744         for (i = 0; (i < n); i++)
745         {
746             ii = pme->sidx[pme->idxa[i]];
747             pme->sidx[pme->idxa[i]]++;
748             pme->redist_buf[ii] = charge[i];
749         }
750         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
751                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
752                       atc->mpi_comm);
753     }
754     else   /* backward, redistribution from pme to pp */
755     {
756         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
757                       pme->redist_buf, pme->scounts, pme->sdispls,
758                       pme->rvec_mpi, atc->mpi_comm);
759
760         /* Copy data from receive buffer */
761         for (i = 0; i < atc->nslab; i++)
762         {
763             pme->sidx[i] = pme->sdispls[i];
764         }
765         for (i = 0; (i < n); i++)
766         {
767             ii          = DIM*pme->sidx[pme->idxa[i]];
768             x_f[i][XX] += pme->redist_buf[ii+XX];
769             x_f[i][YY] += pme->redist_buf[ii+YY];
770             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
771             pme->sidx[pme->idxa[i]]++;
772         }
773     }
774 #endif
775 }
776
777 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
778                             gmx_bool bBackward, int shift,
779                             void *buf_s, int nbyte_s,
780                             void *buf_r, int nbyte_r)
781 {
782 #ifdef GMX_MPI
783     int        dest, src;
784     MPI_Status stat;
785
786     if (bBackward == FALSE)
787     {
788         dest = atc->node_dest[shift];
789         src  = atc->node_src[shift];
790     }
791     else
792     {
793         dest = atc->node_src[shift];
794         src  = atc->node_dest[shift];
795     }
796
797     if (nbyte_s > 0 && nbyte_r > 0)
798     {
799         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
800                      dest, shift,
801                      buf_r, nbyte_r, MPI_BYTE,
802                      src, shift,
803                      atc->mpi_comm, &stat);
804     }
805     else if (nbyte_s > 0)
806     {
807         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
808                  dest, shift,
809                  atc->mpi_comm);
810     }
811     else if (nbyte_r > 0)
812     {
813         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
814                  src, shift,
815                  atc->mpi_comm, &stat);
816     }
817 #endif
818 }
819
820 static void dd_pmeredist_x_q(gmx_pme_t pme,
821                              int n, gmx_bool bX, rvec *x, real *charge,
822                              pme_atomcomm_t *atc)
823 {
824     int *commnode, *buf_index;
825     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
826
827     commnode  = atc->node_dest;
828     buf_index = atc->buf_index;
829
830     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
831
832     nsend = 0;
833     for (i = 0; i < nnodes_comm; i++)
834     {
835         buf_index[commnode[i]] = nsend;
836         nsend                 += atc->count[commnode[i]];
837     }
838     if (bX)
839     {
840         if (atc->count[atc->nodeid] + nsend != n)
841         {
842             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
843                       "This usually means that your system is not well equilibrated.",
844                       n - (atc->count[atc->nodeid] + nsend),
845                       pme->nodeid, 'x'+atc->dimind);
846         }
847
848         if (nsend > pme->buf_nalloc)
849         {
850             pme->buf_nalloc = over_alloc_dd(nsend);
851             srenew(pme->bufv, pme->buf_nalloc);
852             srenew(pme->bufr, pme->buf_nalloc);
853         }
854
855         atc->n = atc->count[atc->nodeid];
856         for (i = 0; i < nnodes_comm; i++)
857         {
858             scount = atc->count[commnode[i]];
859             /* Communicate the count */
860             if (debug)
861             {
862                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
863                         atc->dimind, atc->nodeid, commnode[i], scount);
864             }
865             pme_dd_sendrecv(atc, FALSE, i,
866                             &scount, sizeof(int),
867                             &atc->rcount[i], sizeof(int));
868             atc->n += atc->rcount[i];
869         }
870
871         pme_realloc_atomcomm_things(atc);
872     }
873
874     local_pos = 0;
875     for (i = 0; i < n; i++)
876     {
877         node = atc->pd[i];
878         if (node == atc->nodeid)
879         {
880             /* Copy direct to the receive buffer */
881             if (bX)
882             {
883                 copy_rvec(x[i], atc->x[local_pos]);
884             }
885             atc->q[local_pos] = charge[i];
886             local_pos++;
887         }
888         else
889         {
890             /* Copy to the send buffer */
891             if (bX)
892             {
893                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
894             }
895             pme->bufr[buf_index[node]] = charge[i];
896             buf_index[node]++;
897         }
898     }
899
900     buf_pos = 0;
901     for (i = 0; i < nnodes_comm; i++)
902     {
903         scount = atc->count[commnode[i]];
904         rcount = atc->rcount[i];
905         if (scount > 0 || rcount > 0)
906         {
907             if (bX)
908             {
909                 /* Communicate the coordinates */
910                 pme_dd_sendrecv(atc, FALSE, i,
911                                 pme->bufv[buf_pos], scount*sizeof(rvec),
912                                 atc->x[local_pos], rcount*sizeof(rvec));
913             }
914             /* Communicate the charges */
915             pme_dd_sendrecv(atc, FALSE, i,
916                             pme->bufr+buf_pos, scount*sizeof(real),
917                             atc->q+local_pos, rcount*sizeof(real));
918             buf_pos   += scount;
919             local_pos += atc->rcount[i];
920         }
921     }
922 }
923
924 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
925                            int n, rvec *f,
926                            gmx_bool bAddF)
927 {
928     int *commnode, *buf_index;
929     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
930
931     commnode  = atc->node_dest;
932     buf_index = atc->buf_index;
933
934     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
935
936     local_pos = atc->count[atc->nodeid];
937     buf_pos   = 0;
938     for (i = 0; i < nnodes_comm; i++)
939     {
940         scount = atc->rcount[i];
941         rcount = atc->count[commnode[i]];
942         if (scount > 0 || rcount > 0)
943         {
944             /* Communicate the forces */
945             pme_dd_sendrecv(atc, TRUE, i,
946                             atc->f[local_pos], scount*sizeof(rvec),
947                             pme->bufv[buf_pos], rcount*sizeof(rvec));
948             local_pos += scount;
949         }
950         buf_index[commnode[i]] = buf_pos;
951         buf_pos               += rcount;
952     }
953
954     local_pos = 0;
955     if (bAddF)
956     {
957         for (i = 0; i < n; i++)
958         {
959             node = atc->pd[i];
960             if (node == atc->nodeid)
961             {
962                 /* Add from the local force array */
963                 rvec_inc(f[i], atc->f[local_pos]);
964                 local_pos++;
965             }
966             else
967             {
968                 /* Add from the receive buffer */
969                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
970                 buf_index[node]++;
971             }
972         }
973     }
974     else
975     {
976         for (i = 0; i < n; i++)
977         {
978             node = atc->pd[i];
979             if (node == atc->nodeid)
980             {
981                 /* Copy from the local force array */
982                 copy_rvec(atc->f[local_pos], f[i]);
983                 local_pos++;
984             }
985             else
986             {
987                 /* Copy from the receive buffer */
988                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
989                 buf_index[node]++;
990             }
991         }
992     }
993 }
994
995 #ifdef GMX_MPI
996 static void
997 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
998 {
999     pme_overlap_t *overlap;
1000     int            send_index0, send_nindex;
1001     int            recv_index0, recv_nindex;
1002     MPI_Status     stat;
1003     int            i, j, k, ix, iy, iz, icnt;
1004     int            ipulse, send_id, recv_id, datasize;
1005     real          *p;
1006     real          *sendptr, *recvptr;
1007
1008     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1009     overlap = &pme->overlap[1];
1010
1011     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1012     {
1013         /* Since we have already (un)wrapped the overlap in the z-dimension,
1014          * we only have to communicate 0 to nkz (not pmegrid_nz).
1015          */
1016         if (direction == GMX_SUM_QGRID_FORWARD)
1017         {
1018             send_id       = overlap->send_id[ipulse];
1019             recv_id       = overlap->recv_id[ipulse];
1020             send_index0   = overlap->comm_data[ipulse].send_index0;
1021             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1022             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1023             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1024         }
1025         else
1026         {
1027             send_id       = overlap->recv_id[ipulse];
1028             recv_id       = overlap->send_id[ipulse];
1029             send_index0   = overlap->comm_data[ipulse].recv_index0;
1030             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1031             recv_index0   = overlap->comm_data[ipulse].send_index0;
1032             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1033         }
1034
1035         /* Copy data to contiguous send buffer */
1036         if (debug)
1037         {
1038             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1039                     pme->nodeid, overlap->nodeid, send_id,
1040                     pme->pmegrid_start_iy,
1041                     send_index0-pme->pmegrid_start_iy,
1042                     send_index0-pme->pmegrid_start_iy+send_nindex);
1043         }
1044         icnt = 0;
1045         for (i = 0; i < pme->pmegrid_nx; i++)
1046         {
1047             ix = i;
1048             for (j = 0; j < send_nindex; j++)
1049             {
1050                 iy = j + send_index0 - pme->pmegrid_start_iy;
1051                 for (k = 0; k < pme->nkz; k++)
1052                 {
1053                     iz = k;
1054                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1055                 }
1056             }
1057         }
1058
1059         datasize      = pme->pmegrid_nx * pme->nkz;
1060
1061         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1062                      send_id, ipulse,
1063                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1064                      recv_id, ipulse,
1065                      overlap->mpi_comm, &stat);
1066
1067         /* Get data from contiguous recv buffer */
1068         if (debug)
1069         {
1070             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1071                     pme->nodeid, overlap->nodeid, recv_id,
1072                     pme->pmegrid_start_iy,
1073                     recv_index0-pme->pmegrid_start_iy,
1074                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1075         }
1076         icnt = 0;
1077         for (i = 0; i < pme->pmegrid_nx; i++)
1078         {
1079             ix = i;
1080             for (j = 0; j < recv_nindex; j++)
1081             {
1082                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1083                 for (k = 0; k < pme->nkz; k++)
1084                 {
1085                     iz = k;
1086                     if (direction == GMX_SUM_QGRID_FORWARD)
1087                     {
1088                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1089                     }
1090                     else
1091                     {
1092                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1093                     }
1094                 }
1095             }
1096         }
1097     }
1098
1099     /* Major dimension is easier, no copying required,
1100      * but we might have to sum to separate array.
1101      * Since we don't copy, we have to communicate up to pmegrid_nz,
1102      * not nkz as for the minor direction.
1103      */
1104     overlap = &pme->overlap[0];
1105
1106     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1107     {
1108         if (direction == GMX_SUM_QGRID_FORWARD)
1109         {
1110             send_id       = overlap->send_id[ipulse];
1111             recv_id       = overlap->recv_id[ipulse];
1112             send_index0   = overlap->comm_data[ipulse].send_index0;
1113             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1114             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1115             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1116             recvptr       = overlap->recvbuf;
1117         }
1118         else
1119         {
1120             send_id       = overlap->recv_id[ipulse];
1121             recv_id       = overlap->send_id[ipulse];
1122             send_index0   = overlap->comm_data[ipulse].recv_index0;
1123             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1124             recv_index0   = overlap->comm_data[ipulse].send_index0;
1125             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1126             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1127         }
1128
1129         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1130         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1131
1132         if (debug)
1133         {
1134             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1135                     pme->nodeid, overlap->nodeid, send_id,
1136                     pme->pmegrid_start_ix,
1137                     send_index0-pme->pmegrid_start_ix,
1138                     send_index0-pme->pmegrid_start_ix+send_nindex);
1139             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1140                     pme->nodeid, overlap->nodeid, recv_id,
1141                     pme->pmegrid_start_ix,
1142                     recv_index0-pme->pmegrid_start_ix,
1143                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1144         }
1145
1146         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1147                      send_id, ipulse,
1148                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1149                      recv_id, ipulse,
1150                      overlap->mpi_comm, &stat);
1151
1152         /* ADD data from contiguous recv buffer */
1153         if (direction == GMX_SUM_QGRID_FORWARD)
1154         {
1155             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1156             for (i = 0; i < recv_nindex*datasize; i++)
1157             {
1158                 p[i] += overlap->recvbuf[i];
1159             }
1160         }
1161     }
1162 }
1163 #endif
1164
1165
1166 static int
1167 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1168 {
1169     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1170     ivec    local_pme_size;
1171     int     i, ix, iy, iz;
1172     int     pmeidx, fftidx;
1173
1174     /* Dimensions should be identical for A/B grid, so we just use A here */
1175     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1176                                    local_fft_ndata,
1177                                    local_fft_offset,
1178                                    local_fft_size);
1179
1180     local_pme_size[0] = pme->pmegrid_nx;
1181     local_pme_size[1] = pme->pmegrid_ny;
1182     local_pme_size[2] = pme->pmegrid_nz;
1183
1184     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1185        the offset is identical, and the PME grid always has more data (due to overlap)
1186      */
1187     {
1188 #ifdef DEBUG_PME
1189         FILE *fp, *fp2;
1190         char  fn[STRLEN], format[STRLEN];
1191         real  val;
1192         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1193         fp = ffopen(fn, "w");
1194         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1195         fp2 = ffopen(fn, "w");
1196         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1197 #endif
1198
1199         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1200         {
1201             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1202             {
1203                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1204                 {
1205                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1206                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1207                     fftgrid[fftidx] = pmegrid[pmeidx];
1208 #ifdef DEBUG_PME
1209                     val = 100*pmegrid[pmeidx];
1210                     if (pmegrid[pmeidx] != 0)
1211                     {
1212                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1213                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1214                     }
1215                     if (pmegrid[pmeidx] != 0)
1216                     {
1217                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1218                                 "qgrid",
1219                                 pme->pmegrid_start_ix + ix,
1220                                 pme->pmegrid_start_iy + iy,
1221                                 pme->pmegrid_start_iz + iz,
1222                                 pmegrid[pmeidx]);
1223                     }
1224 #endif
1225                 }
1226             }
1227         }
1228 #ifdef DEBUG_PME
1229         ffclose(fp);
1230         ffclose(fp2);
1231 #endif
1232     }
1233     return 0;
1234 }
1235
1236
1237 static gmx_cycles_t omp_cyc_start()
1238 {
1239     return gmx_cycles_read();
1240 }
1241
1242 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1243 {
1244     return gmx_cycles_read() - c;
1245 }
1246
1247
1248 static int
1249 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1250                         int nthread, int thread)
1251 {
1252     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1253     ivec          local_pme_size;
1254     int           ixy0, ixy1, ixy, ix, iy, iz;
1255     int           pmeidx, fftidx;
1256 #ifdef PME_TIME_THREADS
1257     gmx_cycles_t  c1;
1258     static double cs1 = 0;
1259     static int    cnt = 0;
1260 #endif
1261
1262 #ifdef PME_TIME_THREADS
1263     c1 = omp_cyc_start();
1264 #endif
1265     /* Dimensions should be identical for A/B grid, so we just use A here */
1266     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1267                                    local_fft_ndata,
1268                                    local_fft_offset,
1269                                    local_fft_size);
1270
1271     local_pme_size[0] = pme->pmegrid_nx;
1272     local_pme_size[1] = pme->pmegrid_ny;
1273     local_pme_size[2] = pme->pmegrid_nz;
1274
1275     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1276        the offset is identical, and the PME grid always has more data (due to overlap)
1277      */
1278     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1279     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1280
1281     for (ixy = ixy0; ixy < ixy1; ixy++)
1282     {
1283         ix = ixy/local_fft_ndata[YY];
1284         iy = ixy - ix*local_fft_ndata[YY];
1285
1286         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1287         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1288         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1289         {
1290             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1291         }
1292     }
1293
1294 #ifdef PME_TIME_THREADS
1295     c1   = omp_cyc_end(c1);
1296     cs1 += (double)c1;
1297     cnt++;
1298     if (cnt % 20 == 0)
1299     {
1300         printf("copy %.2f\n", cs1*1e-9);
1301     }
1302 #endif
1303
1304     return 0;
1305 }
1306
1307
1308 static void
1309 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1310 {
1311     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1312
1313     nx = pme->nkx;
1314     ny = pme->nky;
1315     nz = pme->nkz;
1316
1317     pnx = pme->pmegrid_nx;
1318     pny = pme->pmegrid_ny;
1319     pnz = pme->pmegrid_nz;
1320
1321     overlap = pme->pme_order - 1;
1322
1323     /* Add periodic overlap in z */
1324     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1325     {
1326         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1327         {
1328             for (iz = 0; iz < overlap; iz++)
1329             {
1330                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1331                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1332             }
1333         }
1334     }
1335
1336     if (pme->nnodes_minor == 1)
1337     {
1338         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1339         {
1340             for (iy = 0; iy < overlap; iy++)
1341             {
1342                 for (iz = 0; iz < nz; iz++)
1343                 {
1344                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1345                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1346                 }
1347             }
1348         }
1349     }
1350
1351     if (pme->nnodes_major == 1)
1352     {
1353         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1354
1355         for (ix = 0; ix < overlap; ix++)
1356         {
1357             for (iy = 0; iy < ny_x; iy++)
1358             {
1359                 for (iz = 0; iz < nz; iz++)
1360                 {
1361                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1362                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1363                 }
1364             }
1365         }
1366     }
1367 }
1368
1369
1370 static void
1371 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1372 {
1373     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1374
1375     nx = pme->nkx;
1376     ny = pme->nky;
1377     nz = pme->nkz;
1378
1379     pnx = pme->pmegrid_nx;
1380     pny = pme->pmegrid_ny;
1381     pnz = pme->pmegrid_nz;
1382
1383     overlap = pme->pme_order - 1;
1384
1385     if (pme->nnodes_major == 1)
1386     {
1387         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1388
1389         for (ix = 0; ix < overlap; ix++)
1390         {
1391             int iy, iz;
1392
1393             for (iy = 0; iy < ny_x; iy++)
1394             {
1395                 for (iz = 0; iz < nz; iz++)
1396                 {
1397                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1398                         pmegrid[(ix*pny+iy)*pnz+iz];
1399                 }
1400             }
1401         }
1402     }
1403
1404     if (pme->nnodes_minor == 1)
1405     {
1406 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1407         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1408         {
1409             int iy, iz;
1410
1411             for (iy = 0; iy < overlap; iy++)
1412             {
1413                 for (iz = 0; iz < nz; iz++)
1414                 {
1415                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1416                         pmegrid[(ix*pny+iy)*pnz+iz];
1417                 }
1418             }
1419         }
1420     }
1421
1422     /* Copy periodic overlap in z */
1423 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1424     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1425     {
1426         int iy, iz;
1427
1428         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1429         {
1430             for (iz = 0; iz < overlap; iz++)
1431             {
1432                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1433                     pmegrid[(ix*pny+iy)*pnz+iz];
1434             }
1435         }
1436     }
1437 }
1438
1439
1440 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1441 #define DO_BSPLINE(order)                            \
1442     for (ithx = 0; (ithx < order); ithx++)                    \
1443     {                                                    \
1444         index_x = (i0+ithx)*pny*pnz;                     \
1445         valx    = qn*thx[ithx];                          \
1446                                                      \
1447         for (ithy = 0; (ithy < order); ithy++)                \
1448         {                                                \
1449             valxy    = valx*thy[ithy];                   \
1450             index_xy = index_x+(j0+ithy)*pnz;            \
1451                                                      \
1452             for (ithz = 0; (ithz < order); ithz++)            \
1453             {                                            \
1454                 index_xyz        = index_xy+(k0+ithz);   \
1455                 grid[index_xyz] += valxy*thz[ithz];      \
1456             }                                            \
1457         }                                                \
1458     }
1459
1460
1461 static void spread_q_bsplines_thread(pmegrid_t                    *pmegrid,
1462                                      pme_atomcomm_t               *atc,
1463                                      splinedata_t                 *spline,
1464                                      pme_spline_work_t gmx_unused *work)
1465 {
1466
1467     /* spread charges from home atoms to local grid */
1468     real          *grid;
1469     pme_overlap_t *ol;
1470     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1471     int       *    idxptr;
1472     int            order, norder, index_x, index_xy, index_xyz;
1473     real           valx, valxy, qn;
1474     real          *thx, *thy, *thz;
1475     int            localsize, bndsize;
1476     int            pnx, pny, pnz, ndatatot;
1477     int            offx, offy, offz;
1478
1479 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1480     real           thz_buffer[12], *thz_aligned;
1481
1482     thz_aligned = gmx_simd4_align_real(thz_buffer);
1483 #endif
1484
1485     pnx = pmegrid->s[XX];
1486     pny = pmegrid->s[YY];
1487     pnz = pmegrid->s[ZZ];
1488
1489     offx = pmegrid->offset[XX];
1490     offy = pmegrid->offset[YY];
1491     offz = pmegrid->offset[ZZ];
1492
1493     ndatatot = pnx*pny*pnz;
1494     grid     = pmegrid->grid;
1495     for (i = 0; i < ndatatot; i++)
1496     {
1497         grid[i] = 0;
1498     }
1499
1500     order = pmegrid->order;
1501
1502     for (nn = 0; nn < spline->n; nn++)
1503     {
1504         n  = spline->ind[nn];
1505         qn = atc->q[n];
1506
1507         if (qn != 0)
1508         {
1509             idxptr = atc->idx[n];
1510             norder = nn*order;
1511
1512             i0   = idxptr[XX] - offx;
1513             j0   = idxptr[YY] - offy;
1514             k0   = idxptr[ZZ] - offz;
1515
1516             thx = spline->theta[XX] + norder;
1517             thy = spline->theta[YY] + norder;
1518             thz = spline->theta[ZZ] + norder;
1519
1520             switch (order)
1521             {
1522                 case 4:
1523 #ifdef PME_SIMD4_SPREAD_GATHER
1524 #ifdef PME_SIMD4_UNALIGNED
1525 #define PME_SPREAD_SIMD4_ORDER4
1526 #else
1527 #define PME_SPREAD_SIMD4_ALIGNED
1528 #define PME_ORDER 4
1529 #endif
1530 #include "pme_simd4.h"
1531 #else
1532                     DO_BSPLINE(4);
1533 #endif
1534                     break;
1535                 case 5:
1536 #ifdef PME_SIMD4_SPREAD_GATHER
1537 #define PME_SPREAD_SIMD4_ALIGNED
1538 #define PME_ORDER 5
1539 #include "pme_simd4.h"
1540 #else
1541                     DO_BSPLINE(5);
1542 #endif
1543                     break;
1544                 default:
1545                     DO_BSPLINE(order);
1546                     break;
1547             }
1548         }
1549     }
1550 }
1551
1552 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1553 {
1554 #ifdef PME_SIMD4_SPREAD_GATHER
1555     if (pme_order == 5
1556 #ifndef PME_SIMD4_UNALIGNED
1557         || pme_order == 4
1558 #endif
1559         )
1560     {
1561         /* Round nz up to a multiple of 4 to ensure alignment */
1562         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1563     }
1564 #endif
1565 }
1566
1567 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1568 {
1569 #ifdef PME_SIMD4_SPREAD_GATHER
1570 #ifndef PME_SIMD4_UNALIGNED
1571     if (pme_order == 4)
1572     {
1573         /* Add extra elements to ensured aligned operations do not go
1574          * beyond the allocated grid size.
1575          * Note that for pme_order=5, the pme grid z-size alignment
1576          * ensures that we will not go beyond the grid size.
1577          */
1578         *gridsize += 4;
1579     }
1580 #endif
1581 #endif
1582 }
1583
1584 static void pmegrid_init(pmegrid_t *grid,
1585                          int cx, int cy, int cz,
1586                          int x0, int y0, int z0,
1587                          int x1, int y1, int z1,
1588                          gmx_bool set_alignment,
1589                          int pme_order,
1590                          real *ptr)
1591 {
1592     int nz, gridsize;
1593
1594     grid->ci[XX]     = cx;
1595     grid->ci[YY]     = cy;
1596     grid->ci[ZZ]     = cz;
1597     grid->offset[XX] = x0;
1598     grid->offset[YY] = y0;
1599     grid->offset[ZZ] = z0;
1600     grid->n[XX]      = x1 - x0 + pme_order - 1;
1601     grid->n[YY]      = y1 - y0 + pme_order - 1;
1602     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1603     copy_ivec(grid->n, grid->s);
1604
1605     nz = grid->s[ZZ];
1606     set_grid_alignment(&nz, pme_order);
1607     if (set_alignment)
1608     {
1609         grid->s[ZZ] = nz;
1610     }
1611     else if (nz != grid->s[ZZ])
1612     {
1613         gmx_incons("pmegrid_init call with an unaligned z size");
1614     }
1615
1616     grid->order = pme_order;
1617     if (ptr == NULL)
1618     {
1619         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1620         set_gridsize_alignment(&gridsize, pme_order);
1621         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1622     }
1623     else
1624     {
1625         grid->grid = ptr;
1626     }
1627 }
1628
1629 static int div_round_up(int enumerator, int denominator)
1630 {
1631     return (enumerator + denominator - 1)/denominator;
1632 }
1633
1634 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1635                                   ivec nsub)
1636 {
1637     int gsize_opt, gsize;
1638     int nsx, nsy, nsz;
1639     char *env;
1640
1641     gsize_opt = -1;
1642     for (nsx = 1; nsx <= nthread; nsx++)
1643     {
1644         if (nthread % nsx == 0)
1645         {
1646             for (nsy = 1; nsy <= nthread; nsy++)
1647             {
1648                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1649                 {
1650                     nsz = nthread/(nsx*nsy);
1651
1652                     /* Determine the number of grid points per thread */
1653                     gsize =
1654                         (div_round_up(n[XX], nsx) + ovl)*
1655                         (div_round_up(n[YY], nsy) + ovl)*
1656                         (div_round_up(n[ZZ], nsz) + ovl);
1657
1658                     /* Minimize the number of grids points per thread
1659                      * and, secondarily, the number of cuts in minor dimensions.
1660                      */
1661                     if (gsize_opt == -1 ||
1662                         gsize < gsize_opt ||
1663                         (gsize == gsize_opt &&
1664                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1665                     {
1666                         nsub[XX]  = nsx;
1667                         nsub[YY]  = nsy;
1668                         nsub[ZZ]  = nsz;
1669                         gsize_opt = gsize;
1670                     }
1671                 }
1672             }
1673         }
1674     }
1675
1676     env = getenv("GMX_PME_THREAD_DIVISION");
1677     if (env != NULL)
1678     {
1679         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1680     }
1681
1682     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1683     {
1684         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1685     }
1686 }
1687
1688 static void pmegrids_init(pmegrids_t *grids,
1689                           int nx, int ny, int nz, int nz_base,
1690                           int pme_order,
1691                           gmx_bool bUseThreads,
1692                           int nthread,
1693                           int overlap_x,
1694                           int overlap_y)
1695 {
1696     ivec n, n_base, g0, g1;
1697     int t, x, y, z, d, i, tfac;
1698     int max_comm_lines = -1;
1699
1700     n[XX] = nx - (pme_order - 1);
1701     n[YY] = ny - (pme_order - 1);
1702     n[ZZ] = nz - (pme_order - 1);
1703
1704     copy_ivec(n, n_base);
1705     n_base[ZZ] = nz_base;
1706
1707     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1708                  NULL);
1709
1710     grids->nthread = nthread;
1711
1712     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1713
1714     if (bUseThreads)
1715     {
1716         ivec nst;
1717         int gridsize;
1718
1719         for (d = 0; d < DIM; d++)
1720         {
1721             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1722         }
1723         set_grid_alignment(&nst[ZZ], pme_order);
1724
1725         if (debug)
1726         {
1727             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1728                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1729             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1730                     nx, ny, nz,
1731                     nst[XX], nst[YY], nst[ZZ]);
1732         }
1733
1734         snew(grids->grid_th, grids->nthread);
1735         t        = 0;
1736         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1737         set_gridsize_alignment(&gridsize, pme_order);
1738         snew_aligned(grids->grid_all,
1739                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1740                      SIMD4_ALIGNMENT);
1741
1742         for (x = 0; x < grids->nc[XX]; x++)
1743         {
1744             for (y = 0; y < grids->nc[YY]; y++)
1745             {
1746                 for (z = 0; z < grids->nc[ZZ]; z++)
1747                 {
1748                     pmegrid_init(&grids->grid_th[t],
1749                                  x, y, z,
1750                                  (n[XX]*(x  ))/grids->nc[XX],
1751                                  (n[YY]*(y  ))/grids->nc[YY],
1752                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1753                                  (n[XX]*(x+1))/grids->nc[XX],
1754                                  (n[YY]*(y+1))/grids->nc[YY],
1755                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1756                                  TRUE,
1757                                  pme_order,
1758                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1759                     t++;
1760                 }
1761             }
1762         }
1763     }
1764     else
1765     {
1766         grids->grid_th = NULL;
1767     }
1768
1769     snew(grids->g2t, DIM);
1770     tfac = 1;
1771     for (d = DIM-1; d >= 0; d--)
1772     {
1773         snew(grids->g2t[d], n[d]);
1774         t = 0;
1775         for (i = 0; i < n[d]; i++)
1776         {
1777             /* The second check should match the parameters
1778              * of the pmegrid_init call above.
1779              */
1780             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1781             {
1782                 t++;
1783             }
1784             grids->g2t[d][i] = t*tfac;
1785         }
1786
1787         tfac *= grids->nc[d];
1788
1789         switch (d)
1790         {
1791             case XX: max_comm_lines = overlap_x;     break;
1792             case YY: max_comm_lines = overlap_y;     break;
1793             case ZZ: max_comm_lines = pme_order - 1; break;
1794         }
1795         grids->nthread_comm[d] = 0;
1796         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1797                grids->nthread_comm[d] < grids->nc[d])
1798         {
1799             grids->nthread_comm[d]++;
1800         }
1801         if (debug != NULL)
1802         {
1803             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1804                     'x'+d, grids->nthread_comm[d]);
1805         }
1806         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1807          * work, but this is not a problematic restriction.
1808          */
1809         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1810         {
1811             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1812         }
1813     }
1814 }
1815
1816
1817 static void pmegrids_destroy(pmegrids_t *grids)
1818 {
1819     int t;
1820
1821     if (grids->grid.grid != NULL)
1822     {
1823         sfree(grids->grid.grid);
1824
1825         if (grids->nthread > 0)
1826         {
1827             for (t = 0; t < grids->nthread; t++)
1828             {
1829                 sfree(grids->grid_th[t].grid);
1830             }
1831             sfree(grids->grid_th);
1832         }
1833     }
1834 }
1835
1836
1837 static void realloc_work(pme_work_t *work, int nkx)
1838 {
1839     int simd_width;
1840
1841     if (nkx > work->nalloc)
1842     {
1843         work->nalloc = nkx;
1844         srenew(work->mhx, work->nalloc);
1845         srenew(work->mhy, work->nalloc);
1846         srenew(work->mhz, work->nalloc);
1847         srenew(work->m2, work->nalloc);
1848         /* Allocate an aligned pointer for SIMD operations, including extra
1849          * elements at the end for padding.
1850          */
1851 #ifdef PME_SIMD
1852         simd_width = GMX_SIMD_WIDTH_HERE;
1853 #else
1854         /* We can use any alignment, apart from 0, so we use 4 */
1855         simd_width = 4;
1856 #endif
1857         sfree_aligned(work->denom);
1858         sfree_aligned(work->tmp1);
1859         sfree_aligned(work->eterm);
1860         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1861         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1862         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1863         srenew(work->m2inv, work->nalloc);
1864     }
1865 }
1866
1867
1868 static void free_work(pme_work_t *work)
1869 {
1870     sfree(work->mhx);
1871     sfree(work->mhy);
1872     sfree(work->mhz);
1873     sfree(work->m2);
1874     sfree_aligned(work->denom);
1875     sfree_aligned(work->tmp1);
1876     sfree_aligned(work->eterm);
1877     sfree(work->m2inv);
1878 }
1879
1880
1881 #ifdef PME_SIMD
1882 /* Calculate exponentials through SIMD */
1883 inline static void calc_exponentials(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1884 {
1885     {
1886         const gmx_mm_pr two = gmx_set1_pr(2.0);
1887         gmx_mm_pr f_simd;
1888         gmx_mm_pr lu;
1889         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1890         int kx;
1891         f_simd = gmx_set1_pr(f);
1892         /* We only need to calculate from start. But since start is 0 or 1
1893          * and we want to use aligned loads/stores, we always start from 0.
1894          */
1895         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1896         {
1897             tmp_d1   = gmx_load_pr(d_aligned+kx);
1898             d_inv    = gmx_inv_pr(tmp_d1);
1899             tmp_r    = gmx_load_pr(r_aligned+kx);
1900             tmp_r    = gmx_exp_pr(tmp_r);
1901             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1902             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1903             gmx_store_pr(e_aligned+kx, tmp_e);
1904         }
1905     }
1906 }
1907 #else
1908 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1909 {
1910     int kx;
1911     for (kx = start; kx < end; kx++)
1912     {
1913         d[kx] = 1.0/d[kx];
1914     }
1915     for (kx = start; kx < end; kx++)
1916     {
1917         r[kx] = exp(r[kx]);
1918     }
1919     for (kx = start; kx < end; kx++)
1920     {
1921         e[kx] = f*r[kx]*d[kx];
1922     }
1923 }
1924 #endif
1925
1926
1927 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1928                          real ewaldcoeff, real vol,
1929                          gmx_bool bEnerVir,
1930                          int nthread, int thread)
1931 {
1932     /* do recip sum over local cells in grid */
1933     /* y major, z middle, x minor or continuous */
1934     t_complex *p0;
1935     int     kx, ky, kz, maxkx, maxky, maxkz;
1936     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1937     real    mx, my, mz;
1938     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1939     real    ets2, struct2, vfactor, ets2vf;
1940     real    d1, d2, energy = 0;
1941     real    by, bz;
1942     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1943     real    rxx, ryx, ryy, rzx, rzy, rzz;
1944     pme_work_t *work;
1945     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1946     real    mhxk, mhyk, mhzk, m2k;
1947     real    corner_fac;
1948     ivec    complex_order;
1949     ivec    local_ndata, local_offset, local_size;
1950     real    elfac;
1951
1952     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1953
1954     nx = pme->nkx;
1955     ny = pme->nky;
1956     nz = pme->nkz;
1957
1958     /* Dimensions should be identical for A/B grid, so we just use A here */
1959     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1960                                       complex_order,
1961                                       local_ndata,
1962                                       local_offset,
1963                                       local_size);
1964
1965     rxx = pme->recipbox[XX][XX];
1966     ryx = pme->recipbox[YY][XX];
1967     ryy = pme->recipbox[YY][YY];
1968     rzx = pme->recipbox[ZZ][XX];
1969     rzy = pme->recipbox[ZZ][YY];
1970     rzz = pme->recipbox[ZZ][ZZ];
1971
1972     maxkx = (nx+1)/2;
1973     maxky = (ny+1)/2;
1974     maxkz = nz/2+1;
1975
1976     work  = &pme->work[thread];
1977     mhx   = work->mhx;
1978     mhy   = work->mhy;
1979     mhz   = work->mhz;
1980     m2    = work->m2;
1981     denom = work->denom;
1982     tmp1  = work->tmp1;
1983     eterm = work->eterm;
1984     m2inv = work->m2inv;
1985
1986     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1987     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1988
1989     for (iyz = iyz0; iyz < iyz1; iyz++)
1990     {
1991         iy = iyz/local_ndata[ZZ];
1992         iz = iyz - iy*local_ndata[ZZ];
1993
1994         ky = iy + local_offset[YY];
1995
1996         if (ky < maxky)
1997         {
1998             my = ky;
1999         }
2000         else
2001         {
2002             my = (ky - ny);
2003         }
2004
2005         by = M_PI*vol*pme->bsp_mod[YY][ky];
2006
2007         kz = iz + local_offset[ZZ];
2008
2009         mz = kz;
2010
2011         bz = pme->bsp_mod[ZZ][kz];
2012
2013         /* 0.5 correction for corner points */
2014         corner_fac = 1;
2015         if (kz == 0 || kz == (nz+1)/2)
2016         {
2017             corner_fac = 0.5;
2018         }
2019
2020         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2021
2022         /* We should skip the k-space point (0,0,0) */
2023         /* Note that since here x is the minor index, local_offset[XX]=0 */
2024         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2025         {
2026             kxstart = local_offset[XX];
2027         }
2028         else
2029         {
2030             kxstart = local_offset[XX] + 1;
2031             p0++;
2032         }
2033         kxend = local_offset[XX] + local_ndata[XX];
2034
2035         if (bEnerVir)
2036         {
2037             /* More expensive inner loop, especially because of the storage
2038              * of the mh elements in array's.
2039              * Because x is the minor grid index, all mh elements
2040              * depend on kx for triclinic unit cells.
2041              */
2042
2043             /* Two explicit loops to avoid a conditional inside the loop */
2044             for (kx = kxstart; kx < maxkx; kx++)
2045             {
2046                 mx = kx;
2047
2048                 mhxk      = mx * rxx;
2049                 mhyk      = mx * ryx + my * ryy;
2050                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2051                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2052                 mhx[kx]   = mhxk;
2053                 mhy[kx]   = mhyk;
2054                 mhz[kx]   = mhzk;
2055                 m2[kx]    = m2k;
2056                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2057                 tmp1[kx]  = -factor*m2k;
2058             }
2059
2060             for (kx = maxkx; kx < kxend; kx++)
2061             {
2062                 mx = (kx - nx);
2063
2064                 mhxk      = mx * rxx;
2065                 mhyk      = mx * ryx + my * ryy;
2066                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2067                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2068                 mhx[kx]   = mhxk;
2069                 mhy[kx]   = mhyk;
2070                 mhz[kx]   = mhzk;
2071                 m2[kx]    = m2k;
2072                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2073                 tmp1[kx]  = -factor*m2k;
2074             }
2075
2076             for (kx = kxstart; kx < kxend; kx++)
2077             {
2078                 m2inv[kx] = 1.0/m2[kx];
2079             }
2080
2081             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2082
2083             for (kx = kxstart; kx < kxend; kx++, p0++)
2084             {
2085                 d1      = p0->re;
2086                 d2      = p0->im;
2087
2088                 p0->re  = d1*eterm[kx];
2089                 p0->im  = d2*eterm[kx];
2090
2091                 struct2 = 2.0*(d1*d1+d2*d2);
2092
2093                 tmp1[kx] = eterm[kx]*struct2;
2094             }
2095
2096             for (kx = kxstart; kx < kxend; kx++)
2097             {
2098                 ets2     = corner_fac*tmp1[kx];
2099                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2100                 energy  += ets2;
2101
2102                 ets2vf   = ets2*vfactor;
2103                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2104                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2105                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2106                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2107                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2108                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2109             }
2110         }
2111         else
2112         {
2113             /* We don't need to calculate the energy and the virial.
2114              * In this case the triclinic overhead is small.
2115              */
2116
2117             /* Two explicit loops to avoid a conditional inside the loop */
2118
2119             for (kx = kxstart; kx < maxkx; kx++)
2120             {
2121                 mx = kx;
2122
2123                 mhxk      = mx * rxx;
2124                 mhyk      = mx * ryx + my * ryy;
2125                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2126                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2127                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2128                 tmp1[kx]  = -factor*m2k;
2129             }
2130
2131             for (kx = maxkx; kx < kxend; kx++)
2132             {
2133                 mx = (kx - nx);
2134
2135                 mhxk      = mx * rxx;
2136                 mhyk      = mx * ryx + my * ryy;
2137                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2138                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2139                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2140                 tmp1[kx]  = -factor*m2k;
2141             }
2142
2143             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2144
2145             for (kx = kxstart; kx < kxend; kx++, p0++)
2146             {
2147                 d1      = p0->re;
2148                 d2      = p0->im;
2149
2150                 p0->re  = d1*eterm[kx];
2151                 p0->im  = d2*eterm[kx];
2152             }
2153         }
2154     }
2155
2156     if (bEnerVir)
2157     {
2158         /* Update virial with local values.
2159          * The virial is symmetric by definition.
2160          * this virial seems ok for isotropic scaling, but I'm
2161          * experiencing problems on semiisotropic membranes.
2162          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2163          */
2164         work->vir[XX][XX] = 0.25*virxx;
2165         work->vir[YY][YY] = 0.25*viryy;
2166         work->vir[ZZ][ZZ] = 0.25*virzz;
2167         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2168         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2169         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2170
2171         /* This energy should be corrected for a charged system */
2172         work->energy = 0.5*energy;
2173     }
2174
2175     /* Return the loop count */
2176     return local_ndata[YY]*local_ndata[XX];
2177 }
2178
2179 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2180                              real *mesh_energy, matrix vir)
2181 {
2182     /* This function sums output over threads
2183      * and should therefore only be called after thread synchronization.
2184      */
2185     int thread;
2186
2187     *mesh_energy = pme->work[0].energy;
2188     copy_mat(pme->work[0].vir, vir);
2189
2190     for (thread = 1; thread < nthread; thread++)
2191     {
2192         *mesh_energy += pme->work[thread].energy;
2193         m_add(vir, pme->work[thread].vir, vir);
2194     }
2195 }
2196
2197 #define DO_FSPLINE(order)                      \
2198     for (ithx = 0; (ithx < order); ithx++)              \
2199     {                                              \
2200         index_x = (i0+ithx)*pny*pnz;               \
2201         tx      = thx[ithx];                       \
2202         dx      = dthx[ithx];                      \
2203                                                \
2204         for (ithy = 0; (ithy < order); ithy++)          \
2205         {                                          \
2206             index_xy = index_x+(j0+ithy)*pnz;      \
2207             ty       = thy[ithy];                  \
2208             dy       = dthy[ithy];                 \
2209             fxy1     = fz1 = 0;                    \
2210                                                \
2211             for (ithz = 0; (ithz < order); ithz++)      \
2212             {                                      \
2213                 gval  = grid[index_xy+(k0+ithz)];  \
2214                 fxy1 += thz[ithz]*gval;            \
2215                 fz1  += dthz[ithz]*gval;           \
2216             }                                      \
2217             fx += dx*ty*fxy1;                      \
2218             fy += tx*dy*fxy1;                      \
2219             fz += tx*ty*fz1;                       \
2220         }                                          \
2221     }
2222
2223
2224 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2225                               gmx_bool bClearF, pme_atomcomm_t *atc,
2226                               splinedata_t *spline,
2227                               real scale)
2228 {
2229     /* sum forces for local particles */
2230     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2231     int     index_x, index_xy;
2232     int     nx, ny, nz, pnx, pny, pnz;
2233     int *   idxptr;
2234     real    tx, ty, dx, dy, qn;
2235     real    fx, fy, fz, gval;
2236     real    fxy1, fz1;
2237     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2238     int     norder;
2239     real    rxx, ryx, ryy, rzx, rzy, rzz;
2240     int     order;
2241
2242     pme_spline_work_t *work;
2243
2244 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2245     real           thz_buffer[12],  *thz_aligned;
2246     real           dthz_buffer[12], *dthz_aligned;
2247
2248     thz_aligned  = gmx_simd4_align_real(thz_buffer);
2249     dthz_aligned = gmx_simd4_align_real(dthz_buffer);
2250 #endif
2251
2252     work = pme->spline_work;
2253
2254     order = pme->pme_order;
2255     thx   = spline->theta[XX];
2256     thy   = spline->theta[YY];
2257     thz   = spline->theta[ZZ];
2258     dthx  = spline->dtheta[XX];
2259     dthy  = spline->dtheta[YY];
2260     dthz  = spline->dtheta[ZZ];
2261     nx    = pme->nkx;
2262     ny    = pme->nky;
2263     nz    = pme->nkz;
2264     pnx   = pme->pmegrid_nx;
2265     pny   = pme->pmegrid_ny;
2266     pnz   = pme->pmegrid_nz;
2267
2268     rxx   = pme->recipbox[XX][XX];
2269     ryx   = pme->recipbox[YY][XX];
2270     ryy   = pme->recipbox[YY][YY];
2271     rzx   = pme->recipbox[ZZ][XX];
2272     rzy   = pme->recipbox[ZZ][YY];
2273     rzz   = pme->recipbox[ZZ][ZZ];
2274
2275     for (nn = 0; nn < spline->n; nn++)
2276     {
2277         n  = spline->ind[nn];
2278         qn = scale*atc->q[n];
2279
2280         if (bClearF)
2281         {
2282             atc->f[n][XX] = 0;
2283             atc->f[n][YY] = 0;
2284             atc->f[n][ZZ] = 0;
2285         }
2286         if (qn != 0)
2287         {
2288             fx     = 0;
2289             fy     = 0;
2290             fz     = 0;
2291             idxptr = atc->idx[n];
2292             norder = nn*order;
2293
2294             i0   = idxptr[XX];
2295             j0   = idxptr[YY];
2296             k0   = idxptr[ZZ];
2297
2298             /* Pointer arithmetic alert, next six statements */
2299             thx  = spline->theta[XX] + norder;
2300             thy  = spline->theta[YY] + norder;
2301             thz  = spline->theta[ZZ] + norder;
2302             dthx = spline->dtheta[XX] + norder;
2303             dthy = spline->dtheta[YY] + norder;
2304             dthz = spline->dtheta[ZZ] + norder;
2305
2306             switch (order)
2307             {
2308                 case 4:
2309 #ifdef PME_SIMD4_SPREAD_GATHER
2310 #ifdef PME_SIMD4_UNALIGNED
2311 #define PME_GATHER_F_SIMD4_ORDER4
2312 #else
2313 #define PME_GATHER_F_SIMD4_ALIGNED
2314 #define PME_ORDER 4
2315 #endif
2316 #include "pme_simd4.h"
2317 #else
2318                     DO_FSPLINE(4);
2319 #endif
2320                     break;
2321                 case 5:
2322 #ifdef PME_SIMD4_SPREAD_GATHER
2323 #define PME_GATHER_F_SIMD4_ALIGNED
2324 #define PME_ORDER 5
2325 #include "pme_simd4.h"
2326 #else
2327                     DO_FSPLINE(5);
2328 #endif
2329                     break;
2330                 default:
2331                     DO_FSPLINE(order);
2332                     break;
2333             }
2334
2335             atc->f[n][XX] += -qn*( fx*nx*rxx );
2336             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2337             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2338         }
2339     }
2340     /* Since the energy and not forces are interpolated
2341      * the net force might not be exactly zero.
2342      * This can be solved by also interpolating F, but
2343      * that comes at a cost.
2344      * A better hack is to remove the net force every
2345      * step, but that must be done at a higher level
2346      * since this routine doesn't see all atoms if running
2347      * in parallel. Don't know how important it is?  EL 990726
2348      */
2349 }
2350
2351
2352 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2353                                    pme_atomcomm_t *atc)
2354 {
2355     splinedata_t *spline;
2356     int     n, ithx, ithy, ithz, i0, j0, k0;
2357     int     index_x, index_xy;
2358     int *   idxptr;
2359     real    energy, pot, tx, ty, qn, gval;
2360     real    *thx, *thy, *thz;
2361     int     norder;
2362     int     order;
2363
2364     spline = &atc->spline[0];
2365
2366     order = pme->pme_order;
2367
2368     energy = 0;
2369     for (n = 0; (n < atc->n); n++)
2370     {
2371         qn      = atc->q[n];
2372
2373         if (qn != 0)
2374         {
2375             idxptr = atc->idx[n];
2376             norder = n*order;
2377
2378             i0   = idxptr[XX];
2379             j0   = idxptr[YY];
2380             k0   = idxptr[ZZ];
2381
2382             /* Pointer arithmetic alert, next three statements */
2383             thx  = spline->theta[XX] + norder;
2384             thy  = spline->theta[YY] + norder;
2385             thz  = spline->theta[ZZ] + norder;
2386
2387             pot = 0;
2388             for (ithx = 0; (ithx < order); ithx++)
2389             {
2390                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2391                 tx      = thx[ithx];
2392
2393                 for (ithy = 0; (ithy < order); ithy++)
2394                 {
2395                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2396                     ty       = thy[ithy];
2397
2398                     for (ithz = 0; (ithz < order); ithz++)
2399                     {
2400                         gval  = grid[index_xy+(k0+ithz)];
2401                         pot  += tx*ty*thz[ithz]*gval;
2402                     }
2403
2404                 }
2405             }
2406
2407             energy += pot*qn;
2408         }
2409     }
2410
2411     return energy;
2412 }
2413
2414 /* Macro to force loop unrolling by fixing order.
2415  * This gives a significant performance gain.
2416  */
2417 #define CALC_SPLINE(order)                     \
2418     {                                              \
2419         int j, k, l;                                 \
2420         real dr, div;                               \
2421         real data[PME_ORDER_MAX];                  \
2422         real ddata[PME_ORDER_MAX];                 \
2423                                                \
2424         for (j = 0; (j < DIM); j++)                     \
2425         {                                          \
2426             dr  = xptr[j];                         \
2427                                                \
2428             /* dr is relative offset from lower cell limit */ \
2429             data[order-1] = 0;                     \
2430             data[1]       = dr;                          \
2431             data[0]       = 1 - dr;                      \
2432                                                \
2433             for (k = 3; (k < order); k++)               \
2434             {                                      \
2435                 div       = 1.0/(k - 1.0);               \
2436                 data[k-1] = div*dr*data[k-2];      \
2437                 for (l = 1; (l < (k-1)); l++)           \
2438                 {                                  \
2439                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2440                                        data[k-l-1]);                \
2441                 }                                  \
2442                 data[0] = div*(1-dr)*data[0];      \
2443             }                                      \
2444             /* differentiate */                    \
2445             ddata[0] = -data[0];                   \
2446             for (k = 1; (k < order); k++)               \
2447             {                                      \
2448                 ddata[k] = data[k-1] - data[k];    \
2449             }                                      \
2450                                                \
2451             div           = 1.0/(order - 1);                 \
2452             data[order-1] = div*dr*data[order-2];  \
2453             for (l = 1; (l < (order-1)); l++)           \
2454             {                                      \
2455                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2456                                        (order-l-dr)*data[order-l-1]); \
2457             }                                      \
2458             data[0] = div*(1 - dr)*data[0];        \
2459                                                \
2460             for (k = 0; k < order; k++)                 \
2461             {                                      \
2462                 theta[j][i*order+k]  = data[k];    \
2463                 dtheta[j][i*order+k] = ddata[k];   \
2464             }                                      \
2465         }                                          \
2466     }
2467
2468 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2469                    rvec fractx[], int nr, int ind[], real charge[],
2470                    gmx_bool bFreeEnergy)
2471 {
2472     /* construct splines for local atoms */
2473     int  i, ii;
2474     real *xptr;
2475
2476     for (i = 0; i < nr; i++)
2477     {
2478         /* With free energy we do not use the charge check.
2479          * In most cases this will be more efficient than calling make_bsplines
2480          * twice, since usually more than half the particles have charges.
2481          */
2482         ii = ind[i];
2483         if (bFreeEnergy || charge[ii] != 0.0)
2484         {
2485             xptr = fractx[ii];
2486             switch (order)
2487             {
2488                 case 4:  CALC_SPLINE(4);     break;
2489                 case 5:  CALC_SPLINE(5);     break;
2490                 default: CALC_SPLINE(order); break;
2491             }
2492         }
2493     }
2494 }
2495
2496
2497 void make_dft_mod(real *mod, real *data, int ndata)
2498 {
2499     int i, j;
2500     real sc, ss, arg;
2501
2502     for (i = 0; i < ndata; i++)
2503     {
2504         sc = ss = 0;
2505         for (j = 0; j < ndata; j++)
2506         {
2507             arg = (2.0*M_PI*i*j)/ndata;
2508             sc += data[j]*cos(arg);
2509             ss += data[j]*sin(arg);
2510         }
2511         mod[i] = sc*sc+ss*ss;
2512     }
2513     for (i = 0; i < ndata; i++)
2514     {
2515         if (mod[i] < 1e-7)
2516         {
2517             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2518         }
2519     }
2520 }
2521
2522
2523 static void make_bspline_moduli(splinevec bsp_mod,
2524                                 int nx, int ny, int nz, int order)
2525 {
2526     int nmax = max(nx, max(ny, nz));
2527     real *data, *ddata, *bsp_data;
2528     int i, k, l;
2529     real div;
2530
2531     snew(data, order);
2532     snew(ddata, order);
2533     snew(bsp_data, nmax);
2534
2535     data[order-1] = 0;
2536     data[1]       = 0;
2537     data[0]       = 1;
2538
2539     for (k = 3; k < order; k++)
2540     {
2541         div       = 1.0/(k-1.0);
2542         data[k-1] = 0;
2543         for (l = 1; l < (k-1); l++)
2544         {
2545             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2546         }
2547         data[0] = div*data[0];
2548     }
2549     /* differentiate */
2550     ddata[0] = -data[0];
2551     for (k = 1; k < order; k++)
2552     {
2553         ddata[k] = data[k-1]-data[k];
2554     }
2555     div           = 1.0/(order-1);
2556     data[order-1] = 0;
2557     for (l = 1; l < (order-1); l++)
2558     {
2559         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2560     }
2561     data[0] = div*data[0];
2562
2563     for (i = 0; i < nmax; i++)
2564     {
2565         bsp_data[i] = 0;
2566     }
2567     for (i = 1; i <= order; i++)
2568     {
2569         bsp_data[i] = data[i-1];
2570     }
2571
2572     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2573     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2574     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2575
2576     sfree(data);
2577     sfree(ddata);
2578     sfree(bsp_data);
2579 }
2580
2581
2582 /* Return the P3M optimal influence function */
2583 static double do_p3m_influence(double z, int order)
2584 {
2585     double z2, z4;
2586
2587     z2 = z*z;
2588     z4 = z2*z2;
2589
2590     /* The formula and most constants can be found in:
2591      * Ballenegger et al., JCTC 8, 936 (2012)
2592      */
2593     switch (order)
2594     {
2595         case 2:
2596             return 1.0 - 2.0*z2/3.0;
2597             break;
2598         case 3:
2599             return 1.0 - z2 + 2.0*z4/15.0;
2600             break;
2601         case 4:
2602             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2603             break;
2604         case 5:
2605             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2606             break;
2607         case 6:
2608             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2609             break;
2610         case 7:
2611             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2612         case 8:
2613             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2614             break;
2615     }
2616
2617     return 0.0;
2618 }
2619
2620 /* Calculate the P3M B-spline moduli for one dimension */
2621 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2622 {
2623     double zarg, zai, sinzai, infl;
2624     int    maxk, i;
2625
2626     if (order > 8)
2627     {
2628         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2629     }
2630
2631     zarg = M_PI/n;
2632
2633     maxk = (n + 1)/2;
2634
2635     for (i = -maxk; i < 0; i++)
2636     {
2637         zai          = zarg*i;
2638         sinzai       = sin(zai);
2639         infl         = do_p3m_influence(sinzai, order);
2640         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2641     }
2642     bsp_mod[0] = 1.0;
2643     for (i = 1; i < maxk; i++)
2644     {
2645         zai        = zarg*i;
2646         sinzai     = sin(zai);
2647         infl       = do_p3m_influence(sinzai, order);
2648         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2649     }
2650 }
2651
2652 /* Calculate the P3M B-spline moduli */
2653 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2654                                     int nx, int ny, int nz, int order)
2655 {
2656     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2657     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2658     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2659 }
2660
2661
2662 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2663 {
2664     int nslab, n, i;
2665     int fw, bw;
2666
2667     nslab = atc->nslab;
2668
2669     n = 0;
2670     for (i = 1; i <= nslab/2; i++)
2671     {
2672         fw = (atc->nodeid + i) % nslab;
2673         bw = (atc->nodeid - i + nslab) % nslab;
2674         if (n < nslab - 1)
2675         {
2676             atc->node_dest[n] = fw;
2677             atc->node_src[n]  = bw;
2678             n++;
2679         }
2680         if (n < nslab - 1)
2681         {
2682             atc->node_dest[n] = bw;
2683             atc->node_src[n]  = fw;
2684             n++;
2685         }
2686     }
2687 }
2688
2689 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2690 {
2691     int thread;
2692
2693     if (NULL != log)
2694     {
2695         fprintf(log, "Destroying PME data structures.\n");
2696     }
2697
2698     sfree((*pmedata)->nnx);
2699     sfree((*pmedata)->nny);
2700     sfree((*pmedata)->nnz);
2701
2702     pmegrids_destroy(&(*pmedata)->pmegridA);
2703
2704     sfree((*pmedata)->fftgridA);
2705     sfree((*pmedata)->cfftgridA);
2706     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2707
2708     if ((*pmedata)->pmegridB.grid.grid != NULL)
2709     {
2710         pmegrids_destroy(&(*pmedata)->pmegridB);
2711         sfree((*pmedata)->fftgridB);
2712         sfree((*pmedata)->cfftgridB);
2713         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2714     }
2715     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2716     {
2717         free_work(&(*pmedata)->work[thread]);
2718     }
2719     sfree((*pmedata)->work);
2720
2721     sfree(*pmedata);
2722     *pmedata = NULL;
2723
2724     return 0;
2725 }
2726
2727 static int mult_up(int n, int f)
2728 {
2729     return ((n + f - 1)/f)*f;
2730 }
2731
2732
2733 static double pme_load_imbalance(gmx_pme_t pme)
2734 {
2735     int    nma, nmi;
2736     double n1, n2, n3;
2737
2738     nma = pme->nnodes_major;
2739     nmi = pme->nnodes_minor;
2740
2741     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2742     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2743     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2744
2745     /* pme_solve is roughly double the cost of an fft */
2746
2747     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2748 }
2749
2750 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
2751                           int dimind, gmx_bool bSpread)
2752 {
2753     int nk, k, s, thread;
2754
2755     atc->dimind    = dimind;
2756     atc->nslab     = 1;
2757     atc->nodeid    = 0;
2758     atc->pd_nalloc = 0;
2759 #ifdef GMX_MPI
2760     if (pme->nnodes > 1)
2761     {
2762         atc->mpi_comm = pme->mpi_comm_d[dimind];
2763         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2764         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2765     }
2766     if (debug)
2767     {
2768         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2769     }
2770 #endif
2771
2772     atc->bSpread   = bSpread;
2773     atc->pme_order = pme->pme_order;
2774
2775     if (atc->nslab > 1)
2776     {
2777         /* These three allocations are not required for particle decomp. */
2778         snew(atc->node_dest, atc->nslab);
2779         snew(atc->node_src, atc->nslab);
2780         setup_coordinate_communication(atc);
2781
2782         snew(atc->count_thread, pme->nthread);
2783         for (thread = 0; thread < pme->nthread; thread++)
2784         {
2785             snew(atc->count_thread[thread], atc->nslab);
2786         }
2787         atc->count = atc->count_thread[0];
2788         snew(atc->rcount, atc->nslab);
2789         snew(atc->buf_index, atc->nslab);
2790     }
2791
2792     atc->nthread = pme->nthread;
2793     if (atc->nthread > 1)
2794     {
2795         snew(atc->thread_plist, atc->nthread);
2796     }
2797     snew(atc->spline, atc->nthread);
2798     for (thread = 0; thread < atc->nthread; thread++)
2799     {
2800         if (atc->nthread > 1)
2801         {
2802             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2803             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2804         }
2805         snew(atc->spline[thread].thread_one, pme->nthread);
2806         atc->spline[thread].thread_one[thread] = 1;
2807     }
2808 }
2809
2810 static void
2811 init_overlap_comm(pme_overlap_t *  ol,
2812                   int              norder,
2813 #ifdef GMX_MPI
2814                   MPI_Comm         comm,
2815 #endif
2816                   int              nnodes,
2817                   int              nodeid,
2818                   int              ndata,
2819                   int              commplainsize)
2820 {
2821     int lbnd, rbnd, maxlr, b, i;
2822     int exten;
2823     int nn, nk;
2824     pme_grid_comm_t *pgc;
2825     gmx_bool bCont;
2826     int fft_start, fft_end, send_index1, recv_index1;
2827 #ifdef GMX_MPI
2828     MPI_Status stat;
2829
2830     ol->mpi_comm = comm;
2831 #endif
2832
2833     ol->nnodes = nnodes;
2834     ol->nodeid = nodeid;
2835
2836     /* Linear translation of the PME grid won't affect reciprocal space
2837      * calculations, so to optimize we only interpolate "upwards",
2838      * which also means we only have to consider overlap in one direction.
2839      * I.e., particles on this node might also be spread to grid indices
2840      * that belong to higher nodes (modulo nnodes)
2841      */
2842
2843     snew(ol->s2g0, ol->nnodes+1);
2844     snew(ol->s2g1, ol->nnodes);
2845     if (debug)
2846     {
2847         fprintf(debug, "PME slab boundaries:");
2848     }
2849     for (i = 0; i < nnodes; i++)
2850     {
2851         /* s2g0 the local interpolation grid start.
2852          * s2g1 the local interpolation grid end.
2853          * Because grid overlap communication only goes forward,
2854          * the grid the slabs for fft's should be rounded down.
2855          */
2856         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2857         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2858
2859         if (debug)
2860         {
2861             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2862         }
2863     }
2864     ol->s2g0[nnodes] = ndata;
2865     if (debug)
2866     {
2867         fprintf(debug, "\n");
2868     }
2869
2870     /* Determine with how many nodes we need to communicate the grid overlap */
2871     b = 0;
2872     do
2873     {
2874         b++;
2875         bCont = FALSE;
2876         for (i = 0; i < nnodes; i++)
2877         {
2878             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2879                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2880             {
2881                 bCont = TRUE;
2882             }
2883         }
2884     }
2885     while (bCont && b < nnodes);
2886     ol->noverlap_nodes = b - 1;
2887
2888     snew(ol->send_id, ol->noverlap_nodes);
2889     snew(ol->recv_id, ol->noverlap_nodes);
2890     for (b = 0; b < ol->noverlap_nodes; b++)
2891     {
2892         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2893         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2894     }
2895     snew(ol->comm_data, ol->noverlap_nodes);
2896
2897     ol->send_size = 0;
2898     for (b = 0; b < ol->noverlap_nodes; b++)
2899     {
2900         pgc = &ol->comm_data[b];
2901         /* Send */
2902         fft_start        = ol->s2g0[ol->send_id[b]];
2903         fft_end          = ol->s2g0[ol->send_id[b]+1];
2904         if (ol->send_id[b] < nodeid)
2905         {
2906             fft_start += ndata;
2907             fft_end   += ndata;
2908         }
2909         send_index1       = ol->s2g1[nodeid];
2910         send_index1       = min(send_index1, fft_end);
2911         pgc->send_index0  = fft_start;
2912         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2913         ol->send_size    += pgc->send_nindex;
2914
2915         /* We always start receiving to the first index of our slab */
2916         fft_start        = ol->s2g0[ol->nodeid];
2917         fft_end          = ol->s2g0[ol->nodeid+1];
2918         recv_index1      = ol->s2g1[ol->recv_id[b]];
2919         if (ol->recv_id[b] > nodeid)
2920         {
2921             recv_index1 -= ndata;
2922         }
2923         recv_index1      = min(recv_index1, fft_end);
2924         pgc->recv_index0 = fft_start;
2925         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2926     }
2927
2928 #ifdef GMX_MPI
2929     /* Communicate the buffer sizes to receive */
2930     for (b = 0; b < ol->noverlap_nodes; b++)
2931     {
2932         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2933                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2934                      ol->mpi_comm, &stat);
2935     }
2936 #endif
2937
2938     /* For non-divisible grid we need pme_order iso pme_order-1 */
2939     snew(ol->sendbuf, norder*commplainsize);
2940     snew(ol->recvbuf, norder*commplainsize);
2941 }
2942
2943 static void
2944 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2945                               int **global_to_local,
2946                               real **fraction_shift)
2947 {
2948     int i;
2949     int * gtl;
2950     real * fsh;
2951
2952     snew(gtl, 5*n);
2953     snew(fsh, 5*n);
2954     for (i = 0; (i < 5*n); i++)
2955     {
2956         /* Determine the global to local grid index */
2957         gtl[i] = (i - local_start + n) % n;
2958         /* For coordinates that fall within the local grid the fraction
2959          * is correct, we don't need to shift it.
2960          */
2961         fsh[i] = 0;
2962         if (local_range < n)
2963         {
2964             /* Due to rounding issues i could be 1 beyond the lower or
2965              * upper boundary of the local grid. Correct the index for this.
2966              * If we shift the index, we need to shift the fraction by
2967              * the same amount in the other direction to not affect
2968              * the weights.
2969              * Note that due to this shifting the weights at the end of
2970              * the spline might change, but that will only involve values
2971              * between zero and values close to the precision of a real,
2972              * which is anyhow the accuracy of the whole mesh calculation.
2973              */
2974             /* With local_range=0 we should not change i=local_start */
2975             if (i % n != local_start)
2976             {
2977                 if (gtl[i] == n-1)
2978                 {
2979                     gtl[i] = 0;
2980                     fsh[i] = -1;
2981                 }
2982                 else if (gtl[i] == local_range)
2983                 {
2984                     gtl[i] = local_range - 1;
2985                     fsh[i] = 1;
2986                 }
2987             }
2988         }
2989     }
2990
2991     *global_to_local = gtl;
2992     *fraction_shift  = fsh;
2993 }
2994
2995 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
2996 {
2997     pme_spline_work_t *work;
2998
2999 #ifdef PME_SIMD4_SPREAD_GATHER
3000     real         tmp[12], *tmp_aligned;
3001     gmx_simd4_pr zero_S;
3002     gmx_simd4_pr real_mask_S0, real_mask_S1;
3003     int          of, i;
3004
3005     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3006
3007     tmp_aligned = gmx_simd4_align_real(tmp);
3008
3009     zero_S = gmx_simd4_setzero_pr();
3010
3011     /* Generate bit masks to mask out the unused grid entries,
3012      * as we only operate on order of the 8 grid entries that are
3013      * load into 2 SIMD registers.
3014      */
3015     for (of = 0; of < 8-(order-1); of++)
3016     {
3017         for (i = 0; i < 8; i++)
3018         {
3019             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3020         }
3021         real_mask_S0      = gmx_simd4_load_pr(tmp_aligned);
3022         real_mask_S1      = gmx_simd4_load_pr(tmp_aligned+4);
3023         work->mask_S0[of] = gmx_simd4_cmplt_pr(real_mask_S0, zero_S);
3024         work->mask_S1[of] = gmx_simd4_cmplt_pr(real_mask_S1, zero_S);
3025     }
3026 #else
3027     work = NULL;
3028 #endif
3029
3030     return work;
3031 }
3032
3033 void gmx_pme_check_restrictions(int pme_order,
3034                                 int nkx, int nky, int nkz,
3035                                 int nnodes_major,
3036                                 int nnodes_minor,
3037                                 gmx_bool bUseThreads,
3038                                 gmx_bool bFatal,
3039                                 gmx_bool *bValidSettings)
3040 {
3041     if (pme_order > PME_ORDER_MAX)
3042     {
3043         if (!bFatal)
3044         {
3045             *bValidSettings = FALSE;
3046             return;
3047         }
3048         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3049                   pme_order, PME_ORDER_MAX);
3050     }
3051
3052     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3053         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3054         nkz <= pme_order)
3055     {
3056         if (!bFatal)
3057         {
3058             *bValidSettings = FALSE;
3059             return;
3060         }
3061         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3062                   pme_order);
3063     }
3064
3065     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3066      * We only allow multiple communication pulses in dim 1, not in dim 0.
3067      */
3068     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3069                         nkx != nnodes_major*(pme_order - 1)))
3070     {
3071         if (!bFatal)
3072         {
3073             *bValidSettings = FALSE;
3074             return;
3075         }
3076         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3077                   nkx/(double)nnodes_major, pme_order);
3078     }
3079
3080     if (bValidSettings != NULL)
3081     {
3082         *bValidSettings = TRUE;
3083     }
3084
3085     return;
3086 }
3087
3088 int gmx_pme_init(gmx_pme_t *         pmedata,
3089                  t_commrec *         cr,
3090                  int                 nnodes_major,
3091                  int                 nnodes_minor,
3092                  t_inputrec *        ir,
3093                  int                 homenr,
3094                  gmx_bool            bFreeEnergy,
3095                  gmx_bool            bReproducible,
3096                  int                 nthread)
3097 {
3098     gmx_pme_t pme = NULL;
3099
3100     int  use_threads, sum_use_threads;
3101     ivec ndata;
3102
3103     if (debug)
3104     {
3105         fprintf(debug, "Creating PME data structures.\n");
3106     }
3107     snew(pme, 1);
3108
3109     pme->redist_init         = FALSE;
3110     pme->sum_qgrid_tmp       = NULL;
3111     pme->sum_qgrid_dd_tmp    = NULL;
3112     pme->buf_nalloc          = 0;
3113     pme->redist_buf_nalloc   = 0;
3114
3115     pme->nnodes              = 1;
3116     pme->bPPnode             = TRUE;
3117
3118     pme->nnodes_major        = nnodes_major;
3119     pme->nnodes_minor        = nnodes_minor;
3120
3121 #ifdef GMX_MPI
3122     if (nnodes_major*nnodes_minor > 1)
3123     {
3124         pme->mpi_comm = cr->mpi_comm_mygroup;
3125
3126         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3127         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3128         if (pme->nnodes != nnodes_major*nnodes_minor)
3129         {
3130             gmx_incons("PME node count mismatch");
3131         }
3132     }
3133     else
3134     {
3135         pme->mpi_comm = MPI_COMM_NULL;
3136     }
3137 #endif
3138
3139     if (pme->nnodes == 1)
3140     {
3141 #ifdef GMX_MPI
3142         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3143         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3144 #endif
3145         pme->ndecompdim   = 0;
3146         pme->nodeid_major = 0;
3147         pme->nodeid_minor = 0;
3148 #ifdef GMX_MPI
3149         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3150 #endif
3151     }
3152     else
3153     {
3154         if (nnodes_minor == 1)
3155         {
3156 #ifdef GMX_MPI
3157             pme->mpi_comm_d[0] = pme->mpi_comm;
3158             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3159 #endif
3160             pme->ndecompdim   = 1;
3161             pme->nodeid_major = pme->nodeid;
3162             pme->nodeid_minor = 0;
3163
3164         }
3165         else if (nnodes_major == 1)
3166         {
3167 #ifdef GMX_MPI
3168             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3169             pme->mpi_comm_d[1] = pme->mpi_comm;
3170 #endif
3171             pme->ndecompdim   = 1;
3172             pme->nodeid_major = 0;
3173             pme->nodeid_minor = pme->nodeid;
3174         }
3175         else
3176         {
3177             if (pme->nnodes % nnodes_major != 0)
3178             {
3179                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3180             }
3181             pme->ndecompdim = 2;
3182
3183 #ifdef GMX_MPI
3184             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3185                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3186             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3187                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3188
3189             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3190             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3191             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3192             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3193 #endif
3194         }
3195         pme->bPPnode = (cr->duty & DUTY_PP);
3196     }
3197
3198     pme->nthread = nthread;
3199
3200     /* Check if any of the PME MPI ranks uses threads */
3201     use_threads = (pme->nthread > 1 ? 1 : 0);
3202 #ifdef GMX_MPI
3203     if (pme->nnodes > 1)
3204     {
3205         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3206                       MPI_SUM, pme->mpi_comm);
3207     }
3208     else
3209 #endif
3210     {
3211         sum_use_threads = use_threads;
3212     }
3213     pme->bUseThreads = (sum_use_threads > 0);
3214
3215     if (ir->ePBC == epbcSCREW)
3216     {
3217         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3218     }
3219
3220     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3221     pme->nkx         = ir->nkx;
3222     pme->nky         = ir->nky;
3223     pme->nkz         = ir->nkz;
3224     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3225     pme->pme_order   = ir->pme_order;
3226     pme->epsilon_r   = ir->epsilon_r;
3227
3228     /* If we violate restrictions, generate a fatal error here */
3229     gmx_pme_check_restrictions(pme->pme_order,
3230                                pme->nkx, pme->nky, pme->nkz,
3231                                pme->nnodes_major,
3232                                pme->nnodes_minor,
3233                                pme->bUseThreads,
3234                                TRUE,
3235                                NULL);
3236
3237     if (pme->nnodes > 1)
3238     {
3239         double imbal;
3240
3241 #ifdef GMX_MPI
3242         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3243         MPI_Type_commit(&(pme->rvec_mpi));
3244 #endif
3245
3246         /* Note that the charge spreading and force gathering, which usually
3247          * takes about the same amount of time as FFT+solve_pme,
3248          * is always fully load balanced
3249          * (unless the charge distribution is inhomogeneous).
3250          */
3251
3252         imbal = pme_load_imbalance(pme);
3253         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3254         {
3255             fprintf(stderr,
3256                     "\n"
3257                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3258                     "      For optimal PME load balancing\n"
3259                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3260                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3261                     "\n",
3262                     (int)((imbal-1)*100 + 0.5),
3263                     pme->nkx, pme->nky, pme->nnodes_major,
3264                     pme->nky, pme->nkz, pme->nnodes_minor);
3265         }
3266     }
3267
3268     /* For non-divisible grid we need pme_order iso pme_order-1 */
3269     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3270      * y is always copied through a buffer: we don't need padding in z,
3271      * but we do need the overlap in x because of the communication order.
3272      */
3273     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3274 #ifdef GMX_MPI
3275                       pme->mpi_comm_d[0],
3276 #endif
3277                       pme->nnodes_major, pme->nodeid_major,
3278                       pme->nkx,
3279                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3280
3281     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3282      * We do this with an offset buffer of equal size, so we need to allocate
3283      * extra for the offset. That's what the (+1)*pme->nkz is for.
3284      */
3285     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3286 #ifdef GMX_MPI
3287                       pme->mpi_comm_d[1],
3288 #endif
3289                       pme->nnodes_minor, pme->nodeid_minor,
3290                       pme->nky,
3291                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3292
3293     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3294      * Note that gmx_pme_check_restrictions checked for this already.
3295      */
3296     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3297     {
3298         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3299     }
3300
3301     snew(pme->bsp_mod[XX], pme->nkx);
3302     snew(pme->bsp_mod[YY], pme->nky);
3303     snew(pme->bsp_mod[ZZ], pme->nkz);
3304
3305     /* The required size of the interpolation grid, including overlap.
3306      * The allocated size (pmegrid_n?) might be slightly larger.
3307      */
3308     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3309         pme->overlap[0].s2g0[pme->nodeid_major];
3310     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3311         pme->overlap[1].s2g0[pme->nodeid_minor];
3312     pme->pmegrid_nz_base = pme->nkz;
3313     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3314     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3315
3316     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3317     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3318     pme->pmegrid_start_iz = 0;
3319
3320     make_gridindex5_to_localindex(pme->nkx,
3321                                   pme->pmegrid_start_ix,
3322                                   pme->pmegrid_nx - (pme->pme_order-1),
3323                                   &pme->nnx, &pme->fshx);
3324     make_gridindex5_to_localindex(pme->nky,
3325                                   pme->pmegrid_start_iy,
3326                                   pme->pmegrid_ny - (pme->pme_order-1),
3327                                   &pme->nny, &pme->fshy);
3328     make_gridindex5_to_localindex(pme->nkz,
3329                                   pme->pmegrid_start_iz,
3330                                   pme->pmegrid_nz_base,
3331                                   &pme->nnz, &pme->fshz);
3332
3333     pmegrids_init(&pme->pmegridA,
3334                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3335                   pme->pmegrid_nz_base,
3336                   pme->pme_order,
3337                   pme->bUseThreads,
3338                   pme->nthread,
3339                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3340                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3341
3342     pme->spline_work = make_pme_spline_work(pme->pme_order);
3343
3344     ndata[0] = pme->nkx;
3345     ndata[1] = pme->nky;
3346     ndata[2] = pme->nkz;
3347
3348     /* This routine will allocate the grid data to fit the FFTs */
3349     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3350                             &pme->fftgridA, &pme->cfftgridA,
3351                             pme->mpi_comm_d,
3352                             bReproducible, pme->nthread);
3353
3354     if (bFreeEnergy)
3355     {
3356         pmegrids_init(&pme->pmegridB,
3357                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3358                       pme->pmegrid_nz_base,
3359                       pme->pme_order,
3360                       pme->bUseThreads,
3361                       pme->nthread,
3362                       pme->nkx % pme->nnodes_major != 0,
3363                       pme->nky % pme->nnodes_minor != 0);
3364
3365         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3366                                 &pme->fftgridB, &pme->cfftgridB,
3367                                 pme->mpi_comm_d,
3368                                 bReproducible, pme->nthread);
3369     }
3370     else
3371     {
3372         pme->pmegridB.grid.grid = NULL;
3373         pme->fftgridB           = NULL;
3374         pme->cfftgridB          = NULL;
3375     }
3376
3377     if (!pme->bP3M)
3378     {
3379         /* Use plain SPME B-spline interpolation */
3380         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3381     }
3382     else
3383     {
3384         /* Use the P3M grid-optimized influence function */
3385         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3386     }
3387
3388     /* Use atc[0] for spreading */
3389     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3390     if (pme->ndecompdim >= 2)
3391     {
3392         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3393     }
3394
3395     if (pme->nnodes == 1)
3396     {
3397         pme->atc[0].n = homenr;
3398         pme_realloc_atomcomm_things(&pme->atc[0]);
3399     }
3400
3401     {
3402         int thread;
3403
3404         /* Use fft5d, order after FFT is y major, z, x minor */
3405
3406         snew(pme->work, pme->nthread);
3407         for (thread = 0; thread < pme->nthread; thread++)
3408         {
3409             realloc_work(&pme->work[thread], pme->nkx);
3410         }
3411     }
3412
3413     *pmedata = pme;
3414
3415     return 0;
3416 }
3417
3418 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3419 {
3420     int d, t;
3421
3422     for (d = 0; d < DIM; d++)
3423     {
3424         if (new->grid.n[d] > old->grid.n[d])
3425         {
3426             return;
3427         }
3428     }
3429
3430     sfree_aligned(new->grid.grid);
3431     new->grid.grid = old->grid.grid;
3432
3433     if (new->grid_th != NULL && new->nthread == old->nthread)
3434     {
3435         sfree_aligned(new->grid_all);
3436         for (t = 0; t < new->nthread; t++)
3437         {
3438             new->grid_th[t].grid = old->grid_th[t].grid;
3439         }
3440     }
3441 }
3442
3443 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3444                    t_commrec *         cr,
3445                    gmx_pme_t           pme_src,
3446                    const t_inputrec *  ir,
3447                    ivec                grid_size)
3448 {
3449     t_inputrec irc;
3450     int homenr;
3451     int ret;
3452
3453     irc     = *ir;
3454     irc.nkx = grid_size[XX];
3455     irc.nky = grid_size[YY];
3456     irc.nkz = grid_size[ZZ];
3457
3458     if (pme_src->nnodes == 1)
3459     {
3460         homenr = pme_src->atc[0].n;
3461     }
3462     else
3463     {
3464         homenr = -1;
3465     }
3466
3467     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3468                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3469
3470     if (ret == 0)
3471     {
3472         /* We can easily reuse the allocated pme grids in pme_src */
3473         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3474         /* We would like to reuse the fft grids, but that's harder */
3475     }
3476
3477     return ret;
3478 }
3479
3480
3481 static void copy_local_grid(gmx_pme_t pme,
3482                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3483 {
3484     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3485     int  fft_my, fft_mz;
3486     int  nsx, nsy, nsz;
3487     ivec nf;
3488     int  offx, offy, offz, x, y, z, i0, i0t;
3489     int  d;
3490     pmegrid_t *pmegrid;
3491     real *grid_th;
3492
3493     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3494                                    local_fft_ndata,
3495                                    local_fft_offset,
3496                                    local_fft_size);
3497     fft_my = local_fft_size[YY];
3498     fft_mz = local_fft_size[ZZ];
3499
3500     pmegrid = &pmegrids->grid_th[thread];
3501
3502     nsx = pmegrid->s[XX];
3503     nsy = pmegrid->s[YY];
3504     nsz = pmegrid->s[ZZ];
3505
3506     for (d = 0; d < DIM; d++)
3507     {
3508         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3509                     local_fft_ndata[d] - pmegrid->offset[d]);
3510     }
3511
3512     offx = pmegrid->offset[XX];
3513     offy = pmegrid->offset[YY];
3514     offz = pmegrid->offset[ZZ];
3515
3516     /* Directly copy the non-overlapping parts of the local grids.
3517      * This also initializes the full grid.
3518      */
3519     grid_th = pmegrid->grid;
3520     for (x = 0; x < nf[XX]; x++)
3521     {
3522         for (y = 0; y < nf[YY]; y++)
3523         {
3524             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3525             i0t = (x*nsy + y)*nsz;
3526             for (z = 0; z < nf[ZZ]; z++)
3527             {
3528                 fftgrid[i0+z] = grid_th[i0t+z];
3529             }
3530         }
3531     }
3532 }
3533
3534 static void
3535 reduce_threadgrid_overlap(gmx_pme_t pme,
3536                           const pmegrids_t *pmegrids, int thread,
3537                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3538 {
3539     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3540     int  fft_nx, fft_ny, fft_nz;
3541     int  fft_my, fft_mz;
3542     int  buf_my = -1;
3543     int  nsx, nsy, nsz;
3544     ivec ne;
3545     int  offx, offy, offz, x, y, z, i0, i0t;
3546     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3547     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3548     gmx_bool bCommX, bCommY;
3549     int  d;
3550     int  thread_f;
3551     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3552     const real *grid_th;
3553     real *commbuf = NULL;
3554
3555     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3556                                    local_fft_ndata,
3557                                    local_fft_offset,
3558                                    local_fft_size);
3559     fft_nx = local_fft_ndata[XX];
3560     fft_ny = local_fft_ndata[YY];
3561     fft_nz = local_fft_ndata[ZZ];
3562
3563     fft_my = local_fft_size[YY];
3564     fft_mz = local_fft_size[ZZ];
3565
3566     /* This routine is called when all thread have finished spreading.
3567      * Here each thread sums grid contributions calculated by other threads
3568      * to the thread local grid volume.
3569      * To minimize the number of grid copying operations,
3570      * this routines sums immediately from the pmegrid to the fftgrid.
3571      */
3572
3573     /* Determine which part of the full node grid we should operate on,
3574      * this is our thread local part of the full grid.
3575      */
3576     pmegrid = &pmegrids->grid_th[thread];
3577
3578     for (d = 0; d < DIM; d++)
3579     {
3580         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3581                     local_fft_ndata[d]);
3582     }
3583
3584     offx = pmegrid->offset[XX];
3585     offy = pmegrid->offset[YY];
3586     offz = pmegrid->offset[ZZ];
3587
3588
3589     bClearBufX  = TRUE;
3590     bClearBufY  = TRUE;
3591     bClearBufXY = TRUE;
3592
3593     /* Now loop over all the thread data blocks that contribute
3594      * to the grid region we (our thread) are operating on.
3595      */
3596     /* Note that ffy_nx/y is equal to the number of grid points
3597      * between the first point of our node grid and the one of the next node.
3598      */
3599     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3600     {
3601         fx     = pmegrid->ci[XX] + sx;
3602         ox     = 0;
3603         bCommX = FALSE;
3604         if (fx < 0)
3605         {
3606             fx    += pmegrids->nc[XX];
3607             ox    -= fft_nx;
3608             bCommX = (pme->nnodes_major > 1);
3609         }
3610         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3611         ox       += pmegrid_g->offset[XX];
3612         if (!bCommX)
3613         {
3614             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3615         }
3616         else
3617         {
3618             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3619         }
3620
3621         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3622         {
3623             fy     = pmegrid->ci[YY] + sy;
3624             oy     = 0;
3625             bCommY = FALSE;
3626             if (fy < 0)
3627             {
3628                 fy    += pmegrids->nc[YY];
3629                 oy    -= fft_ny;
3630                 bCommY = (pme->nnodes_minor > 1);
3631             }
3632             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3633             oy       += pmegrid_g->offset[YY];
3634             if (!bCommY)
3635             {
3636                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3637             }
3638             else
3639             {
3640                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3641             }
3642
3643             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3644             {
3645                 fz = pmegrid->ci[ZZ] + sz;
3646                 oz = 0;
3647                 if (fz < 0)
3648                 {
3649                     fz += pmegrids->nc[ZZ];
3650                     oz -= fft_nz;
3651                 }
3652                 pmegrid_g = &pmegrids->grid_th[fz];
3653                 oz       += pmegrid_g->offset[ZZ];
3654                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3655
3656                 if (sx == 0 && sy == 0 && sz == 0)
3657                 {
3658                     /* We have already added our local contribution
3659                      * before calling this routine, so skip it here.
3660                      */
3661                     continue;
3662                 }
3663
3664                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3665
3666                 pmegrid_f = &pmegrids->grid_th[thread_f];
3667
3668                 grid_th = pmegrid_f->grid;
3669
3670                 nsx = pmegrid_f->s[XX];
3671                 nsy = pmegrid_f->s[YY];
3672                 nsz = pmegrid_f->s[ZZ];
3673
3674 #ifdef DEBUG_PME_REDUCE
3675                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3676                        pme->nodeid, thread, thread_f,
3677                        pme->pmegrid_start_ix,
3678                        pme->pmegrid_start_iy,
3679                        pme->pmegrid_start_iz,
3680                        sx, sy, sz,
3681                        offx-ox, tx1-ox, offx, tx1,
3682                        offy-oy, ty1-oy, offy, ty1,
3683                        offz-oz, tz1-oz, offz, tz1);
3684 #endif
3685
3686                 if (!(bCommX || bCommY))
3687                 {
3688                     /* Copy from the thread local grid to the node grid */
3689                     for (x = offx; x < tx1; x++)
3690                     {
3691                         for (y = offy; y < ty1; y++)
3692                         {
3693                             i0  = (x*fft_my + y)*fft_mz;
3694                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3695                             for (z = offz; z < tz1; z++)
3696                             {
3697                                 fftgrid[i0+z] += grid_th[i0t+z];
3698                             }
3699                         }
3700                     }
3701                 }
3702                 else
3703                 {
3704                     /* The order of this conditional decides
3705                      * where the corner volume gets stored with x+y decomp.
3706                      */
3707                     if (bCommY)
3708                     {
3709                         commbuf = commbuf_y;
3710                         buf_my  = ty1 - offy;
3711                         if (bCommX)
3712                         {
3713                             /* We index commbuf modulo the local grid size */
3714                             commbuf += buf_my*fft_nx*fft_nz;
3715
3716                             bClearBuf   = bClearBufXY;
3717                             bClearBufXY = FALSE;
3718                         }
3719                         else
3720                         {
3721                             bClearBuf  = bClearBufY;
3722                             bClearBufY = FALSE;
3723                         }
3724                     }
3725                     else
3726                     {
3727                         commbuf    = commbuf_x;
3728                         buf_my     = fft_ny;
3729                         bClearBuf  = bClearBufX;
3730                         bClearBufX = FALSE;
3731                     }
3732
3733                     /* Copy to the communication buffer */
3734                     for (x = offx; x < tx1; x++)
3735                     {
3736                         for (y = offy; y < ty1; y++)
3737                         {
3738                             i0  = (x*buf_my + y)*fft_nz;
3739                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3740
3741                             if (bClearBuf)
3742                             {
3743                                 /* First access of commbuf, initialize it */
3744                                 for (z = offz; z < tz1; z++)
3745                                 {
3746                                     commbuf[i0+z]  = grid_th[i0t+z];
3747                                 }
3748                             }
3749                             else
3750                             {
3751                                 for (z = offz; z < tz1; z++)
3752                                 {
3753                                     commbuf[i0+z] += grid_th[i0t+z];
3754                                 }
3755                             }
3756                         }
3757                     }
3758                 }
3759             }
3760         }
3761     }
3762 }
3763
3764
3765 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3766 {
3767     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3768     pme_overlap_t *overlap;
3769     int  send_index0, send_nindex;
3770     int  recv_nindex;
3771 #ifdef GMX_MPI
3772     MPI_Status stat;
3773 #endif
3774     int  send_size_y, recv_size_y;
3775     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3776     real *sendptr, *recvptr;
3777     int  x, y, z, indg, indb;
3778
3779     /* Note that this routine is only used for forward communication.
3780      * Since the force gathering, unlike the charge spreading,
3781      * can be trivially parallelized over the particles,
3782      * the backwards process is much simpler and can use the "old"
3783      * communication setup.
3784      */
3785
3786     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3787                                    local_fft_ndata,
3788                                    local_fft_offset,
3789                                    local_fft_size);
3790
3791     if (pme->nnodes_minor > 1)
3792     {
3793         /* Major dimension */
3794         overlap = &pme->overlap[1];
3795
3796         if (pme->nnodes_major > 1)
3797         {
3798             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3799         }
3800         else
3801         {
3802             size_yx = 0;
3803         }
3804         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3805
3806         send_size_y = overlap->send_size;
3807
3808         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3809         {
3810             send_id       = overlap->send_id[ipulse];
3811             recv_id       = overlap->recv_id[ipulse];
3812             send_index0   =
3813                 overlap->comm_data[ipulse].send_index0 -
3814                 overlap->comm_data[0].send_index0;
3815             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3816             /* We don't use recv_index0, as we always receive starting at 0 */
3817             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3818             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3819
3820             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3821             recvptr = overlap->recvbuf;
3822
3823 #ifdef GMX_MPI
3824             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3825                          send_id, ipulse,
3826                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3827                          recv_id, ipulse,
3828                          overlap->mpi_comm, &stat);
3829 #endif
3830
3831             for (x = 0; x < local_fft_ndata[XX]; x++)
3832             {
3833                 for (y = 0; y < recv_nindex; y++)
3834                 {
3835                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3836                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3837                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3838                     {
3839                         fftgrid[indg+z] += recvptr[indb+z];
3840                     }
3841                 }
3842             }
3843
3844             if (pme->nnodes_major > 1)
3845             {
3846                 /* Copy from the received buffer to the send buffer for dim 0 */
3847                 sendptr = pme->overlap[0].sendbuf;
3848                 for (x = 0; x < size_yx; x++)
3849                 {
3850                     for (y = 0; y < recv_nindex; y++)
3851                     {
3852                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3853                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3854                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3855                         {
3856                             sendptr[indg+z] += recvptr[indb+z];
3857                         }
3858                     }
3859                 }
3860             }
3861         }
3862     }
3863
3864     /* We only support a single pulse here.
3865      * This is not a severe limitation, as this code is only used
3866      * with OpenMP and with OpenMP the (PME) domains can be larger.
3867      */
3868     if (pme->nnodes_major > 1)
3869     {
3870         /* Major dimension */
3871         overlap = &pme->overlap[0];
3872
3873         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3874         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3875
3876         ipulse = 0;
3877
3878         send_id       = overlap->send_id[ipulse];
3879         recv_id       = overlap->recv_id[ipulse];
3880         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3881         /* We don't use recv_index0, as we always receive starting at 0 */
3882         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3883
3884         sendptr = overlap->sendbuf;
3885         recvptr = overlap->recvbuf;
3886
3887         if (debug != NULL)
3888         {
3889             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3890                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3891         }
3892
3893 #ifdef GMX_MPI
3894         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3895                      send_id, ipulse,
3896                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3897                      recv_id, ipulse,
3898                      overlap->mpi_comm, &stat);
3899 #endif
3900
3901         for (x = 0; x < recv_nindex; x++)
3902         {
3903             for (y = 0; y < local_fft_ndata[YY]; y++)
3904             {
3905                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3906                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3907                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3908                 {
3909                     fftgrid[indg+z] += recvptr[indb+z];
3910                 }
3911             }
3912         }
3913     }
3914 }
3915
3916
3917 static void spread_on_grid(gmx_pme_t pme,
3918                            pme_atomcomm_t *atc, pmegrids_t *grids,
3919                            gmx_bool bCalcSplines, gmx_bool bSpread,
3920                            real *fftgrid)
3921 {
3922     int nthread, thread;
3923 #ifdef PME_TIME_THREADS
3924     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3925     static double cs1     = 0, cs2 = 0, cs3 = 0;
3926     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3927     static int cnt        = 0;
3928 #endif
3929
3930     nthread = pme->nthread;
3931     assert(nthread > 0);
3932
3933 #ifdef PME_TIME_THREADS
3934     c1 = omp_cyc_start();
3935 #endif
3936     if (bCalcSplines)
3937     {
3938 #pragma omp parallel for num_threads(nthread) schedule(static)
3939         for (thread = 0; thread < nthread; thread++)
3940         {
3941             int start, end;
3942
3943             start = atc->n* thread   /nthread;
3944             end   = atc->n*(thread+1)/nthread;
3945
3946             /* Compute fftgrid index for all atoms,
3947              * with help of some extra variables.
3948              */
3949             calc_interpolation_idx(pme, atc, start, end, thread);
3950         }
3951     }
3952 #ifdef PME_TIME_THREADS
3953     c1   = omp_cyc_end(c1);
3954     cs1 += (double)c1;
3955 #endif
3956
3957 #ifdef PME_TIME_THREADS
3958     c2 = omp_cyc_start();
3959 #endif
3960 #pragma omp parallel for num_threads(nthread) schedule(static)
3961     for (thread = 0; thread < nthread; thread++)
3962     {
3963         splinedata_t *spline;
3964         pmegrid_t *grid = NULL;
3965
3966         /* make local bsplines  */
3967         if (grids == NULL || !pme->bUseThreads)
3968         {
3969             spline = &atc->spline[0];
3970
3971             spline->n = atc->n;
3972
3973             if (bSpread)
3974             {
3975                 grid = &grids->grid;
3976             }
3977         }
3978         else
3979         {
3980             spline = &atc->spline[thread];
3981
3982             if (grids->nthread == 1)
3983             {
3984                 /* One thread, we operate on all charges */
3985                 spline->n = atc->n;
3986             }
3987             else
3988             {
3989                 /* Get the indices our thread should operate on */
3990                 make_thread_local_ind(atc, thread, spline);
3991             }
3992
3993             grid = &grids->grid_th[thread];
3994         }
3995
3996         if (bCalcSplines)
3997         {
3998             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
3999                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
4000         }
4001
4002         if (bSpread)
4003         {
4004             /* put local atoms on grid. */
4005 #ifdef PME_TIME_SPREAD
4006             ct1a = omp_cyc_start();
4007 #endif
4008             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4009
4010             if (pme->bUseThreads)
4011             {
4012                 copy_local_grid(pme, grids, thread, fftgrid);
4013             }
4014 #ifdef PME_TIME_SPREAD
4015             ct1a          = omp_cyc_end(ct1a);
4016             cs1a[thread] += (double)ct1a;
4017 #endif
4018         }
4019     }
4020 #ifdef PME_TIME_THREADS
4021     c2   = omp_cyc_end(c2);
4022     cs2 += (double)c2;
4023 #endif
4024
4025     if (bSpread && pme->bUseThreads)
4026     {
4027 #ifdef PME_TIME_THREADS
4028         c3 = omp_cyc_start();
4029 #endif
4030 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4031         for (thread = 0; thread < grids->nthread; thread++)
4032         {
4033             reduce_threadgrid_overlap(pme, grids, thread,
4034                                       fftgrid,
4035                                       pme->overlap[0].sendbuf,
4036                                       pme->overlap[1].sendbuf);
4037         }
4038 #ifdef PME_TIME_THREADS
4039         c3   = omp_cyc_end(c3);
4040         cs3 += (double)c3;
4041 #endif
4042
4043         if (pme->nnodes > 1)
4044         {
4045             /* Communicate the overlapping part of the fftgrid.
4046              * For this communication call we need to check pme->bUseThreads
4047              * to have all ranks communicate here, regardless of pme->nthread.
4048              */
4049             sum_fftgrid_dd(pme, fftgrid);
4050         }
4051     }
4052
4053 #ifdef PME_TIME_THREADS
4054     cnt++;
4055     if (cnt % 20 == 0)
4056     {
4057         printf("idx %.2f spread %.2f red %.2f",
4058                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4059 #ifdef PME_TIME_SPREAD
4060         for (thread = 0; thread < nthread; thread++)
4061         {
4062             printf(" %.2f", cs1a[thread]*1e-9);
4063         }
4064 #endif
4065         printf("\n");
4066     }
4067 #endif
4068 }
4069
4070
4071 static void dump_grid(FILE *fp,
4072                       int sx, int sy, int sz, int nx, int ny, int nz,
4073                       int my, int mz, const real *g)
4074 {
4075     int x, y, z;
4076
4077     for (x = 0; x < nx; x++)
4078     {
4079         for (y = 0; y < ny; y++)
4080         {
4081             for (z = 0; z < nz; z++)
4082             {
4083                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4084                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4085             }
4086         }
4087     }
4088 }
4089
4090 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4091 {
4092     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4093
4094     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4095                                    local_fft_ndata,
4096                                    local_fft_offset,
4097                                    local_fft_size);
4098
4099     dump_grid(stderr,
4100               pme->pmegrid_start_ix,
4101               pme->pmegrid_start_iy,
4102               pme->pmegrid_start_iz,
4103               pme->pmegrid_nx-pme->pme_order+1,
4104               pme->pmegrid_ny-pme->pme_order+1,
4105               pme->pmegrid_nz-pme->pme_order+1,
4106               local_fft_size[YY],
4107               local_fft_size[ZZ],
4108               fftgrid);
4109 }
4110
4111
4112 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4113 {
4114     pme_atomcomm_t *atc;
4115     pmegrids_t *grid;
4116
4117     if (pme->nnodes > 1)
4118     {
4119         gmx_incons("gmx_pme_calc_energy called in parallel");
4120     }
4121     if (pme->bFEP > 1)
4122     {
4123         gmx_incons("gmx_pme_calc_energy with free energy");
4124     }
4125
4126     atc            = &pme->atc_energy;
4127     atc->nthread   = 1;
4128     if (atc->spline == NULL)
4129     {
4130         snew(atc->spline, atc->nthread);
4131     }
4132     atc->nslab     = 1;
4133     atc->bSpread   = TRUE;
4134     atc->pme_order = pme->pme_order;
4135     atc->n         = n;
4136     pme_realloc_atomcomm_things(atc);
4137     atc->x         = x;
4138     atc->q         = q;
4139
4140     /* We only use the A-charges grid */
4141     grid = &pme->pmegridA;
4142
4143     /* Only calculate the spline coefficients, don't actually spread */
4144     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4145
4146     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4147 }
4148
4149
4150 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4151                                    gmx_walltime_accounting_t walltime_accounting,
4152                                    t_nrnb *nrnb, t_inputrec *ir,
4153                                    gmx_large_int_t step)
4154 {
4155     /* Reset all the counters related to performance over the run */
4156     wallcycle_stop(wcycle, ewcRUN);
4157     wallcycle_reset_all(wcycle);
4158     init_nrnb(nrnb);
4159     if (ir->nsteps >= 0)
4160     {
4161         /* ir->nsteps is not used here, but we update it for consistency */
4162         ir->nsteps -= step - ir->init_step;
4163     }
4164     ir->init_step = step;
4165     wallcycle_start(wcycle, ewcRUN);
4166     walltime_accounting_start(walltime_accounting);
4167 }
4168
4169
4170 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4171                                ivec grid_size,
4172                                t_commrec *cr, t_inputrec *ir,
4173                                gmx_pme_t *pme_ret)
4174 {
4175     int ind;
4176     gmx_pme_t pme = NULL;
4177
4178     ind = 0;
4179     while (ind < *npmedata)
4180     {
4181         pme = (*pmedata)[ind];
4182         if (pme->nkx == grid_size[XX] &&
4183             pme->nky == grid_size[YY] &&
4184             pme->nkz == grid_size[ZZ])
4185         {
4186             *pme_ret = pme;
4187
4188             return;
4189         }
4190
4191         ind++;
4192     }
4193
4194     (*npmedata)++;
4195     srenew(*pmedata, *npmedata);
4196
4197     /* Generate a new PME data structure, copying part of the old pointers */
4198     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4199
4200     *pme_ret = (*pmedata)[ind];
4201 }
4202
4203
4204 int gmx_pmeonly(gmx_pme_t pme,
4205                 t_commrec *cr,    t_nrnb *nrnb,
4206                 gmx_wallcycle_t wcycle,
4207                 gmx_walltime_accounting_t walltime_accounting,
4208                 real ewaldcoeff,
4209                 t_inputrec *ir)
4210 {
4211     int npmedata;
4212     gmx_pme_t *pmedata;
4213     gmx_pme_pp_t pme_pp;
4214     int  ret;
4215     int  natoms;
4216     matrix box;
4217     rvec *x_pp      = NULL, *f_pp = NULL;
4218     real *chargeA   = NULL, *chargeB = NULL;
4219     real lambda     = 0;
4220     int  maxshift_x = 0, maxshift_y = 0;
4221     real energy, dvdlambda;
4222     matrix vir;
4223     float cycles;
4224     int  count;
4225     gmx_bool bEnerVir;
4226     gmx_large_int_t step, step_rel;
4227     ivec grid_switch;
4228
4229     /* This data will only use with PME tuning, i.e. switching PME grids */
4230     npmedata = 1;
4231     snew(pmedata, npmedata);
4232     pmedata[0] = pme;
4233
4234     pme_pp = gmx_pme_pp_init(cr);
4235
4236     init_nrnb(nrnb);
4237
4238     count = 0;
4239     do /****** this is a quasi-loop over time steps! */
4240     {
4241         /* The reason for having a loop here is PME grid tuning/switching */
4242         do
4243         {
4244             /* Domain decomposition */
4245             ret = gmx_pme_recv_q_x(pme_pp,
4246                                    &natoms,
4247                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4248                                    &maxshift_x, &maxshift_y,
4249                                    &pme->bFEP, &lambda,
4250                                    &bEnerVir,
4251                                    &step,
4252                                    grid_switch, &ewaldcoeff);
4253
4254             if (ret == pmerecvqxSWITCHGRID)
4255             {
4256                 /* Switch the PME grid to grid_switch */
4257                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4258             }
4259
4260             if (ret == pmerecvqxRESETCOUNTERS)
4261             {
4262                 /* Reset the cycle and flop counters */
4263                 reset_pmeonly_counters(wcycle, walltime_accounting, nrnb, ir, step);
4264             }
4265         }
4266         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4267
4268         if (ret == pmerecvqxFINISH)
4269         {
4270             /* We should stop: break out of the loop */
4271             break;
4272         }
4273
4274         step_rel = step - ir->init_step;
4275
4276         if (count == 0)
4277         {
4278             wallcycle_start(wcycle, ewcRUN);
4279             walltime_accounting_start(walltime_accounting);
4280         }
4281
4282         wallcycle_start(wcycle, ewcPMEMESH);
4283
4284         dvdlambda = 0;
4285         clear_mat(vir);
4286         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4287                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4288                    &energy, lambda, &dvdlambda,
4289                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4290
4291         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4292
4293         gmx_pme_send_force_vir_ener(pme_pp,
4294                                     f_pp, vir, energy, dvdlambda,
4295                                     cycles);
4296
4297         count++;
4298     } /***** end of quasi-loop, we stop with the break above */
4299     while (TRUE);
4300
4301     walltime_accounting_end(walltime_accounting);
4302
4303     return 0;
4304 }
4305
4306 int gmx_pme_do(gmx_pme_t pme,
4307                int start,       int homenr,
4308                rvec x[],        rvec f[],
4309                real *chargeA,   real *chargeB,
4310                matrix box, t_commrec *cr,
4311                int  maxshift_x, int maxshift_y,
4312                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4313                matrix vir,      real ewaldcoeff,
4314                real *energy,    real lambda,
4315                real *dvdlambda, int flags)
4316 {
4317     int     q, d, i, j, ntot, npme;
4318     int     nx, ny, nz;
4319     int     n_d, local_ny;
4320     pme_atomcomm_t *atc = NULL;
4321     pmegrids_t *pmegrid = NULL;
4322     real    *grid       = NULL;
4323     real    *ptr;
4324     rvec    *x_d, *f_d;
4325     real    *charge = NULL, *q_d;
4326     real    energy_AB[2];
4327     matrix  vir_AB[2];
4328     gmx_bool bClearF;
4329     gmx_parallel_3dfft_t pfft_setup;
4330     real *  fftgrid;
4331     t_complex * cfftgrid;
4332     int     thread;
4333     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4334     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4335
4336     assert(pme->nnodes > 0);
4337     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4338
4339     if (pme->nnodes > 1)
4340     {
4341         atc      = &pme->atc[0];
4342         atc->npd = homenr;
4343         if (atc->npd > atc->pd_nalloc)
4344         {
4345             atc->pd_nalloc = over_alloc_dd(atc->npd);
4346             srenew(atc->pd, atc->pd_nalloc);
4347         }
4348         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4349     }
4350     else
4351     {
4352         /* This could be necessary for TPI */
4353         pme->atc[0].n = homenr;
4354     }
4355
4356     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4357     {
4358         if (q == 0)
4359         {
4360             pmegrid    = &pme->pmegridA;
4361             fftgrid    = pme->fftgridA;
4362             cfftgrid   = pme->cfftgridA;
4363             pfft_setup = pme->pfft_setupA;
4364             charge     = chargeA+start;
4365         }
4366         else
4367         {
4368             pmegrid    = &pme->pmegridB;
4369             fftgrid    = pme->fftgridB;
4370             cfftgrid   = pme->cfftgridB;
4371             pfft_setup = pme->pfft_setupB;
4372             charge     = chargeB+start;
4373         }
4374         grid = pmegrid->grid.grid;
4375         /* Unpack structure */
4376         if (debug)
4377         {
4378             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4379                     cr->nnodes, cr->nodeid);
4380             fprintf(debug, "Grid = %p\n", (void*)grid);
4381             if (grid == NULL)
4382             {
4383                 gmx_fatal(FARGS, "No grid!");
4384             }
4385         }
4386         where();
4387
4388         m_inv_ur0(box, pme->recipbox);
4389
4390         if (pme->nnodes == 1)
4391         {
4392             atc = &pme->atc[0];
4393             if (DOMAINDECOMP(cr))
4394             {
4395                 atc->n = homenr;
4396                 pme_realloc_atomcomm_things(atc);
4397             }
4398             atc->x = x;
4399             atc->q = charge;
4400             atc->f = f;
4401         }
4402         else
4403         {
4404             wallcycle_start(wcycle, ewcPME_REDISTXF);
4405             for (d = pme->ndecompdim-1; d >= 0; d--)
4406             {
4407                 if (d == pme->ndecompdim-1)
4408                 {
4409                     n_d = homenr;
4410                     x_d = x + start;
4411                     q_d = charge;
4412                 }
4413                 else
4414                 {
4415                     n_d = pme->atc[d+1].n;
4416                     x_d = atc->x;
4417                     q_d = atc->q;
4418                 }
4419                 atc      = &pme->atc[d];
4420                 atc->npd = n_d;
4421                 if (atc->npd > atc->pd_nalloc)
4422                 {
4423                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4424                     srenew(atc->pd, atc->pd_nalloc);
4425                 }
4426                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4427                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4428                 where();
4429
4430                 /* Redistribute x (only once) and qA or qB */
4431                 if (DOMAINDECOMP(cr))
4432                 {
4433                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4434                 }
4435                 else
4436                 {
4437                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4438                 }
4439             }
4440             where();
4441
4442             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4443         }
4444
4445         if (debug)
4446         {
4447             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4448                     cr->nodeid, atc->n);
4449         }
4450
4451         if (flags & GMX_PME_SPREAD_Q)
4452         {
4453             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4454
4455             /* Spread the charges on a grid */
4456             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4457
4458             if (q == 0)
4459             {
4460                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4461             }
4462             inc_nrnb(nrnb, eNR_SPREADQBSP,
4463                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4464
4465             if (!pme->bUseThreads)
4466             {
4467                 wrap_periodic_pmegrid(pme, grid);
4468
4469                 /* sum contributions to local grid from other nodes */
4470 #ifdef GMX_MPI
4471                 if (pme->nnodes > 1)
4472                 {
4473                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4474                     where();
4475                 }
4476 #endif
4477
4478                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4479             }
4480
4481             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4482
4483             /*
4484                dump_local_fftgrid(pme,fftgrid);
4485                exit(0);
4486              */
4487         }
4488
4489         /* Here we start a large thread parallel region */
4490 #pragma omp parallel num_threads(pme->nthread) private(thread)
4491         {
4492             thread = gmx_omp_get_thread_num();
4493             if (flags & GMX_PME_SOLVE)
4494             {
4495                 int loop_count;
4496
4497                 /* do 3d-fft */
4498                 if (thread == 0)
4499                 {
4500                     wallcycle_start(wcycle, ewcPME_FFT);
4501                 }
4502                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4503                                            thread, wcycle);
4504                 if (thread == 0)
4505                 {
4506                     wallcycle_stop(wcycle, ewcPME_FFT);
4507                 }
4508                 where();
4509
4510                 /* solve in k-space for our local cells */
4511                 if (thread == 0)
4512                 {
4513                     wallcycle_start(wcycle, ewcPME_SOLVE);
4514                 }
4515                 loop_count =
4516                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4517                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4518                                   bCalcEnerVir,
4519                                   pme->nthread, thread);
4520                 if (thread == 0)
4521                 {
4522                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4523                     where();
4524                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4525                 }
4526             }
4527
4528             if (bCalcF)
4529             {
4530                 /* do 3d-invfft */
4531                 if (thread == 0)
4532                 {
4533                     where();
4534                     wallcycle_start(wcycle, ewcPME_FFT);
4535                 }
4536                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4537                                            thread, wcycle);
4538                 if (thread == 0)
4539                 {
4540                     wallcycle_stop(wcycle, ewcPME_FFT);
4541
4542                     where();
4543
4544                     if (pme->nodeid == 0)
4545                     {
4546                         ntot  = pme->nkx*pme->nky*pme->nkz;
4547                         npme  = ntot*log((real)ntot)/log(2.0);
4548                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4549                     }
4550
4551                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4552                 }
4553
4554                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4555             }
4556         }
4557         /* End of thread parallel section.
4558          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4559          */
4560
4561         if (bCalcF)
4562         {
4563             /* distribute local grid to all nodes */
4564 #ifdef GMX_MPI
4565             if (pme->nnodes > 1)
4566             {
4567                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4568             }
4569 #endif
4570             where();
4571
4572             unwrap_periodic_pmegrid(pme, grid);
4573
4574             /* interpolate forces for our local atoms */
4575
4576             where();
4577
4578             /* If we are running without parallelization,
4579              * atc->f is the actual force array, not a buffer,
4580              * therefore we should not clear it.
4581              */
4582             bClearF = (q == 0 && PAR(cr));
4583 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4584             for (thread = 0; thread < pme->nthread; thread++)
4585             {
4586                 gather_f_bsplines(pme, grid, bClearF, atc,
4587                                   &atc->spline[thread],
4588                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4589             }
4590
4591             where();
4592
4593             inc_nrnb(nrnb, eNR_GATHERFBSP,
4594                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4595             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4596         }
4597
4598         if (bCalcEnerVir)
4599         {
4600             /* This should only be called on the master thread
4601              * and after the threads have synchronized.
4602              */
4603             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4604         }
4605     } /* of q-loop */
4606
4607     if (bCalcF && pme->nnodes > 1)
4608     {
4609         wallcycle_start(wcycle, ewcPME_REDISTXF);
4610         for (d = 0; d < pme->ndecompdim; d++)
4611         {
4612             atc = &pme->atc[d];
4613             if (d == pme->ndecompdim - 1)
4614             {
4615                 n_d = homenr;
4616                 f_d = f + start;
4617             }
4618             else
4619             {
4620                 n_d = pme->atc[d+1].n;
4621                 f_d = pme->atc[d+1].f;
4622             }
4623             if (DOMAINDECOMP(cr))
4624             {
4625                 dd_pmeredist_f(pme, atc, n_d, f_d,
4626                                d == pme->ndecompdim-1 && pme->bPPnode);
4627             }
4628             else
4629             {
4630                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4631             }
4632         }
4633
4634         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4635     }
4636     where();
4637
4638     if (bCalcEnerVir)
4639     {
4640         if (!pme->bFEP)
4641         {
4642             *energy = energy_AB[0];
4643             m_add(vir, vir_AB[0], vir);
4644         }
4645         else
4646         {
4647             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4648             *dvdlambda += energy_AB[1] - energy_AB[0];
4649             for (i = 0; i < DIM; i++)
4650             {
4651                 for (j = 0; j < DIM; j++)
4652                 {
4653                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4654                         lambda*vir_AB[1][i][j];
4655                 }
4656             }
4657         }
4658     }
4659     else
4660     {
4661         *energy = 0;
4662     }
4663
4664     if (debug)
4665     {
4666         fprintf(debug, "PME mesh energy: %g\n", *energy);
4667     }
4668
4669     return 0;
4670 }