Re-fixed PME bug with high OpenMP thread count
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #ifdef HAVE_CONFIG_H
61 #include <config.h>
62 #endif
63
64 #ifdef GMX_LIB_MPI
65 #include <mpi.h>
66 #endif
67 #ifdef GMX_THREAD_MPI
68 #include "tmpi.h"
69 #endif
70
71 #include <stdio.h>
72 #include <string.h>
73 #include <math.h>
74 #include <assert.h>
75 #include "typedefs.h"
76 #include "txtdump.h"
77 #include "vec.h"
78 #include "gmxcomplex.h"
79 #include "smalloc.h"
80 #include "futil.h"
81 #include "coulomb.h"
82 #include "gmx_fatal.h"
83 #include "pme.h"
84 #include "network.h"
85 #include "physics.h"
86 #include "nrnb.h"
87 #include "copyrite.h"
88 #include "gmx_wallcycle.h"
89 #include "gmx_parallel_3dfft.h"
90 #include "pdbio.h"
91 #include "gmx_cyclecounter.h"
92 #include "gmx_omp.h"
93
94 /* Include the SIMD macro file and then check for support */
95 #include "gmx_simd_macros.h"
96 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
97 /* Turn on arbitrary width SIMD intrinsics for PME solve */
98 #define PME_SIMD
99 #endif
100
101 /* Include the 4-wide SIMD macro file */
102 #include "gmx_simd4_macros.h"
103 /* Check if we have 4-wide SIMD macro support */
104 #ifdef GMX_HAVE_SIMD4_MACROS
105 /* Do PME spread and gather with 4-wide SIMD.
106  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
107  */
108 #define PME_SIMD4_SPREAD_GATHER
109
110 #ifdef GMX_SIMD4_HAVE_UNALIGNED
111 /* With PME-order=4 on x86, unaligned load+store is slightly faster
112  * than doubling all SIMD operations when using aligned load+store.
113  */
114 #define PME_SIMD4_UNALIGNED
115 #endif
116 #endif
117
118
119 #include "mpelogging.h"
120
121 #define DFT_TOL 1e-7
122 /* #define PRT_FORCE */
123 /* conditions for on the fly time-measurement */
124 /* #define TAKETIME (step > 1 && timesteps < 10) */
125 #define TAKETIME FALSE
126
127 /* #define PME_TIME_THREADS */
128
129 #ifdef GMX_DOUBLE
130 #define mpi_type MPI_DOUBLE
131 #else
132 #define mpi_type MPI_FLOAT
133 #endif
134
135 #ifdef PME_SIMD4_SPREAD_GATHER
136 #define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
137 #else
138 /* We can use any alignment, apart from 0, so we use 4 reals */
139 #define SIMD4_ALIGNMENT  (4*sizeof(real))
140 #endif
141
142 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
143  * to preserve alignment.
144  */
145 #define GMX_CACHE_SEP 64
146
147 /* We only define a maximum to be able to use local arrays without allocation.
148  * An order larger than 12 should never be needed, even for test cases.
149  * If needed it can be changed here.
150  */
151 #define PME_ORDER_MAX 12
152
153 /* Internal datastructures */
154 typedef struct {
155     int send_index0;
156     int send_nindex;
157     int recv_index0;
158     int recv_nindex;
159     int recv_size;   /* Receive buffer width, used with OpenMP */
160 } pme_grid_comm_t;
161
162 typedef struct {
163 #ifdef GMX_MPI
164     MPI_Comm         mpi_comm;
165 #endif
166     int              nnodes, nodeid;
167     int             *s2g0;
168     int             *s2g1;
169     int              noverlap_nodes;
170     int             *send_id, *recv_id;
171     int              send_size; /* Send buffer width, used with OpenMP */
172     pme_grid_comm_t *comm_data;
173     real            *sendbuf;
174     real            *recvbuf;
175 } pme_overlap_t;
176
177 typedef struct {
178     int *n;      /* Cumulative counts of the number of particles per thread */
179     int  nalloc; /* Allocation size of i */
180     int *i;      /* Particle indices ordered on thread index (n) */
181 } thread_plist_t;
182
183 typedef struct {
184     int      *thread_one;
185     int       n;
186     int      *ind;
187     splinevec theta;
188     real     *ptr_theta_z;
189     splinevec dtheta;
190     real     *ptr_dtheta_z;
191 } splinedata_t;
192
193 typedef struct {
194     int      dimind;        /* The index of the dimension, 0=x, 1=y */
195     int      nslab;
196     int      nodeid;
197 #ifdef GMX_MPI
198     MPI_Comm mpi_comm;
199 #endif
200
201     int     *node_dest;     /* The nodes to send x and q to with DD */
202     int     *node_src;      /* The nodes to receive x and q from with DD */
203     int     *buf_index;     /* Index for commnode into the buffers */
204
205     int      maxshift;
206
207     int      npd;
208     int      pd_nalloc;
209     int     *pd;
210     int     *count;         /* The number of atoms to send to each node */
211     int    **count_thread;
212     int     *rcount;        /* The number of atoms to receive */
213
214     int      n;
215     int      nalloc;
216     rvec    *x;
217     real    *q;
218     rvec    *f;
219     gmx_bool bSpread;       /* These coordinates are used for spreading */
220     int      pme_order;
221     ivec    *idx;
222     rvec    *fractx;            /* Fractional coordinate relative to the
223                                  * lower cell boundary
224                                  */
225     int             nthread;
226     int            *thread_idx; /* Which thread should spread which charge */
227     thread_plist_t *thread_plist;
228     splinedata_t   *spline;
229 } pme_atomcomm_t;
230
231 #define FLBS  3
232 #define FLBSZ 4
233
234 typedef struct {
235     ivec  ci;     /* The spatial location of this grid         */
236     ivec  n;      /* The used size of *grid, including order-1 */
237     ivec  offset; /* The grid offset from the full node grid   */
238     int   order;  /* PME spreading order                       */
239     ivec  s;      /* The allocated size of *grid, s >= n       */
240     real *grid;   /* The grid local thread, size n             */
241 } pmegrid_t;
242
243 typedef struct {
244     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
245     int        nthread;      /* The number of threads operating on this grid     */
246     ivec       nc;           /* The local spatial decomposition over the threads */
247     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
248     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
249     int      **g2t;          /* The grid to thread index                         */
250     ivec       nthread_comm; /* The number of threads to communicate with        */
251 } pmegrids_t;
252
253
254 typedef struct {
255 #ifdef PME_SIMD4_SPREAD_GATHER
256     /* Masks for 4-wide SIMD aligned spreading and gathering */
257     gmx_simd4_pb mask_S0[6], mask_S1[6];
258 #else
259     int    dummy; /* C89 requires that struct has at least one member */
260 #endif
261 } pme_spline_work_t;
262
263 typedef struct {
264     /* work data for solve_pme */
265     int      nalloc;
266     real *   mhx;
267     real *   mhy;
268     real *   mhz;
269     real *   m2;
270     real *   denom;
271     real *   tmp1_alloc;
272     real *   tmp1;
273     real *   eterm;
274     real *   m2inv;
275
276     real     energy;
277     matrix   vir;
278 } pme_work_t;
279
280 typedef struct gmx_pme {
281     int           ndecompdim; /* The number of decomposition dimensions */
282     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
283     int           nodeid_major;
284     int           nodeid_minor;
285     int           nnodes;    /* The number of nodes doing PME */
286     int           nnodes_major;
287     int           nnodes_minor;
288
289     MPI_Comm      mpi_comm;
290     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
291 #ifdef GMX_MPI
292     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
293 #endif
294
295     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
296     int        nthread;       /* The number of threads doing PME on our rank */
297
298     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
299     gmx_bool   bFEP;          /* Compute Free energy contribution */
300     int        nkx, nky, nkz; /* Grid dimensions */
301     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
302     int        pme_order;
303     real       epsilon_r;
304
305     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
306     pmegrids_t pmegridB;
307     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
308     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
309     /* pmegrid_nz might be larger than strictly necessary to ensure
310      * memory alignment, pmegrid_nz_base gives the real base size.
311      */
312     int     pmegrid_nz_base;
313     /* The local PME grid starting indices */
314     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
315
316     /* Work data for spreading and gathering */
317     pme_spline_work_t    *spline_work;
318
319     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
320     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
321     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
322
323     t_complex            *cfftgridA;  /* Grids for complex FFT data */
324     t_complex            *cfftgridB;
325     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
326
327     gmx_parallel_3dfft_t  pfft_setupA;
328     gmx_parallel_3dfft_t  pfft_setupB;
329
330     int                  *nnx, *nny, *nnz;
331     real                 *fshx, *fshy, *fshz;
332
333     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
334     matrix                recipbox;
335     splinevec             bsp_mod;
336
337     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
338
339     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
340
341     rvec                 *bufv;       /* Communication buffer */
342     real                 *bufr;       /* Communication buffer */
343     int                   buf_nalloc; /* The communication buffer size */
344
345     /* thread local work data for solve_pme */
346     pme_work_t *work;
347
348     /* Work data for PME_redist */
349     gmx_bool redist_init;
350     int *    scounts;
351     int *    rcounts;
352     int *    sdispls;
353     int *    rdispls;
354     int *    sidx;
355     int *    idxa;
356     real *   redist_buf;
357     int      redist_buf_nalloc;
358
359     /* Work data for sum_qgrid */
360     real *   sum_qgrid_tmp;
361     real *   sum_qgrid_dd_tmp;
362 } t_gmx_pme;
363
364
365 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
366                                    int start, int end, int thread)
367 {
368     int             i;
369     int            *idxptr, tix, tiy, tiz;
370     real           *xptr, *fptr, tx, ty, tz;
371     real            rxx, ryx, ryy, rzx, rzy, rzz;
372     int             nx, ny, nz;
373     int             start_ix, start_iy, start_iz;
374     int            *g2tx, *g2ty, *g2tz;
375     gmx_bool        bThreads;
376     int            *thread_idx = NULL;
377     thread_plist_t *tpl        = NULL;
378     int            *tpl_n      = NULL;
379     int             thread_i;
380
381     nx  = pme->nkx;
382     ny  = pme->nky;
383     nz  = pme->nkz;
384
385     start_ix = pme->pmegrid_start_ix;
386     start_iy = pme->pmegrid_start_iy;
387     start_iz = pme->pmegrid_start_iz;
388
389     rxx = pme->recipbox[XX][XX];
390     ryx = pme->recipbox[YY][XX];
391     ryy = pme->recipbox[YY][YY];
392     rzx = pme->recipbox[ZZ][XX];
393     rzy = pme->recipbox[ZZ][YY];
394     rzz = pme->recipbox[ZZ][ZZ];
395
396     g2tx = pme->pmegridA.g2t[XX];
397     g2ty = pme->pmegridA.g2t[YY];
398     g2tz = pme->pmegridA.g2t[ZZ];
399
400     bThreads = (atc->nthread > 1);
401     if (bThreads)
402     {
403         thread_idx = atc->thread_idx;
404
405         tpl   = &atc->thread_plist[thread];
406         tpl_n = tpl->n;
407         for (i = 0; i < atc->nthread; i++)
408         {
409             tpl_n[i] = 0;
410         }
411     }
412
413     for (i = start; i < end; i++)
414     {
415         xptr   = atc->x[i];
416         idxptr = atc->idx[i];
417         fptr   = atc->fractx[i];
418
419         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
420         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
421         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
422         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
423
424         tix = (int)(tx);
425         tiy = (int)(ty);
426         tiz = (int)(tz);
427
428         /* Because decomposition only occurs in x and y,
429          * we never have a fraction correction in z.
430          */
431         fptr[XX] = tx - tix + pme->fshx[tix];
432         fptr[YY] = ty - tiy + pme->fshy[tiy];
433         fptr[ZZ] = tz - tiz;
434
435         idxptr[XX] = pme->nnx[tix];
436         idxptr[YY] = pme->nny[tiy];
437         idxptr[ZZ] = pme->nnz[tiz];
438
439 #ifdef DEBUG
440         range_check(idxptr[XX], 0, pme->pmegrid_nx);
441         range_check(idxptr[YY], 0, pme->pmegrid_ny);
442         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
443 #endif
444
445         if (bThreads)
446         {
447             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
448             thread_idx[i] = thread_i;
449             tpl_n[thread_i]++;
450         }
451     }
452
453     if (bThreads)
454     {
455         /* Make a list of particle indices sorted on thread */
456
457         /* Get the cumulative count */
458         for (i = 1; i < atc->nthread; i++)
459         {
460             tpl_n[i] += tpl_n[i-1];
461         }
462         /* The current implementation distributes particles equally
463          * over the threads, so we could actually allocate for that
464          * in pme_realloc_atomcomm_things.
465          */
466         if (tpl_n[atc->nthread-1] > tpl->nalloc)
467         {
468             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
469             srenew(tpl->i, tpl->nalloc);
470         }
471         /* Set tpl_n to the cumulative start */
472         for (i = atc->nthread-1; i >= 1; i--)
473         {
474             tpl_n[i] = tpl_n[i-1];
475         }
476         tpl_n[0] = 0;
477
478         /* Fill our thread local array with indices sorted on thread */
479         for (i = start; i < end; i++)
480         {
481             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
482         }
483         /* Now tpl_n contains the cummulative count again */
484     }
485 }
486
487 static void make_thread_local_ind(pme_atomcomm_t *atc,
488                                   int thread, splinedata_t *spline)
489 {
490     int             n, t, i, start, end;
491     thread_plist_t *tpl;
492
493     /* Combine the indices made by each thread into one index */
494
495     n     = 0;
496     start = 0;
497     for (t = 0; t < atc->nthread; t++)
498     {
499         tpl = &atc->thread_plist[t];
500         /* Copy our part (start - end) from the list of thread t */
501         if (thread > 0)
502         {
503             start = tpl->n[thread-1];
504         }
505         end = tpl->n[thread];
506         for (i = start; i < end; i++)
507         {
508             spline->ind[n++] = tpl->i[i];
509         }
510     }
511
512     spline->n = n;
513 }
514
515
516 static void pme_calc_pidx(int start, int end,
517                           matrix recipbox, rvec x[],
518                           pme_atomcomm_t *atc, int *count)
519 {
520     int   nslab, i;
521     int   si;
522     real *xptr, s;
523     real  rxx, ryx, rzx, ryy, rzy;
524     int  *pd;
525
526     /* Calculate PME task index (pidx) for each grid index.
527      * Here we always assign equally sized slabs to each node
528      * for load balancing reasons (the PME grid spacing is not used).
529      */
530
531     nslab = atc->nslab;
532     pd    = atc->pd;
533
534     /* Reset the count */
535     for (i = 0; i < nslab; i++)
536     {
537         count[i] = 0;
538     }
539
540     if (atc->dimind == 0)
541     {
542         rxx = recipbox[XX][XX];
543         ryx = recipbox[YY][XX];
544         rzx = recipbox[ZZ][XX];
545         /* Calculate the node index in x-dimension */
546         for (i = start; i < end; i++)
547         {
548             xptr   = x[i];
549             /* Fractional coordinates along box vectors */
550             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
551             si    = (int)(s + 2*nslab) % nslab;
552             pd[i] = si;
553             count[si]++;
554         }
555     }
556     else
557     {
558         ryy = recipbox[YY][YY];
559         rzy = recipbox[ZZ][YY];
560         /* Calculate the node index in y-dimension */
561         for (i = start; i < end; i++)
562         {
563             xptr   = x[i];
564             /* Fractional coordinates along box vectors */
565             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
566             si    = (int)(s + 2*nslab) % nslab;
567             pd[i] = si;
568             count[si]++;
569         }
570     }
571 }
572
573 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
574                                   pme_atomcomm_t *atc)
575 {
576     int nthread, thread, slab;
577
578     nthread = atc->nthread;
579
580 #pragma omp parallel for num_threads(nthread) schedule(static)
581     for (thread = 0; thread < nthread; thread++)
582     {
583         pme_calc_pidx(natoms* thread   /nthread,
584                       natoms*(thread+1)/nthread,
585                       recipbox, x, atc, atc->count_thread[thread]);
586     }
587     /* Non-parallel reduction, since nslab is small */
588
589     for (thread = 1; thread < nthread; thread++)
590     {
591         for (slab = 0; slab < atc->nslab; slab++)
592         {
593             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
594         }
595     }
596 }
597
598 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
599 {
600     const int padding = 4;
601     int       i;
602
603     srenew(th[XX], nalloc);
604     srenew(th[YY], nalloc);
605     /* In z we add padding, this is only required for the aligned SIMD code */
606     sfree_aligned(*ptr_z);
607     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
608     th[ZZ] = *ptr_z + padding;
609
610     for (i = 0; i < padding; i++)
611     {
612         (*ptr_z)[               i] = 0;
613         (*ptr_z)[padding+nalloc+i] = 0;
614     }
615 }
616
617 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
618 {
619     int i, d;
620
621     srenew(spline->ind, atc->nalloc);
622     /* Initialize the index to identity so it works without threads */
623     for (i = 0; i < atc->nalloc; i++)
624     {
625         spline->ind[i] = i;
626     }
627
628     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
629                       atc->pme_order*atc->nalloc);
630     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
631                       atc->pme_order*atc->nalloc);
632 }
633
634 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
635 {
636     int nalloc_old, i, j, nalloc_tpl;
637
638     /* We have to avoid a NULL pointer for atc->x to avoid
639      * possible fatal errors in MPI routines.
640      */
641     if (atc->n > atc->nalloc || atc->nalloc == 0)
642     {
643         nalloc_old  = atc->nalloc;
644         atc->nalloc = over_alloc_dd(max(atc->n, 1));
645
646         if (atc->nslab > 1)
647         {
648             srenew(atc->x, atc->nalloc);
649             srenew(atc->q, atc->nalloc);
650             srenew(atc->f, atc->nalloc);
651             for (i = nalloc_old; i < atc->nalloc; i++)
652             {
653                 clear_rvec(atc->f[i]);
654             }
655         }
656         if (atc->bSpread)
657         {
658             srenew(atc->fractx, atc->nalloc);
659             srenew(atc->idx, atc->nalloc);
660
661             if (atc->nthread > 1)
662             {
663                 srenew(atc->thread_idx, atc->nalloc);
664             }
665
666             for (i = 0; i < atc->nthread; i++)
667             {
668                 pme_realloc_splinedata(&atc->spline[i], atc);
669             }
670         }
671     }
672 }
673
674 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
675                          int n, gmx_bool bXF, rvec *x_f, real *charge,
676                          pme_atomcomm_t *atc)
677 /* Redistribute particle data for PME calculation */
678 /* domain decomposition by x coordinate           */
679 {
680     int *idxa;
681     int  i, ii;
682
683     if (FALSE == pme->redist_init)
684     {
685         snew(pme->scounts, atc->nslab);
686         snew(pme->rcounts, atc->nslab);
687         snew(pme->sdispls, atc->nslab);
688         snew(pme->rdispls, atc->nslab);
689         snew(pme->sidx, atc->nslab);
690         pme->redist_init = TRUE;
691     }
692     if (n > pme->redist_buf_nalloc)
693     {
694         pme->redist_buf_nalloc = over_alloc_dd(n);
695         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
696     }
697
698     pme->idxa = atc->pd;
699
700 #ifdef GMX_MPI
701     if (forw && bXF)
702     {
703         /* forward, redistribution from pp to pme */
704
705         /* Calculate send counts and exchange them with other nodes */
706         for (i = 0; (i < atc->nslab); i++)
707         {
708             pme->scounts[i] = 0;
709         }
710         for (i = 0; (i < n); i++)
711         {
712             pme->scounts[pme->idxa[i]]++;
713         }
714         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
715
716         /* Calculate send and receive displacements and index into send
717            buffer */
718         pme->sdispls[0] = 0;
719         pme->rdispls[0] = 0;
720         pme->sidx[0]    = 0;
721         for (i = 1; i < atc->nslab; i++)
722         {
723             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
724             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
725             pme->sidx[i]    = pme->sdispls[i];
726         }
727         /* Total # of particles to be received */
728         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
729
730         pme_realloc_atomcomm_things(atc);
731
732         /* Copy particle coordinates into send buffer and exchange*/
733         for (i = 0; (i < n); i++)
734         {
735             ii = DIM*pme->sidx[pme->idxa[i]];
736             pme->sidx[pme->idxa[i]]++;
737             pme->redist_buf[ii+XX] = x_f[i][XX];
738             pme->redist_buf[ii+YY] = x_f[i][YY];
739             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
740         }
741         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
742                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
743                       pme->rvec_mpi, atc->mpi_comm);
744     }
745     if (forw)
746     {
747         /* Copy charge into send buffer and exchange*/
748         for (i = 0; i < atc->nslab; i++)
749         {
750             pme->sidx[i] = pme->sdispls[i];
751         }
752         for (i = 0; (i < n); i++)
753         {
754             ii = pme->sidx[pme->idxa[i]];
755             pme->sidx[pme->idxa[i]]++;
756             pme->redist_buf[ii] = charge[i];
757         }
758         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
759                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
760                       atc->mpi_comm);
761     }
762     else   /* backward, redistribution from pme to pp */
763     {
764         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
765                       pme->redist_buf, pme->scounts, pme->sdispls,
766                       pme->rvec_mpi, atc->mpi_comm);
767
768         /* Copy data from receive buffer */
769         for (i = 0; i < atc->nslab; i++)
770         {
771             pme->sidx[i] = pme->sdispls[i];
772         }
773         for (i = 0; (i < n); i++)
774         {
775             ii          = DIM*pme->sidx[pme->idxa[i]];
776             x_f[i][XX] += pme->redist_buf[ii+XX];
777             x_f[i][YY] += pme->redist_buf[ii+YY];
778             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
779             pme->sidx[pme->idxa[i]]++;
780         }
781     }
782 #endif
783 }
784
785 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
786                             gmx_bool bBackward, int shift,
787                             void *buf_s, int nbyte_s,
788                             void *buf_r, int nbyte_r)
789 {
790 #ifdef GMX_MPI
791     int        dest, src;
792     MPI_Status stat;
793
794     if (bBackward == FALSE)
795     {
796         dest = atc->node_dest[shift];
797         src  = atc->node_src[shift];
798     }
799     else
800     {
801         dest = atc->node_src[shift];
802         src  = atc->node_dest[shift];
803     }
804
805     if (nbyte_s > 0 && nbyte_r > 0)
806     {
807         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
808                      dest, shift,
809                      buf_r, nbyte_r, MPI_BYTE,
810                      src, shift,
811                      atc->mpi_comm, &stat);
812     }
813     else if (nbyte_s > 0)
814     {
815         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
816                  dest, shift,
817                  atc->mpi_comm);
818     }
819     else if (nbyte_r > 0)
820     {
821         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
822                  src, shift,
823                  atc->mpi_comm, &stat);
824     }
825 #endif
826 }
827
828 static void dd_pmeredist_x_q(gmx_pme_t pme,
829                              int n, gmx_bool bX, rvec *x, real *charge,
830                              pme_atomcomm_t *atc)
831 {
832     int *commnode, *buf_index;
833     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
834
835     commnode  = atc->node_dest;
836     buf_index = atc->buf_index;
837
838     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
839
840     nsend = 0;
841     for (i = 0; i < nnodes_comm; i++)
842     {
843         buf_index[commnode[i]] = nsend;
844         nsend                 += atc->count[commnode[i]];
845     }
846     if (bX)
847     {
848         if (atc->count[atc->nodeid] + nsend != n)
849         {
850             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
851                       "This usually means that your system is not well equilibrated.",
852                       n - (atc->count[atc->nodeid] + nsend),
853                       pme->nodeid, 'x'+atc->dimind);
854         }
855
856         if (nsend > pme->buf_nalloc)
857         {
858             pme->buf_nalloc = over_alloc_dd(nsend);
859             srenew(pme->bufv, pme->buf_nalloc);
860             srenew(pme->bufr, pme->buf_nalloc);
861         }
862
863         atc->n = atc->count[atc->nodeid];
864         for (i = 0; i < nnodes_comm; i++)
865         {
866             scount = atc->count[commnode[i]];
867             /* Communicate the count */
868             if (debug)
869             {
870                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
871                         atc->dimind, atc->nodeid, commnode[i], scount);
872             }
873             pme_dd_sendrecv(atc, FALSE, i,
874                             &scount, sizeof(int),
875                             &atc->rcount[i], sizeof(int));
876             atc->n += atc->rcount[i];
877         }
878
879         pme_realloc_atomcomm_things(atc);
880     }
881
882     local_pos = 0;
883     for (i = 0; i < n; i++)
884     {
885         node = atc->pd[i];
886         if (node == atc->nodeid)
887         {
888             /* Copy direct to the receive buffer */
889             if (bX)
890             {
891                 copy_rvec(x[i], atc->x[local_pos]);
892             }
893             atc->q[local_pos] = charge[i];
894             local_pos++;
895         }
896         else
897         {
898             /* Copy to the send buffer */
899             if (bX)
900             {
901                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
902             }
903             pme->bufr[buf_index[node]] = charge[i];
904             buf_index[node]++;
905         }
906     }
907
908     buf_pos = 0;
909     for (i = 0; i < nnodes_comm; i++)
910     {
911         scount = atc->count[commnode[i]];
912         rcount = atc->rcount[i];
913         if (scount > 0 || rcount > 0)
914         {
915             if (bX)
916             {
917                 /* Communicate the coordinates */
918                 pme_dd_sendrecv(atc, FALSE, i,
919                                 pme->bufv[buf_pos], scount*sizeof(rvec),
920                                 atc->x[local_pos], rcount*sizeof(rvec));
921             }
922             /* Communicate the charges */
923             pme_dd_sendrecv(atc, FALSE, i,
924                             pme->bufr+buf_pos, scount*sizeof(real),
925                             atc->q+local_pos, rcount*sizeof(real));
926             buf_pos   += scount;
927             local_pos += atc->rcount[i];
928         }
929     }
930 }
931
932 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
933                            int n, rvec *f,
934                            gmx_bool bAddF)
935 {
936     int *commnode, *buf_index;
937     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
938
939     commnode  = atc->node_dest;
940     buf_index = atc->buf_index;
941
942     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
943
944     local_pos = atc->count[atc->nodeid];
945     buf_pos   = 0;
946     for (i = 0; i < nnodes_comm; i++)
947     {
948         scount = atc->rcount[i];
949         rcount = atc->count[commnode[i]];
950         if (scount > 0 || rcount > 0)
951         {
952             /* Communicate the forces */
953             pme_dd_sendrecv(atc, TRUE, i,
954                             atc->f[local_pos], scount*sizeof(rvec),
955                             pme->bufv[buf_pos], rcount*sizeof(rvec));
956             local_pos += scount;
957         }
958         buf_index[commnode[i]] = buf_pos;
959         buf_pos               += rcount;
960     }
961
962     local_pos = 0;
963     if (bAddF)
964     {
965         for (i = 0; i < n; i++)
966         {
967             node = atc->pd[i];
968             if (node == atc->nodeid)
969             {
970                 /* Add from the local force array */
971                 rvec_inc(f[i], atc->f[local_pos]);
972                 local_pos++;
973             }
974             else
975             {
976                 /* Add from the receive buffer */
977                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
978                 buf_index[node]++;
979             }
980         }
981     }
982     else
983     {
984         for (i = 0; i < n; i++)
985         {
986             node = atc->pd[i];
987             if (node == atc->nodeid)
988             {
989                 /* Copy from the local force array */
990                 copy_rvec(atc->f[local_pos], f[i]);
991                 local_pos++;
992             }
993             else
994             {
995                 /* Copy from the receive buffer */
996                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
997                 buf_index[node]++;
998             }
999         }
1000     }
1001 }
1002
1003 #ifdef GMX_MPI
1004 static void
1005 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
1006 {
1007     pme_overlap_t *overlap;
1008     int            send_index0, send_nindex;
1009     int            recv_index0, recv_nindex;
1010     MPI_Status     stat;
1011     int            i, j, k, ix, iy, iz, icnt;
1012     int            ipulse, send_id, recv_id, datasize;
1013     real          *p;
1014     real          *sendptr, *recvptr;
1015
1016     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1017     overlap = &pme->overlap[1];
1018
1019     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1020     {
1021         /* Since we have already (un)wrapped the overlap in the z-dimension,
1022          * we only have to communicate 0 to nkz (not pmegrid_nz).
1023          */
1024         if (direction == GMX_SUM_QGRID_FORWARD)
1025         {
1026             send_id       = overlap->send_id[ipulse];
1027             recv_id       = overlap->recv_id[ipulse];
1028             send_index0   = overlap->comm_data[ipulse].send_index0;
1029             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1030             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1031             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1032         }
1033         else
1034         {
1035             send_id       = overlap->recv_id[ipulse];
1036             recv_id       = overlap->send_id[ipulse];
1037             send_index0   = overlap->comm_data[ipulse].recv_index0;
1038             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1039             recv_index0   = overlap->comm_data[ipulse].send_index0;
1040             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1041         }
1042
1043         /* Copy data to contiguous send buffer */
1044         if (debug)
1045         {
1046             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1047                     pme->nodeid, overlap->nodeid, send_id,
1048                     pme->pmegrid_start_iy,
1049                     send_index0-pme->pmegrid_start_iy,
1050                     send_index0-pme->pmegrid_start_iy+send_nindex);
1051         }
1052         icnt = 0;
1053         for (i = 0; i < pme->pmegrid_nx; i++)
1054         {
1055             ix = i;
1056             for (j = 0; j < send_nindex; j++)
1057             {
1058                 iy = j + send_index0 - pme->pmegrid_start_iy;
1059                 for (k = 0; k < pme->nkz; k++)
1060                 {
1061                     iz = k;
1062                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1063                 }
1064             }
1065         }
1066
1067         datasize      = pme->pmegrid_nx * pme->nkz;
1068
1069         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1070                      send_id, ipulse,
1071                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1072                      recv_id, ipulse,
1073                      overlap->mpi_comm, &stat);
1074
1075         /* Get data from contiguous recv buffer */
1076         if (debug)
1077         {
1078             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1079                     pme->nodeid, overlap->nodeid, recv_id,
1080                     pme->pmegrid_start_iy,
1081                     recv_index0-pme->pmegrid_start_iy,
1082                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1083         }
1084         icnt = 0;
1085         for (i = 0; i < pme->pmegrid_nx; i++)
1086         {
1087             ix = i;
1088             for (j = 0; j < recv_nindex; j++)
1089             {
1090                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1091                 for (k = 0; k < pme->nkz; k++)
1092                 {
1093                     iz = k;
1094                     if (direction == GMX_SUM_QGRID_FORWARD)
1095                     {
1096                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1097                     }
1098                     else
1099                     {
1100                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1101                     }
1102                 }
1103             }
1104         }
1105     }
1106
1107     /* Major dimension is easier, no copying required,
1108      * but we might have to sum to separate array.
1109      * Since we don't copy, we have to communicate up to pmegrid_nz,
1110      * not nkz as for the minor direction.
1111      */
1112     overlap = &pme->overlap[0];
1113
1114     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1115     {
1116         if (direction == GMX_SUM_QGRID_FORWARD)
1117         {
1118             send_id       = overlap->send_id[ipulse];
1119             recv_id       = overlap->recv_id[ipulse];
1120             send_index0   = overlap->comm_data[ipulse].send_index0;
1121             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1122             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1123             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1124             recvptr       = overlap->recvbuf;
1125         }
1126         else
1127         {
1128             send_id       = overlap->recv_id[ipulse];
1129             recv_id       = overlap->send_id[ipulse];
1130             send_index0   = overlap->comm_data[ipulse].recv_index0;
1131             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1132             recv_index0   = overlap->comm_data[ipulse].send_index0;
1133             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1134             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1135         }
1136
1137         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1138         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1139
1140         if (debug)
1141         {
1142             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1143                     pme->nodeid, overlap->nodeid, send_id,
1144                     pme->pmegrid_start_ix,
1145                     send_index0-pme->pmegrid_start_ix,
1146                     send_index0-pme->pmegrid_start_ix+send_nindex);
1147             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1148                     pme->nodeid, overlap->nodeid, recv_id,
1149                     pme->pmegrid_start_ix,
1150                     recv_index0-pme->pmegrid_start_ix,
1151                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1152         }
1153
1154         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1155                      send_id, ipulse,
1156                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1157                      recv_id, ipulse,
1158                      overlap->mpi_comm, &stat);
1159
1160         /* ADD data from contiguous recv buffer */
1161         if (direction == GMX_SUM_QGRID_FORWARD)
1162         {
1163             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1164             for (i = 0; i < recv_nindex*datasize; i++)
1165             {
1166                 p[i] += overlap->recvbuf[i];
1167             }
1168         }
1169     }
1170 }
1171 #endif
1172
1173
1174 static int
1175 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1176 {
1177     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1178     ivec    local_pme_size;
1179     int     i, ix, iy, iz;
1180     int     pmeidx, fftidx;
1181
1182     /* Dimensions should be identical for A/B grid, so we just use A here */
1183     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1184                                    local_fft_ndata,
1185                                    local_fft_offset,
1186                                    local_fft_size);
1187
1188     local_pme_size[0] = pme->pmegrid_nx;
1189     local_pme_size[1] = pme->pmegrid_ny;
1190     local_pme_size[2] = pme->pmegrid_nz;
1191
1192     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1193        the offset is identical, and the PME grid always has more data (due to overlap)
1194      */
1195     {
1196 #ifdef DEBUG_PME
1197         FILE *fp, *fp2;
1198         char  fn[STRLEN], format[STRLEN];
1199         real  val;
1200         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1201         fp = ffopen(fn, "w");
1202         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1203         fp2 = ffopen(fn, "w");
1204         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1205 #endif
1206
1207         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1208         {
1209             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1210             {
1211                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1212                 {
1213                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1214                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1215                     fftgrid[fftidx] = pmegrid[pmeidx];
1216 #ifdef DEBUG_PME
1217                     val = 100*pmegrid[pmeidx];
1218                     if (pmegrid[pmeidx] != 0)
1219                     {
1220                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1221                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1222                     }
1223                     if (pmegrid[pmeidx] != 0)
1224                     {
1225                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1226                                 "qgrid",
1227                                 pme->pmegrid_start_ix + ix,
1228                                 pme->pmegrid_start_iy + iy,
1229                                 pme->pmegrid_start_iz + iz,
1230                                 pmegrid[pmeidx]);
1231                     }
1232 #endif
1233                 }
1234             }
1235         }
1236 #ifdef DEBUG_PME
1237         ffclose(fp);
1238         ffclose(fp2);
1239 #endif
1240     }
1241     return 0;
1242 }
1243
1244
1245 static gmx_cycles_t omp_cyc_start()
1246 {
1247     return gmx_cycles_read();
1248 }
1249
1250 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1251 {
1252     return gmx_cycles_read() - c;
1253 }
1254
1255
1256 static int
1257 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1258                         int nthread, int thread)
1259 {
1260     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1261     ivec          local_pme_size;
1262     int           ixy0, ixy1, ixy, ix, iy, iz;
1263     int           pmeidx, fftidx;
1264 #ifdef PME_TIME_THREADS
1265     gmx_cycles_t  c1;
1266     static double cs1 = 0;
1267     static int    cnt = 0;
1268 #endif
1269
1270 #ifdef PME_TIME_THREADS
1271     c1 = omp_cyc_start();
1272 #endif
1273     /* Dimensions should be identical for A/B grid, so we just use A here */
1274     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1275                                    local_fft_ndata,
1276                                    local_fft_offset,
1277                                    local_fft_size);
1278
1279     local_pme_size[0] = pme->pmegrid_nx;
1280     local_pme_size[1] = pme->pmegrid_ny;
1281     local_pme_size[2] = pme->pmegrid_nz;
1282
1283     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1284        the offset is identical, and the PME grid always has more data (due to overlap)
1285      */
1286     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1287     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1288
1289     for (ixy = ixy0; ixy < ixy1; ixy++)
1290     {
1291         ix = ixy/local_fft_ndata[YY];
1292         iy = ixy - ix*local_fft_ndata[YY];
1293
1294         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1295         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1296         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1297         {
1298             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1299         }
1300     }
1301
1302 #ifdef PME_TIME_THREADS
1303     c1   = omp_cyc_end(c1);
1304     cs1 += (double)c1;
1305     cnt++;
1306     if (cnt % 20 == 0)
1307     {
1308         printf("copy %.2f\n", cs1*1e-9);
1309     }
1310 #endif
1311
1312     return 0;
1313 }
1314
1315
1316 static void
1317 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1318 {
1319     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1320
1321     nx = pme->nkx;
1322     ny = pme->nky;
1323     nz = pme->nkz;
1324
1325     pnx = pme->pmegrid_nx;
1326     pny = pme->pmegrid_ny;
1327     pnz = pme->pmegrid_nz;
1328
1329     overlap = pme->pme_order - 1;
1330
1331     /* Add periodic overlap in z */
1332     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1333     {
1334         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1335         {
1336             for (iz = 0; iz < overlap; iz++)
1337             {
1338                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1339                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1340             }
1341         }
1342     }
1343
1344     if (pme->nnodes_minor == 1)
1345     {
1346         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1347         {
1348             for (iy = 0; iy < overlap; iy++)
1349             {
1350                 for (iz = 0; iz < nz; iz++)
1351                 {
1352                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1353                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1354                 }
1355             }
1356         }
1357     }
1358
1359     if (pme->nnodes_major == 1)
1360     {
1361         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1362
1363         for (ix = 0; ix < overlap; ix++)
1364         {
1365             for (iy = 0; iy < ny_x; iy++)
1366             {
1367                 for (iz = 0; iz < nz; iz++)
1368                 {
1369                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1370                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1371                 }
1372             }
1373         }
1374     }
1375 }
1376
1377
1378 static void
1379 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1380 {
1381     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1382
1383     nx = pme->nkx;
1384     ny = pme->nky;
1385     nz = pme->nkz;
1386
1387     pnx = pme->pmegrid_nx;
1388     pny = pme->pmegrid_ny;
1389     pnz = pme->pmegrid_nz;
1390
1391     overlap = pme->pme_order - 1;
1392
1393     if (pme->nnodes_major == 1)
1394     {
1395         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1396
1397         for (ix = 0; ix < overlap; ix++)
1398         {
1399             int iy, iz;
1400
1401             for (iy = 0; iy < ny_x; iy++)
1402             {
1403                 for (iz = 0; iz < nz; iz++)
1404                 {
1405                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1406                         pmegrid[(ix*pny+iy)*pnz+iz];
1407                 }
1408             }
1409         }
1410     }
1411
1412     if (pme->nnodes_minor == 1)
1413     {
1414 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1415         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1416         {
1417             int iy, iz;
1418
1419             for (iy = 0; iy < overlap; iy++)
1420             {
1421                 for (iz = 0; iz < nz; iz++)
1422                 {
1423                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1424                         pmegrid[(ix*pny+iy)*pnz+iz];
1425                 }
1426             }
1427         }
1428     }
1429
1430     /* Copy periodic overlap in z */
1431 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1432     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1433     {
1434         int iy, iz;
1435
1436         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1437         {
1438             for (iz = 0; iz < overlap; iz++)
1439             {
1440                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1441                     pmegrid[(ix*pny+iy)*pnz+iz];
1442             }
1443         }
1444     }
1445 }
1446
1447 static void clear_grid(int nx, int ny, int nz, real *grid,
1448                        ivec fs, int *flag,
1449                        int fx, int fy, int fz,
1450                        int order)
1451 {
1452     int nc, ncz;
1453     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1454     int flind;
1455
1456     nc  = 2 + (order - 2)/FLBS;
1457     ncz = 2 + (order - 2)/FLBSZ;
1458
1459     for (fsx = fx; fsx < fx+nc; fsx++)
1460     {
1461         for (fsy = fy; fsy < fy+nc; fsy++)
1462         {
1463             for (fsz = fz; fsz < fz+ncz; fsz++)
1464             {
1465                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1466                 if (flag[flind] == 0)
1467                 {
1468                     gx  = fsx*FLBS;
1469                     gy  = fsy*FLBS;
1470                     gz  = fsz*FLBSZ;
1471                     g0x = (gx*ny + gy)*nz + gz;
1472                     for (x = 0; x < FLBS; x++)
1473                     {
1474                         g0y = g0x;
1475                         for (y = 0; y < FLBS; y++)
1476                         {
1477                             for (z = 0; z < FLBSZ; z++)
1478                             {
1479                                 grid[g0y+z] = 0;
1480                             }
1481                             g0y += nz;
1482                         }
1483                         g0x += ny*nz;
1484                     }
1485
1486                     flag[flind] = 1;
1487                 }
1488             }
1489         }
1490     }
1491 }
1492
1493 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1494 #define DO_BSPLINE(order)                            \
1495     for (ithx = 0; (ithx < order); ithx++)                    \
1496     {                                                    \
1497         index_x = (i0+ithx)*pny*pnz;                     \
1498         valx    = qn*thx[ithx];                          \
1499                                                      \
1500         for (ithy = 0; (ithy < order); ithy++)                \
1501         {                                                \
1502             valxy    = valx*thy[ithy];                   \
1503             index_xy = index_x+(j0+ithy)*pnz;            \
1504                                                      \
1505             for (ithz = 0; (ithz < order); ithz++)            \
1506             {                                            \
1507                 index_xyz        = index_xy+(k0+ithz);   \
1508                 grid[index_xyz] += valxy*thz[ithz];      \
1509             }                                            \
1510         }                                                \
1511     }
1512
1513
1514 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1515                                      pme_atomcomm_t *atc, splinedata_t *spline,
1516                                      pme_spline_work_t *work)
1517 {
1518
1519     /* spread charges from home atoms to local grid */
1520     real          *grid;
1521     pme_overlap_t *ol;
1522     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1523     int       *    idxptr;
1524     int            order, norder, index_x, index_xy, index_xyz;
1525     real           valx, valxy, qn;
1526     real          *thx, *thy, *thz;
1527     int            localsize, bndsize;
1528     int            pnx, pny, pnz, ndatatot;
1529     int            offx, offy, offz;
1530
1531 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1532     real           thz_buffer[12], *thz_aligned;
1533
1534     thz_aligned = gmx_simd4_align_real(thz_buffer);
1535 #endif
1536
1537     pnx = pmegrid->s[XX];
1538     pny = pmegrid->s[YY];
1539     pnz = pmegrid->s[ZZ];
1540
1541     offx = pmegrid->offset[XX];
1542     offy = pmegrid->offset[YY];
1543     offz = pmegrid->offset[ZZ];
1544
1545     ndatatot = pnx*pny*pnz;
1546     grid     = pmegrid->grid;
1547     for (i = 0; i < ndatatot; i++)
1548     {
1549         grid[i] = 0;
1550     }
1551
1552     order = pmegrid->order;
1553
1554     for (nn = 0; nn < spline->n; nn++)
1555     {
1556         n  = spline->ind[nn];
1557         qn = atc->q[n];
1558
1559         if (qn != 0)
1560         {
1561             idxptr = atc->idx[n];
1562             norder = nn*order;
1563
1564             i0   = idxptr[XX] - offx;
1565             j0   = idxptr[YY] - offy;
1566             k0   = idxptr[ZZ] - offz;
1567
1568             thx = spline->theta[XX] + norder;
1569             thy = spline->theta[YY] + norder;
1570             thz = spline->theta[ZZ] + norder;
1571
1572             switch (order)
1573             {
1574                 case 4:
1575 #ifdef PME_SIMD4_SPREAD_GATHER
1576 #ifdef PME_SIMD4_UNALIGNED
1577 #define PME_SPREAD_SIMD4_ORDER4
1578 #else
1579 #define PME_SPREAD_SIMD4_ALIGNED
1580 #define PME_ORDER 4
1581 #endif
1582 #include "pme_simd4.h"
1583 #else
1584                     DO_BSPLINE(4);
1585 #endif
1586                     break;
1587                 case 5:
1588 #ifdef PME_SIMD4_SPREAD_GATHER
1589 #define PME_SPREAD_SIMD4_ALIGNED
1590 #define PME_ORDER 5
1591 #include "pme_simd4.h"
1592 #else
1593                     DO_BSPLINE(5);
1594 #endif
1595                     break;
1596                 default:
1597                     DO_BSPLINE(order);
1598                     break;
1599             }
1600         }
1601     }
1602 }
1603
1604 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1605 {
1606 #ifdef PME_SIMD4_SPREAD_GATHER
1607     if (pme_order == 5
1608 #ifndef PME_SIMD4_UNALIGNED
1609         || pme_order == 4
1610 #endif
1611         )
1612     {
1613         /* Round nz up to a multiple of 4 to ensure alignment */
1614         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1615     }
1616 #endif
1617 }
1618
1619 static void set_gridsize_alignment(int *gridsize, int pme_order)
1620 {
1621 #ifdef PME_SIMD4_SPREAD_GATHER
1622 #ifndef PME_SIMD4_UNALIGNED
1623     if (pme_order == 4)
1624     {
1625         /* Add extra elements to ensured aligned operations do not go
1626          * beyond the allocated grid size.
1627          * Note that for pme_order=5, the pme grid z-size alignment
1628          * ensures that we will not go beyond the grid size.
1629          */
1630         *gridsize += 4;
1631     }
1632 #endif
1633 #endif
1634 }
1635
1636 static void pmegrid_init(pmegrid_t *grid,
1637                          int cx, int cy, int cz,
1638                          int x0, int y0, int z0,
1639                          int x1, int y1, int z1,
1640                          gmx_bool set_alignment,
1641                          int pme_order,
1642                          real *ptr)
1643 {
1644     int nz, gridsize;
1645
1646     grid->ci[XX]     = cx;
1647     grid->ci[YY]     = cy;
1648     grid->ci[ZZ]     = cz;
1649     grid->offset[XX] = x0;
1650     grid->offset[YY] = y0;
1651     grid->offset[ZZ] = z0;
1652     grid->n[XX]      = x1 - x0 + pme_order - 1;
1653     grid->n[YY]      = y1 - y0 + pme_order - 1;
1654     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1655     copy_ivec(grid->n, grid->s);
1656
1657     nz = grid->s[ZZ];
1658     set_grid_alignment(&nz, pme_order);
1659     if (set_alignment)
1660     {
1661         grid->s[ZZ] = nz;
1662     }
1663     else if (nz != grid->s[ZZ])
1664     {
1665         gmx_incons("pmegrid_init call with an unaligned z size");
1666     }
1667
1668     grid->order = pme_order;
1669     if (ptr == NULL)
1670     {
1671         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1672         set_gridsize_alignment(&gridsize, pme_order);
1673         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1674     }
1675     else
1676     {
1677         grid->grid = ptr;
1678     }
1679 }
1680
1681 static int div_round_up(int enumerator, int denominator)
1682 {
1683     return (enumerator + denominator - 1)/denominator;
1684 }
1685
1686 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1687                                   ivec nsub)
1688 {
1689     int gsize_opt, gsize;
1690     int nsx, nsy, nsz;
1691     char *env;
1692
1693     gsize_opt = -1;
1694     for (nsx = 1; nsx <= nthread; nsx++)
1695     {
1696         if (nthread % nsx == 0)
1697         {
1698             for (nsy = 1; nsy <= nthread; nsy++)
1699             {
1700                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1701                 {
1702                     nsz = nthread/(nsx*nsy);
1703
1704                     /* Determine the number of grid points per thread */
1705                     gsize =
1706                         (div_round_up(n[XX], nsx) + ovl)*
1707                         (div_round_up(n[YY], nsy) + ovl)*
1708                         (div_round_up(n[ZZ], nsz) + ovl);
1709
1710                     /* Minimize the number of grids points per thread
1711                      * and, secondarily, the number of cuts in minor dimensions.
1712                      */
1713                     if (gsize_opt == -1 ||
1714                         gsize < gsize_opt ||
1715                         (gsize == gsize_opt &&
1716                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1717                     {
1718                         nsub[XX]  = nsx;
1719                         nsub[YY]  = nsy;
1720                         nsub[ZZ]  = nsz;
1721                         gsize_opt = gsize;
1722                     }
1723                 }
1724             }
1725         }
1726     }
1727
1728     env = getenv("GMX_PME_THREAD_DIVISION");
1729     if (env != NULL)
1730     {
1731         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1732     }
1733
1734     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1735     {
1736         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1737     }
1738 }
1739
1740 static void pmegrids_init(pmegrids_t *grids,
1741                           int nx, int ny, int nz, int nz_base,
1742                           int pme_order,
1743                           gmx_bool bUseThreads,
1744                           int nthread,
1745                           int overlap_x,
1746                           int overlap_y)
1747 {
1748     ivec n, n_base, g0, g1;
1749     int t, x, y, z, d, i, tfac;
1750     int max_comm_lines = -1;
1751
1752     n[XX] = nx - (pme_order - 1);
1753     n[YY] = ny - (pme_order - 1);
1754     n[ZZ] = nz - (pme_order - 1);
1755
1756     copy_ivec(n, n_base);
1757     n_base[ZZ] = nz_base;
1758
1759     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1760                  NULL);
1761
1762     grids->nthread = nthread;
1763
1764     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1765
1766     if (bUseThreads)
1767     {
1768         ivec nst;
1769         int gridsize;
1770
1771         for (d = 0; d < DIM; d++)
1772         {
1773             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1774         }
1775         set_grid_alignment(&nst[ZZ], pme_order);
1776
1777         if (debug)
1778         {
1779             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1780                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1781             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1782                     nx, ny, nz,
1783                     nst[XX], nst[YY], nst[ZZ]);
1784         }
1785
1786         snew(grids->grid_th, grids->nthread);
1787         t        = 0;
1788         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1789         set_gridsize_alignment(&gridsize, pme_order);
1790         snew_aligned(grids->grid_all,
1791                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1792                      SIMD4_ALIGNMENT);
1793
1794         for (x = 0; x < grids->nc[XX]; x++)
1795         {
1796             for (y = 0; y < grids->nc[YY]; y++)
1797             {
1798                 for (z = 0; z < grids->nc[ZZ]; z++)
1799                 {
1800                     pmegrid_init(&grids->grid_th[t],
1801                                  x, y, z,
1802                                  (n[XX]*(x  ))/grids->nc[XX],
1803                                  (n[YY]*(y  ))/grids->nc[YY],
1804                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1805                                  (n[XX]*(x+1))/grids->nc[XX],
1806                                  (n[YY]*(y+1))/grids->nc[YY],
1807                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1808                                  TRUE,
1809                                  pme_order,
1810                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1811                     t++;
1812                 }
1813             }
1814         }
1815     }
1816     else
1817     {
1818         grids->grid_th = NULL;
1819     }
1820
1821     snew(grids->g2t, DIM);
1822     tfac = 1;
1823     for (d = DIM-1; d >= 0; d--)
1824     {
1825         snew(grids->g2t[d], n[d]);
1826         t = 0;
1827         for (i = 0; i < n[d]; i++)
1828         {
1829             /* The second check should match the parameters
1830              * of the pmegrid_init call above.
1831              */
1832             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1833             {
1834                 t++;
1835             }
1836             grids->g2t[d][i] = t*tfac;
1837         }
1838
1839         tfac *= grids->nc[d];
1840
1841         switch (d)
1842         {
1843             case XX: max_comm_lines = overlap_x;     break;
1844             case YY: max_comm_lines = overlap_y;     break;
1845             case ZZ: max_comm_lines = pme_order - 1; break;
1846         }
1847         grids->nthread_comm[d] = 0;
1848         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1849                grids->nthread_comm[d] < grids->nc[d])
1850         {
1851             grids->nthread_comm[d]++;
1852         }
1853         if (debug != NULL)
1854         {
1855             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1856                     'x'+d, grids->nthread_comm[d]);
1857         }
1858         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1859          * work, but this is not a problematic restriction.
1860          */
1861         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1862         {
1863             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1864         }
1865     }
1866 }
1867
1868
1869 static void pmegrids_destroy(pmegrids_t *grids)
1870 {
1871     int t;
1872
1873     if (grids->grid.grid != NULL)
1874     {
1875         sfree(grids->grid.grid);
1876
1877         if (grids->nthread > 0)
1878         {
1879             for (t = 0; t < grids->nthread; t++)
1880             {
1881                 sfree(grids->grid_th[t].grid);
1882             }
1883             sfree(grids->grid_th);
1884         }
1885     }
1886 }
1887
1888
1889 static void realloc_work(pme_work_t *work, int nkx)
1890 {
1891     int simd_width;
1892
1893     if (nkx > work->nalloc)
1894     {
1895         work->nalloc = nkx;
1896         srenew(work->mhx, work->nalloc);
1897         srenew(work->mhy, work->nalloc);
1898         srenew(work->mhz, work->nalloc);
1899         srenew(work->m2, work->nalloc);
1900         /* Allocate an aligned pointer for SIMD operations, including extra
1901          * elements at the end for padding.
1902          */
1903 #ifdef PME_SIMD
1904         simd_width = GMX_SIMD_WIDTH_HERE;
1905 #else
1906         /* We can use any alignment, apart from 0, so we use 4 */
1907         simd_width = 4;
1908 #endif
1909         sfree_aligned(work->denom);
1910         sfree_aligned(work->tmp1);
1911         sfree_aligned(work->eterm);
1912         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1913         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1914         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1915         srenew(work->m2inv, work->nalloc);
1916     }
1917 }
1918
1919
1920 static void free_work(pme_work_t *work)
1921 {
1922     sfree(work->mhx);
1923     sfree(work->mhy);
1924     sfree(work->mhz);
1925     sfree(work->m2);
1926     sfree_aligned(work->denom);
1927     sfree_aligned(work->tmp1);
1928     sfree_aligned(work->eterm);
1929     sfree(work->m2inv);
1930 }
1931
1932
1933 #ifdef PME_SIMD
1934 /* Calculate exponentials through SIMD */
1935 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1936 {
1937     {
1938         const gmx_mm_pr two = gmx_set1_pr(2.0);
1939         gmx_mm_pr f_simd;
1940         gmx_mm_pr lu;
1941         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1942         int kx;
1943         f_simd = gmx_set1_pr(f);
1944         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1945         {
1946             tmp_d1   = gmx_load_pr(d_aligned+kx);
1947             d_inv    = gmx_inv_pr(tmp_d1);
1948             tmp_r    = gmx_load_pr(r_aligned+kx);
1949             tmp_r    = gmx_exp_pr(tmp_r);
1950             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1951             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1952             gmx_store_pr(e_aligned+kx, tmp_e);
1953         }
1954     }
1955 }
1956 #else
1957 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1958 {
1959     int kx;
1960     for (kx = start; kx < end; kx++)
1961     {
1962         d[kx] = 1.0/d[kx];
1963     }
1964     for (kx = start; kx < end; kx++)
1965     {
1966         r[kx] = exp(r[kx]);
1967     }
1968     for (kx = start; kx < end; kx++)
1969     {
1970         e[kx] = f*r[kx]*d[kx];
1971     }
1972 }
1973 #endif
1974
1975
1976 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1977                          real ewaldcoeff, real vol,
1978                          gmx_bool bEnerVir,
1979                          int nthread, int thread)
1980 {
1981     /* do recip sum over local cells in grid */
1982     /* y major, z middle, x minor or continuous */
1983     t_complex *p0;
1984     int     kx, ky, kz, maxkx, maxky, maxkz;
1985     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1986     real    mx, my, mz;
1987     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1988     real    ets2, struct2, vfactor, ets2vf;
1989     real    d1, d2, energy = 0;
1990     real    by, bz;
1991     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1992     real    rxx, ryx, ryy, rzx, rzy, rzz;
1993     pme_work_t *work;
1994     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1995     real    mhxk, mhyk, mhzk, m2k;
1996     real    corner_fac;
1997     ivec    complex_order;
1998     ivec    local_ndata, local_offset, local_size;
1999     real    elfac;
2000
2001     elfac = ONE_4PI_EPS0/pme->epsilon_r;
2002
2003     nx = pme->nkx;
2004     ny = pme->nky;
2005     nz = pme->nkz;
2006
2007     /* Dimensions should be identical for A/B grid, so we just use A here */
2008     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
2009                                       complex_order,
2010                                       local_ndata,
2011                                       local_offset,
2012                                       local_size);
2013
2014     rxx = pme->recipbox[XX][XX];
2015     ryx = pme->recipbox[YY][XX];
2016     ryy = pme->recipbox[YY][YY];
2017     rzx = pme->recipbox[ZZ][XX];
2018     rzy = pme->recipbox[ZZ][YY];
2019     rzz = pme->recipbox[ZZ][ZZ];
2020
2021     maxkx = (nx+1)/2;
2022     maxky = (ny+1)/2;
2023     maxkz = nz/2+1;
2024
2025     work  = &pme->work[thread];
2026     mhx   = work->mhx;
2027     mhy   = work->mhy;
2028     mhz   = work->mhz;
2029     m2    = work->m2;
2030     denom = work->denom;
2031     tmp1  = work->tmp1;
2032     eterm = work->eterm;
2033     m2inv = work->m2inv;
2034
2035     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2036     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2037
2038     for (iyz = iyz0; iyz < iyz1; iyz++)
2039     {
2040         iy = iyz/local_ndata[ZZ];
2041         iz = iyz - iy*local_ndata[ZZ];
2042
2043         ky = iy + local_offset[YY];
2044
2045         if (ky < maxky)
2046         {
2047             my = ky;
2048         }
2049         else
2050         {
2051             my = (ky - ny);
2052         }
2053
2054         by = M_PI*vol*pme->bsp_mod[YY][ky];
2055
2056         kz = iz + local_offset[ZZ];
2057
2058         mz = kz;
2059
2060         bz = pme->bsp_mod[ZZ][kz];
2061
2062         /* 0.5 correction for corner points */
2063         corner_fac = 1;
2064         if (kz == 0 || kz == (nz+1)/2)
2065         {
2066             corner_fac = 0.5;
2067         }
2068
2069         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2070
2071         /* We should skip the k-space point (0,0,0) */
2072         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2073         {
2074             kxstart = local_offset[XX];
2075         }
2076         else
2077         {
2078             kxstart = local_offset[XX] + 1;
2079             p0++;
2080         }
2081         kxend = local_offset[XX] + local_ndata[XX];
2082
2083         if (bEnerVir)
2084         {
2085             /* More expensive inner loop, especially because of the storage
2086              * of the mh elements in array's.
2087              * Because x is the minor grid index, all mh elements
2088              * depend on kx for triclinic unit cells.
2089              */
2090
2091             /* Two explicit loops to avoid a conditional inside the loop */
2092             for (kx = kxstart; kx < maxkx; kx++)
2093             {
2094                 mx = kx;
2095
2096                 mhxk      = mx * rxx;
2097                 mhyk      = mx * ryx + my * ryy;
2098                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2099                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2100                 mhx[kx]   = mhxk;
2101                 mhy[kx]   = mhyk;
2102                 mhz[kx]   = mhzk;
2103                 m2[kx]    = m2k;
2104                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2105                 tmp1[kx]  = -factor*m2k;
2106             }
2107
2108             for (kx = maxkx; kx < kxend; kx++)
2109             {
2110                 mx = (kx - nx);
2111
2112                 mhxk      = mx * rxx;
2113                 mhyk      = mx * ryx + my * ryy;
2114                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2115                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2116                 mhx[kx]   = mhxk;
2117                 mhy[kx]   = mhyk;
2118                 mhz[kx]   = mhzk;
2119                 m2[kx]    = m2k;
2120                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2121                 tmp1[kx]  = -factor*m2k;
2122             }
2123
2124             for (kx = kxstart; kx < kxend; kx++)
2125             {
2126                 m2inv[kx] = 1.0/m2[kx];
2127             }
2128
2129             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2130
2131             for (kx = kxstart; kx < kxend; kx++, p0++)
2132             {
2133                 d1      = p0->re;
2134                 d2      = p0->im;
2135
2136                 p0->re  = d1*eterm[kx];
2137                 p0->im  = d2*eterm[kx];
2138
2139                 struct2 = 2.0*(d1*d1+d2*d2);
2140
2141                 tmp1[kx] = eterm[kx]*struct2;
2142             }
2143
2144             for (kx = kxstart; kx < kxend; kx++)
2145             {
2146                 ets2     = corner_fac*tmp1[kx];
2147                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2148                 energy  += ets2;
2149
2150                 ets2vf   = ets2*vfactor;
2151                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2152                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2153                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2154                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2155                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2156                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2157             }
2158         }
2159         else
2160         {
2161             /* We don't need to calculate the energy and the virial.
2162              * In this case the triclinic overhead is small.
2163              */
2164
2165             /* Two explicit loops to avoid a conditional inside the loop */
2166
2167             for (kx = kxstart; kx < maxkx; kx++)
2168             {
2169                 mx = kx;
2170
2171                 mhxk      = mx * rxx;
2172                 mhyk      = mx * ryx + my * ryy;
2173                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2174                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2175                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2176                 tmp1[kx]  = -factor*m2k;
2177             }
2178
2179             for (kx = maxkx; kx < kxend; kx++)
2180             {
2181                 mx = (kx - nx);
2182
2183                 mhxk      = mx * rxx;
2184                 mhyk      = mx * ryx + my * ryy;
2185                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2186                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2187                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2188                 tmp1[kx]  = -factor*m2k;
2189             }
2190
2191             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2192
2193             for (kx = kxstart; kx < kxend; kx++, p0++)
2194             {
2195                 d1      = p0->re;
2196                 d2      = p0->im;
2197
2198                 p0->re  = d1*eterm[kx];
2199                 p0->im  = d2*eterm[kx];
2200             }
2201         }
2202     }
2203
2204     if (bEnerVir)
2205     {
2206         /* Update virial with local values.
2207          * The virial is symmetric by definition.
2208          * this virial seems ok for isotropic scaling, but I'm
2209          * experiencing problems on semiisotropic membranes.
2210          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2211          */
2212         work->vir[XX][XX] = 0.25*virxx;
2213         work->vir[YY][YY] = 0.25*viryy;
2214         work->vir[ZZ][ZZ] = 0.25*virzz;
2215         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2216         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2217         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2218
2219         /* This energy should be corrected for a charged system */
2220         work->energy = 0.5*energy;
2221     }
2222
2223     /* Return the loop count */
2224     return local_ndata[YY]*local_ndata[XX];
2225 }
2226
2227 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2228                              real *mesh_energy, matrix vir)
2229 {
2230     /* This function sums output over threads
2231      * and should therefore only be called after thread synchronization.
2232      */
2233     int thread;
2234
2235     *mesh_energy = pme->work[0].energy;
2236     copy_mat(pme->work[0].vir, vir);
2237
2238     for (thread = 1; thread < nthread; thread++)
2239     {
2240         *mesh_energy += pme->work[thread].energy;
2241         m_add(vir, pme->work[thread].vir, vir);
2242     }
2243 }
2244
2245 #define DO_FSPLINE(order)                      \
2246     for (ithx = 0; (ithx < order); ithx++)              \
2247     {                                              \
2248         index_x = (i0+ithx)*pny*pnz;               \
2249         tx      = thx[ithx];                       \
2250         dx      = dthx[ithx];                      \
2251                                                \
2252         for (ithy = 0; (ithy < order); ithy++)          \
2253         {                                          \
2254             index_xy = index_x+(j0+ithy)*pnz;      \
2255             ty       = thy[ithy];                  \
2256             dy       = dthy[ithy];                 \
2257             fxy1     = fz1 = 0;                    \
2258                                                \
2259             for (ithz = 0; (ithz < order); ithz++)      \
2260             {                                      \
2261                 gval  = grid[index_xy+(k0+ithz)];  \
2262                 fxy1 += thz[ithz]*gval;            \
2263                 fz1  += dthz[ithz]*gval;           \
2264             }                                      \
2265             fx += dx*ty*fxy1;                      \
2266             fy += tx*dy*fxy1;                      \
2267             fz += tx*ty*fz1;                       \
2268         }                                          \
2269     }
2270
2271
2272 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2273                               gmx_bool bClearF, pme_atomcomm_t *atc,
2274                               splinedata_t *spline,
2275                               real scale)
2276 {
2277     /* sum forces for local particles */
2278     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2279     int     index_x, index_xy;
2280     int     nx, ny, nz, pnx, pny, pnz;
2281     int *   idxptr;
2282     real    tx, ty, dx, dy, qn;
2283     real    fx, fy, fz, gval;
2284     real    fxy1, fz1;
2285     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2286     int     norder;
2287     real    rxx, ryx, ryy, rzx, rzy, rzz;
2288     int     order;
2289
2290     pme_spline_work_t *work;
2291
2292 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2293     real           thz_buffer[12],  *thz_aligned;
2294     real           dthz_buffer[12], *dthz_aligned;
2295
2296     thz_aligned  = gmx_simd4_align_real(thz_buffer);
2297     dthz_aligned = gmx_simd4_align_real(dthz_buffer);
2298 #endif
2299
2300     work = pme->spline_work;
2301
2302     order = pme->pme_order;
2303     thx   = spline->theta[XX];
2304     thy   = spline->theta[YY];
2305     thz   = spline->theta[ZZ];
2306     dthx  = spline->dtheta[XX];
2307     dthy  = spline->dtheta[YY];
2308     dthz  = spline->dtheta[ZZ];
2309     nx    = pme->nkx;
2310     ny    = pme->nky;
2311     nz    = pme->nkz;
2312     pnx   = pme->pmegrid_nx;
2313     pny   = pme->pmegrid_ny;
2314     pnz   = pme->pmegrid_nz;
2315
2316     rxx   = pme->recipbox[XX][XX];
2317     ryx   = pme->recipbox[YY][XX];
2318     ryy   = pme->recipbox[YY][YY];
2319     rzx   = pme->recipbox[ZZ][XX];
2320     rzy   = pme->recipbox[ZZ][YY];
2321     rzz   = pme->recipbox[ZZ][ZZ];
2322
2323     for (nn = 0; nn < spline->n; nn++)
2324     {
2325         n  = spline->ind[nn];
2326         qn = scale*atc->q[n];
2327
2328         if (bClearF)
2329         {
2330             atc->f[n][XX] = 0;
2331             atc->f[n][YY] = 0;
2332             atc->f[n][ZZ] = 0;
2333         }
2334         if (qn != 0)
2335         {
2336             fx     = 0;
2337             fy     = 0;
2338             fz     = 0;
2339             idxptr = atc->idx[n];
2340             norder = nn*order;
2341
2342             i0   = idxptr[XX];
2343             j0   = idxptr[YY];
2344             k0   = idxptr[ZZ];
2345
2346             /* Pointer arithmetic alert, next six statements */
2347             thx  = spline->theta[XX] + norder;
2348             thy  = spline->theta[YY] + norder;
2349             thz  = spline->theta[ZZ] + norder;
2350             dthx = spline->dtheta[XX] + norder;
2351             dthy = spline->dtheta[YY] + norder;
2352             dthz = spline->dtheta[ZZ] + norder;
2353
2354             switch (order)
2355             {
2356                 case 4:
2357 #ifdef PME_SIMD4_SPREAD_GATHER
2358 #ifdef PME_SIMD4_UNALIGNED
2359 #define PME_GATHER_F_SIMD4_ORDER4
2360 #else
2361 #define PME_GATHER_F_SIMD4_ALIGNED
2362 #define PME_ORDER 4
2363 #endif
2364 #include "pme_simd4.h"
2365 #else
2366                     DO_FSPLINE(4);
2367 #endif
2368                     break;
2369                 case 5:
2370 #ifdef PME_SIMD4_SPREAD_GATHER
2371 #define PME_GATHER_F_SIMD4_ALIGNED
2372 #define PME_ORDER 5
2373 #include "pme_simd4.h"
2374 #else
2375                     DO_FSPLINE(5);
2376 #endif
2377                     break;
2378                 default:
2379                     DO_FSPLINE(order);
2380                     break;
2381             }
2382
2383             atc->f[n][XX] += -qn*( fx*nx*rxx );
2384             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2385             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2386         }
2387     }
2388     /* Since the energy and not forces are interpolated
2389      * the net force might not be exactly zero.
2390      * This can be solved by also interpolating F, but
2391      * that comes at a cost.
2392      * A better hack is to remove the net force every
2393      * step, but that must be done at a higher level
2394      * since this routine doesn't see all atoms if running
2395      * in parallel. Don't know how important it is?  EL 990726
2396      */
2397 }
2398
2399
2400 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2401                                    pme_atomcomm_t *atc)
2402 {
2403     splinedata_t *spline;
2404     int     n, ithx, ithy, ithz, i0, j0, k0;
2405     int     index_x, index_xy;
2406     int *   idxptr;
2407     real    energy, pot, tx, ty, qn, gval;
2408     real    *thx, *thy, *thz;
2409     int     norder;
2410     int     order;
2411
2412     spline = &atc->spline[0];
2413
2414     order = pme->pme_order;
2415
2416     energy = 0;
2417     for (n = 0; (n < atc->n); n++)
2418     {
2419         qn      = atc->q[n];
2420
2421         if (qn != 0)
2422         {
2423             idxptr = atc->idx[n];
2424             norder = n*order;
2425
2426             i0   = idxptr[XX];
2427             j0   = idxptr[YY];
2428             k0   = idxptr[ZZ];
2429
2430             /* Pointer arithmetic alert, next three statements */
2431             thx  = spline->theta[XX] + norder;
2432             thy  = spline->theta[YY] + norder;
2433             thz  = spline->theta[ZZ] + norder;
2434
2435             pot = 0;
2436             for (ithx = 0; (ithx < order); ithx++)
2437             {
2438                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2439                 tx      = thx[ithx];
2440
2441                 for (ithy = 0; (ithy < order); ithy++)
2442                 {
2443                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2444                     ty       = thy[ithy];
2445
2446                     for (ithz = 0; (ithz < order); ithz++)
2447                     {
2448                         gval  = grid[index_xy+(k0+ithz)];
2449                         pot  += tx*ty*thz[ithz]*gval;
2450                     }
2451
2452                 }
2453             }
2454
2455             energy += pot*qn;
2456         }
2457     }
2458
2459     return energy;
2460 }
2461
2462 /* Macro to force loop unrolling by fixing order.
2463  * This gives a significant performance gain.
2464  */
2465 #define CALC_SPLINE(order)                     \
2466     {                                              \
2467         int j, k, l;                                 \
2468         real dr, div;                               \
2469         real data[PME_ORDER_MAX];                  \
2470         real ddata[PME_ORDER_MAX];                 \
2471                                                \
2472         for (j = 0; (j < DIM); j++)                     \
2473         {                                          \
2474             dr  = xptr[j];                         \
2475                                                \
2476             /* dr is relative offset from lower cell limit */ \
2477             data[order-1] = 0;                     \
2478             data[1]       = dr;                          \
2479             data[0]       = 1 - dr;                      \
2480                                                \
2481             for (k = 3; (k < order); k++)               \
2482             {                                      \
2483                 div       = 1.0/(k - 1.0);               \
2484                 data[k-1] = div*dr*data[k-2];      \
2485                 for (l = 1; (l < (k-1)); l++)           \
2486                 {                                  \
2487                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2488                                        data[k-l-1]);                \
2489                 }                                  \
2490                 data[0] = div*(1-dr)*data[0];      \
2491             }                                      \
2492             /* differentiate */                    \
2493             ddata[0] = -data[0];                   \
2494             for (k = 1; (k < order); k++)               \
2495             {                                      \
2496                 ddata[k] = data[k-1] - data[k];    \
2497             }                                      \
2498                                                \
2499             div           = 1.0/(order - 1);                 \
2500             data[order-1] = div*dr*data[order-2];  \
2501             for (l = 1; (l < (order-1)); l++)           \
2502             {                                      \
2503                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2504                                        (order-l-dr)*data[order-l-1]); \
2505             }                                      \
2506             data[0] = div*(1 - dr)*data[0];        \
2507                                                \
2508             for (k = 0; k < order; k++)                 \
2509             {                                      \
2510                 theta[j][i*order+k]  = data[k];    \
2511                 dtheta[j][i*order+k] = ddata[k];   \
2512             }                                      \
2513         }                                          \
2514     }
2515
2516 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2517                    rvec fractx[], int nr, int ind[], real charge[],
2518                    gmx_bool bFreeEnergy)
2519 {
2520     /* construct splines for local atoms */
2521     int  i, ii;
2522     real *xptr;
2523
2524     for (i = 0; i < nr; i++)
2525     {
2526         /* With free energy we do not use the charge check.
2527          * In most cases this will be more efficient than calling make_bsplines
2528          * twice, since usually more than half the particles have charges.
2529          */
2530         ii = ind[i];
2531         if (bFreeEnergy || charge[ii] != 0.0)
2532         {
2533             xptr = fractx[ii];
2534             switch (order)
2535             {
2536                 case 4:  CALC_SPLINE(4);     break;
2537                 case 5:  CALC_SPLINE(5);     break;
2538                 default: CALC_SPLINE(order); break;
2539             }
2540         }
2541     }
2542 }
2543
2544
2545 void make_dft_mod(real *mod, real *data, int ndata)
2546 {
2547     int i, j;
2548     real sc, ss, arg;
2549
2550     for (i = 0; i < ndata; i++)
2551     {
2552         sc = ss = 0;
2553         for (j = 0; j < ndata; j++)
2554         {
2555             arg = (2.0*M_PI*i*j)/ndata;
2556             sc += data[j]*cos(arg);
2557             ss += data[j]*sin(arg);
2558         }
2559         mod[i] = sc*sc+ss*ss;
2560     }
2561     for (i = 0; i < ndata; i++)
2562     {
2563         if (mod[i] < 1e-7)
2564         {
2565             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2566         }
2567     }
2568 }
2569
2570
2571 static void make_bspline_moduli(splinevec bsp_mod,
2572                                 int nx, int ny, int nz, int order)
2573 {
2574     int nmax = max(nx, max(ny, nz));
2575     real *data, *ddata, *bsp_data;
2576     int i, k, l;
2577     real div;
2578
2579     snew(data, order);
2580     snew(ddata, order);
2581     snew(bsp_data, nmax);
2582
2583     data[order-1] = 0;
2584     data[1]       = 0;
2585     data[0]       = 1;
2586
2587     for (k = 3; k < order; k++)
2588     {
2589         div       = 1.0/(k-1.0);
2590         data[k-1] = 0;
2591         for (l = 1; l < (k-1); l++)
2592         {
2593             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2594         }
2595         data[0] = div*data[0];
2596     }
2597     /* differentiate */
2598     ddata[0] = -data[0];
2599     for (k = 1; k < order; k++)
2600     {
2601         ddata[k] = data[k-1]-data[k];
2602     }
2603     div           = 1.0/(order-1);
2604     data[order-1] = 0;
2605     for (l = 1; l < (order-1); l++)
2606     {
2607         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2608     }
2609     data[0] = div*data[0];
2610
2611     for (i = 0; i < nmax; i++)
2612     {
2613         bsp_data[i] = 0;
2614     }
2615     for (i = 1; i <= order; i++)
2616     {
2617         bsp_data[i] = data[i-1];
2618     }
2619
2620     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2621     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2622     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2623
2624     sfree(data);
2625     sfree(ddata);
2626     sfree(bsp_data);
2627 }
2628
2629
2630 /* Return the P3M optimal influence function */
2631 static double do_p3m_influence(double z, int order)
2632 {
2633     double z2, z4;
2634
2635     z2 = z*z;
2636     z4 = z2*z2;
2637
2638     /* The formula and most constants can be found in:
2639      * Ballenegger et al., JCTC 8, 936 (2012)
2640      */
2641     switch (order)
2642     {
2643         case 2:
2644             return 1.0 - 2.0*z2/3.0;
2645             break;
2646         case 3:
2647             return 1.0 - z2 + 2.0*z4/15.0;
2648             break;
2649         case 4:
2650             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2651             break;
2652         case 5:
2653             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2654             break;
2655         case 6:
2656             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2657             break;
2658         case 7:
2659             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2660         case 8:
2661             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2662             break;
2663     }
2664
2665     return 0.0;
2666 }
2667
2668 /* Calculate the P3M B-spline moduli for one dimension */
2669 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2670 {
2671     double zarg, zai, sinzai, infl;
2672     int    maxk, i;
2673
2674     if (order > 8)
2675     {
2676         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2677     }
2678
2679     zarg = M_PI/n;
2680
2681     maxk = (n + 1)/2;
2682
2683     for (i = -maxk; i < 0; i++)
2684     {
2685         zai          = zarg*i;
2686         sinzai       = sin(zai);
2687         infl         = do_p3m_influence(sinzai, order);
2688         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2689     }
2690     bsp_mod[0] = 1.0;
2691     for (i = 1; i < maxk; i++)
2692     {
2693         zai        = zarg*i;
2694         sinzai     = sin(zai);
2695         infl       = do_p3m_influence(sinzai, order);
2696         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2697     }
2698 }
2699
2700 /* Calculate the P3M B-spline moduli */
2701 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2702                                     int nx, int ny, int nz, int order)
2703 {
2704     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2705     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2706     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2707 }
2708
2709
2710 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2711 {
2712     int nslab, n, i;
2713     int fw, bw;
2714
2715     nslab = atc->nslab;
2716
2717     n = 0;
2718     for (i = 1; i <= nslab/2; i++)
2719     {
2720         fw = (atc->nodeid + i) % nslab;
2721         bw = (atc->nodeid - i + nslab) % nslab;
2722         if (n < nslab - 1)
2723         {
2724             atc->node_dest[n] = fw;
2725             atc->node_src[n]  = bw;
2726             n++;
2727         }
2728         if (n < nslab - 1)
2729         {
2730             atc->node_dest[n] = bw;
2731             atc->node_src[n]  = fw;
2732             n++;
2733         }
2734     }
2735 }
2736
2737 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2738 {
2739     int thread;
2740
2741     if (NULL != log)
2742     {
2743         fprintf(log, "Destroying PME data structures.\n");
2744     }
2745
2746     sfree((*pmedata)->nnx);
2747     sfree((*pmedata)->nny);
2748     sfree((*pmedata)->nnz);
2749
2750     pmegrids_destroy(&(*pmedata)->pmegridA);
2751
2752     sfree((*pmedata)->fftgridA);
2753     sfree((*pmedata)->cfftgridA);
2754     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2755
2756     if ((*pmedata)->pmegridB.grid.grid != NULL)
2757     {
2758         pmegrids_destroy(&(*pmedata)->pmegridB);
2759         sfree((*pmedata)->fftgridB);
2760         sfree((*pmedata)->cfftgridB);
2761         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2762     }
2763     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2764     {
2765         free_work(&(*pmedata)->work[thread]);
2766     }
2767     sfree((*pmedata)->work);
2768
2769     sfree(*pmedata);
2770     *pmedata = NULL;
2771
2772     return 0;
2773 }
2774
2775 static int mult_up(int n, int f)
2776 {
2777     return ((n + f - 1)/f)*f;
2778 }
2779
2780
2781 static double pme_load_imbalance(gmx_pme_t pme)
2782 {
2783     int    nma, nmi;
2784     double n1, n2, n3;
2785
2786     nma = pme->nnodes_major;
2787     nmi = pme->nnodes_minor;
2788
2789     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2790     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2791     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2792
2793     /* pme_solve is roughly double the cost of an fft */
2794
2795     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2796 }
2797
2798 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2799                           int dimind, gmx_bool bSpread)
2800 {
2801     int nk, k, s, thread;
2802
2803     atc->dimind    = dimind;
2804     atc->nslab     = 1;
2805     atc->nodeid    = 0;
2806     atc->pd_nalloc = 0;
2807 #ifdef GMX_MPI
2808     if (pme->nnodes > 1)
2809     {
2810         atc->mpi_comm = pme->mpi_comm_d[dimind];
2811         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2812         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2813     }
2814     if (debug)
2815     {
2816         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2817     }
2818 #endif
2819
2820     atc->bSpread   = bSpread;
2821     atc->pme_order = pme->pme_order;
2822
2823     if (atc->nslab > 1)
2824     {
2825         /* These three allocations are not required for particle decomp. */
2826         snew(atc->node_dest, atc->nslab);
2827         snew(atc->node_src, atc->nslab);
2828         setup_coordinate_communication(atc);
2829
2830         snew(atc->count_thread, pme->nthread);
2831         for (thread = 0; thread < pme->nthread; thread++)
2832         {
2833             snew(atc->count_thread[thread], atc->nslab);
2834         }
2835         atc->count = atc->count_thread[0];
2836         snew(atc->rcount, atc->nslab);
2837         snew(atc->buf_index, atc->nslab);
2838     }
2839
2840     atc->nthread = pme->nthread;
2841     if (atc->nthread > 1)
2842     {
2843         snew(atc->thread_plist, atc->nthread);
2844     }
2845     snew(atc->spline, atc->nthread);
2846     for (thread = 0; thread < atc->nthread; thread++)
2847     {
2848         if (atc->nthread > 1)
2849         {
2850             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2851             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2852         }
2853         snew(atc->spline[thread].thread_one, pme->nthread);
2854         atc->spline[thread].thread_one[thread] = 1;
2855     }
2856 }
2857
2858 static void
2859 init_overlap_comm(pme_overlap_t *  ol,
2860                   int              norder,
2861 #ifdef GMX_MPI
2862                   MPI_Comm         comm,
2863 #endif
2864                   int              nnodes,
2865                   int              nodeid,
2866                   int              ndata,
2867                   int              commplainsize)
2868 {
2869     int lbnd, rbnd, maxlr, b, i;
2870     int exten;
2871     int nn, nk;
2872     pme_grid_comm_t *pgc;
2873     gmx_bool bCont;
2874     int fft_start, fft_end, send_index1, recv_index1;
2875 #ifdef GMX_MPI
2876     MPI_Status stat;
2877
2878     ol->mpi_comm = comm;
2879 #endif
2880
2881     ol->nnodes = nnodes;
2882     ol->nodeid = nodeid;
2883
2884     /* Linear translation of the PME grid won't affect reciprocal space
2885      * calculations, so to optimize we only interpolate "upwards",
2886      * which also means we only have to consider overlap in one direction.
2887      * I.e., particles on this node might also be spread to grid indices
2888      * that belong to higher nodes (modulo nnodes)
2889      */
2890
2891     snew(ol->s2g0, ol->nnodes+1);
2892     snew(ol->s2g1, ol->nnodes);
2893     if (debug)
2894     {
2895         fprintf(debug, "PME slab boundaries:");
2896     }
2897     for (i = 0; i < nnodes; i++)
2898     {
2899         /* s2g0 the local interpolation grid start.
2900          * s2g1 the local interpolation grid end.
2901          * Since in calc_pidx we divide particles, and not grid lines,
2902          * spatially uniform along dimension x or y, we need to round
2903          * s2g0 down and s2g1 up.
2904          */
2905         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2906         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2907
2908         if (debug)
2909         {
2910             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2911         }
2912     }
2913     ol->s2g0[nnodes] = ndata;
2914     if (debug)
2915     {
2916         fprintf(debug, "\n");
2917     }
2918
2919     /* Determine with how many nodes we need to communicate the grid overlap */
2920     b = 0;
2921     do
2922     {
2923         b++;
2924         bCont = FALSE;
2925         for (i = 0; i < nnodes; i++)
2926         {
2927             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2928                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2929             {
2930                 bCont = TRUE;
2931             }
2932         }
2933     }
2934     while (bCont && b < nnodes);
2935     ol->noverlap_nodes = b - 1;
2936
2937     snew(ol->send_id, ol->noverlap_nodes);
2938     snew(ol->recv_id, ol->noverlap_nodes);
2939     for (b = 0; b < ol->noverlap_nodes; b++)
2940     {
2941         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2942         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2943     }
2944     snew(ol->comm_data, ol->noverlap_nodes);
2945
2946     ol->send_size = 0;
2947     for (b = 0; b < ol->noverlap_nodes; b++)
2948     {
2949         pgc = &ol->comm_data[b];
2950         /* Send */
2951         fft_start        = ol->s2g0[ol->send_id[b]];
2952         fft_end          = ol->s2g0[ol->send_id[b]+1];
2953         if (ol->send_id[b] < nodeid)
2954         {
2955             fft_start += ndata;
2956             fft_end   += ndata;
2957         }
2958         send_index1       = ol->s2g1[nodeid];
2959         send_index1       = min(send_index1, fft_end);
2960         pgc->send_index0  = fft_start;
2961         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2962         ol->send_size    += pgc->send_nindex;
2963
2964         /* We always start receiving to the first index of our slab */
2965         fft_start        = ol->s2g0[ol->nodeid];
2966         fft_end          = ol->s2g0[ol->nodeid+1];
2967         recv_index1      = ol->s2g1[ol->recv_id[b]];
2968         if (ol->recv_id[b] > nodeid)
2969         {
2970             recv_index1 -= ndata;
2971         }
2972         recv_index1      = min(recv_index1, fft_end);
2973         pgc->recv_index0 = fft_start;
2974         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2975     }
2976
2977 #ifdef GMX_MPI
2978     /* Communicate the buffer sizes to receive */
2979     for (b = 0; b < ol->noverlap_nodes; b++)
2980     {
2981         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2982                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2983                      ol->mpi_comm, &stat);
2984     }
2985 #endif
2986
2987     /* For non-divisible grid we need pme_order iso pme_order-1 */
2988     snew(ol->sendbuf, norder*commplainsize);
2989     snew(ol->recvbuf, norder*commplainsize);
2990 }
2991
2992 static void
2993 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2994                               int **global_to_local,
2995                               real **fraction_shift)
2996 {
2997     int i;
2998     int * gtl;
2999     real * fsh;
3000
3001     snew(gtl, 5*n);
3002     snew(fsh, 5*n);
3003     for (i = 0; (i < 5*n); i++)
3004     {
3005         /* Determine the global to local grid index */
3006         gtl[i] = (i - local_start + n) % n;
3007         /* For coordinates that fall within the local grid the fraction
3008          * is correct, we don't need to shift it.
3009          */
3010         fsh[i] = 0;
3011         if (local_range < n)
3012         {
3013             /* Due to rounding issues i could be 1 beyond the lower or
3014              * upper boundary of the local grid. Correct the index for this.
3015              * If we shift the index, we need to shift the fraction by
3016              * the same amount in the other direction to not affect
3017              * the weights.
3018              * Note that due to this shifting the weights at the end of
3019              * the spline might change, but that will only involve values
3020              * between zero and values close to the precision of a real,
3021              * which is anyhow the accuracy of the whole mesh calculation.
3022              */
3023             /* With local_range=0 we should not change i=local_start */
3024             if (i % n != local_start)
3025             {
3026                 if (gtl[i] == n-1)
3027                 {
3028                     gtl[i] = 0;
3029                     fsh[i] = -1;
3030                 }
3031                 else if (gtl[i] == local_range)
3032                 {
3033                     gtl[i] = local_range - 1;
3034                     fsh[i] = 1;
3035                 }
3036             }
3037         }
3038     }
3039
3040     *global_to_local = gtl;
3041     *fraction_shift  = fsh;
3042 }
3043
3044 static pme_spline_work_t *make_pme_spline_work(int order)
3045 {
3046     pme_spline_work_t *work;
3047
3048 #ifdef PME_SIMD4_SPREAD_GATHER
3049     real         tmp[12], *tmp_aligned;
3050     gmx_simd4_pr zero_S;
3051     gmx_simd4_pr real_mask_S0, real_mask_S1;
3052     int          of, i;
3053
3054     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3055
3056     tmp_aligned = gmx_simd4_align_real(tmp);
3057
3058     zero_S = gmx_simd4_setzero_pr();
3059
3060     /* Generate bit masks to mask out the unused grid entries,
3061      * as we only operate on order of the 8 grid entries that are
3062      * load into 2 SIMD registers.
3063      */
3064     for (of = 0; of < 8-(order-1); of++)
3065     {
3066         for (i = 0; i < 8; i++)
3067         {
3068             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3069         }
3070         real_mask_S0      = gmx_simd4_load_pr(tmp_aligned);
3071         real_mask_S1      = gmx_simd4_load_pr(tmp_aligned+4);
3072         work->mask_S0[of] = gmx_simd4_cmplt_pr(real_mask_S0, zero_S);
3073         work->mask_S1[of] = gmx_simd4_cmplt_pr(real_mask_S1, zero_S);
3074     }
3075 #else
3076     work = NULL;
3077 #endif
3078
3079     return work;
3080 }
3081
3082 void gmx_pme_check_restrictions(int pme_order,
3083                                 int nkx, int nky, int nkz,
3084                                 int nnodes_major,
3085                                 int nnodes_minor,
3086                                 gmx_bool bUseThreads,
3087                                 gmx_bool bFatal,
3088                                 gmx_bool *bValidSettings)
3089 {
3090     if (pme_order > PME_ORDER_MAX)
3091     {
3092         if (!bFatal)
3093         {
3094             *bValidSettings = FALSE;
3095             return;
3096         }
3097         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3098                   pme_order, PME_ORDER_MAX);
3099     }
3100
3101     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3102         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3103         nkz <= pme_order)
3104     {
3105         if (!bFatal)
3106         {
3107             *bValidSettings = FALSE;
3108             return;
3109         }
3110         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3111                   pme_order);
3112     }
3113
3114     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3115      * We only allow multiple communication pulses in dim 1, not in dim 0.
3116      */
3117     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3118                         nkx != nnodes_major*(pme_order - 1)))
3119     {
3120         if (!bFatal)
3121         {
3122             *bValidSettings = FALSE;
3123             return;
3124         }
3125         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3126                   nkx/(double)nnodes_major, pme_order);
3127     }
3128
3129     if (bValidSettings != NULL)
3130     {
3131         *bValidSettings = TRUE;
3132     }
3133
3134     return;
3135 }
3136
3137 int gmx_pme_init(gmx_pme_t *         pmedata,
3138                  t_commrec *         cr,
3139                  int                 nnodes_major,
3140                  int                 nnodes_minor,
3141                  t_inputrec *        ir,
3142                  int                 homenr,
3143                  gmx_bool            bFreeEnergy,
3144                  gmx_bool            bReproducible,
3145                  int                 nthread)
3146 {
3147     gmx_pme_t pme = NULL;
3148
3149     int  use_threads, sum_use_threads;
3150     ivec ndata;
3151
3152     if (debug)
3153     {
3154         fprintf(debug, "Creating PME data structures.\n");
3155     }
3156     snew(pme, 1);
3157
3158     pme->redist_init         = FALSE;
3159     pme->sum_qgrid_tmp       = NULL;
3160     pme->sum_qgrid_dd_tmp    = NULL;
3161     pme->buf_nalloc          = 0;
3162     pme->redist_buf_nalloc   = 0;
3163
3164     pme->nnodes              = 1;
3165     pme->bPPnode             = TRUE;
3166
3167     pme->nnodes_major        = nnodes_major;
3168     pme->nnodes_minor        = nnodes_minor;
3169
3170 #ifdef GMX_MPI
3171     if (nnodes_major*nnodes_minor > 1)
3172     {
3173         pme->mpi_comm = cr->mpi_comm_mygroup;
3174
3175         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3176         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3177         if (pme->nnodes != nnodes_major*nnodes_minor)
3178         {
3179             gmx_incons("PME node count mismatch");
3180         }
3181     }
3182     else
3183     {
3184         pme->mpi_comm = MPI_COMM_NULL;
3185     }
3186 #endif
3187
3188     if (pme->nnodes == 1)
3189     {
3190 #ifdef GMX_MPI
3191         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3192         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3193 #endif
3194         pme->ndecompdim   = 0;
3195         pme->nodeid_major = 0;
3196         pme->nodeid_minor = 0;
3197 #ifdef GMX_MPI
3198         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3199 #endif
3200     }
3201     else
3202     {
3203         if (nnodes_minor == 1)
3204         {
3205 #ifdef GMX_MPI
3206             pme->mpi_comm_d[0] = pme->mpi_comm;
3207             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3208 #endif
3209             pme->ndecompdim   = 1;
3210             pme->nodeid_major = pme->nodeid;
3211             pme->nodeid_minor = 0;
3212
3213         }
3214         else if (nnodes_major == 1)
3215         {
3216 #ifdef GMX_MPI
3217             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3218             pme->mpi_comm_d[1] = pme->mpi_comm;
3219 #endif
3220             pme->ndecompdim   = 1;
3221             pme->nodeid_major = 0;
3222             pme->nodeid_minor = pme->nodeid;
3223         }
3224         else
3225         {
3226             if (pme->nnodes % nnodes_major != 0)
3227             {
3228                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3229             }
3230             pme->ndecompdim = 2;
3231
3232 #ifdef GMX_MPI
3233             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3234                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3235             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3236                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3237
3238             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3239             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3240             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3241             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3242 #endif
3243         }
3244         pme->bPPnode = (cr->duty & DUTY_PP);
3245     }
3246
3247     pme->nthread = nthread;
3248
3249      /* Check if any of the PME MPI ranks uses threads */
3250     use_threads = (pme->nthread > 1 ? 1 : 0);
3251 #ifdef GMX_MPI
3252     if (pme->nnodes > 1)
3253     {
3254         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3255                       MPI_SUM, pme->mpi_comm);
3256     }
3257     else
3258 #endif
3259     {
3260         sum_use_threads = use_threads;
3261     }
3262     pme->bUseThreads = (sum_use_threads > 0);
3263
3264     if (ir->ePBC == epbcSCREW)
3265     {
3266         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3267     }
3268
3269     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3270     pme->nkx         = ir->nkx;
3271     pme->nky         = ir->nky;
3272     pme->nkz         = ir->nkz;
3273     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3274     pme->pme_order   = ir->pme_order;
3275     pme->epsilon_r   = ir->epsilon_r;
3276
3277     /* If we violate restrictions, generate a fatal error here */
3278     gmx_pme_check_restrictions(pme->pme_order,
3279                                pme->nkx, pme->nky, pme->nkz,
3280                                pme->nnodes_major,
3281                                pme->nnodes_minor,
3282                                pme->bUseThreads,
3283                                TRUE,
3284                                NULL);
3285
3286     if (pme->nnodes > 1)
3287     {
3288         double imbal;
3289
3290 #ifdef GMX_MPI
3291         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3292         MPI_Type_commit(&(pme->rvec_mpi));
3293 #endif
3294
3295         /* Note that the charge spreading and force gathering, which usually
3296          * takes about the same amount of time as FFT+solve_pme,
3297          * is always fully load balanced
3298          * (unless the charge distribution is inhomogeneous).
3299          */
3300
3301         imbal = pme_load_imbalance(pme);
3302         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3303         {
3304             fprintf(stderr,
3305                     "\n"
3306                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3307                     "      For optimal PME load balancing\n"
3308                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3309                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3310                     "\n",
3311                     (int)((imbal-1)*100 + 0.5),
3312                     pme->nkx, pme->nky, pme->nnodes_major,
3313                     pme->nky, pme->nkz, pme->nnodes_minor);
3314         }
3315     }
3316
3317     /* For non-divisible grid we need pme_order iso pme_order-1 */
3318     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3319      * y is always copied through a buffer: we don't need padding in z,
3320      * but we do need the overlap in x because of the communication order.
3321      */
3322     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3323 #ifdef GMX_MPI
3324                       pme->mpi_comm_d[0],
3325 #endif
3326                       pme->nnodes_major, pme->nodeid_major,
3327                       pme->nkx,
3328                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3329
3330     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3331      * We do this with an offset buffer of equal size, so we need to allocate
3332      * extra for the offset. That's what the (+1)*pme->nkz is for.
3333      */
3334     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3335 #ifdef GMX_MPI
3336                       pme->mpi_comm_d[1],
3337 #endif
3338                       pme->nnodes_minor, pme->nodeid_minor,
3339                       pme->nky,
3340                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3341
3342     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3343      * Note that gmx_pme_check_restrictions checked for this already.
3344      */
3345     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3346     {
3347         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3348     }
3349
3350     snew(pme->bsp_mod[XX], pme->nkx);
3351     snew(pme->bsp_mod[YY], pme->nky);
3352     snew(pme->bsp_mod[ZZ], pme->nkz);
3353
3354     /* The required size of the interpolation grid, including overlap.
3355      * The allocated size (pmegrid_n?) might be slightly larger.
3356      */
3357     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3358         pme->overlap[0].s2g0[pme->nodeid_major];
3359     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3360         pme->overlap[1].s2g0[pme->nodeid_minor];
3361     pme->pmegrid_nz_base = pme->nkz;
3362     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3363     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3364
3365     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3366     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3367     pme->pmegrid_start_iz = 0;
3368
3369     make_gridindex5_to_localindex(pme->nkx,
3370                                   pme->pmegrid_start_ix,
3371                                   pme->pmegrid_nx - (pme->pme_order-1),
3372                                   &pme->nnx, &pme->fshx);
3373     make_gridindex5_to_localindex(pme->nky,
3374                                   pme->pmegrid_start_iy,
3375                                   pme->pmegrid_ny - (pme->pme_order-1),
3376                                   &pme->nny, &pme->fshy);
3377     make_gridindex5_to_localindex(pme->nkz,
3378                                   pme->pmegrid_start_iz,
3379                                   pme->pmegrid_nz_base,
3380                                   &pme->nnz, &pme->fshz);
3381
3382     pmegrids_init(&pme->pmegridA,
3383                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3384                   pme->pmegrid_nz_base,
3385                   pme->pme_order,
3386                   pme->bUseThreads,
3387                   pme->nthread,
3388                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3389                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3390
3391     pme->spline_work = make_pme_spline_work(pme->pme_order);
3392
3393     ndata[0] = pme->nkx;
3394     ndata[1] = pme->nky;
3395     ndata[2] = pme->nkz;
3396
3397     /* This routine will allocate the grid data to fit the FFTs */
3398     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3399                             &pme->fftgridA, &pme->cfftgridA,
3400                             pme->mpi_comm_d,
3401                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3402                             bReproducible, pme->nthread);
3403
3404     if (bFreeEnergy)
3405     {
3406         pmegrids_init(&pme->pmegridB,
3407                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3408                       pme->pmegrid_nz_base,
3409                       pme->pme_order,
3410                       pme->bUseThreads,
3411                       pme->nthread,
3412                       pme->nkx % pme->nnodes_major != 0,
3413                       pme->nky % pme->nnodes_minor != 0);
3414
3415         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3416                                 &pme->fftgridB, &pme->cfftgridB,
3417                                 pme->mpi_comm_d,
3418                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3419                                 bReproducible, pme->nthread);
3420     }
3421     else
3422     {
3423         pme->pmegridB.grid.grid = NULL;
3424         pme->fftgridB           = NULL;
3425         pme->cfftgridB          = NULL;
3426     }
3427
3428     if (!pme->bP3M)
3429     {
3430         /* Use plain SPME B-spline interpolation */
3431         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3432     }
3433     else
3434     {
3435         /* Use the P3M grid-optimized influence function */
3436         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3437     }
3438
3439     /* Use atc[0] for spreading */
3440     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3441     if (pme->ndecompdim >= 2)
3442     {
3443         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3444     }
3445
3446     if (pme->nnodes == 1)
3447     {
3448         pme->atc[0].n = homenr;
3449         pme_realloc_atomcomm_things(&pme->atc[0]);
3450     }
3451
3452     {
3453         int thread;
3454
3455         /* Use fft5d, order after FFT is y major, z, x minor */
3456
3457         snew(pme->work, pme->nthread);
3458         for (thread = 0; thread < pme->nthread; thread++)
3459         {
3460             realloc_work(&pme->work[thread], pme->nkx);
3461         }
3462     }
3463
3464     *pmedata = pme;
3465
3466     return 0;
3467 }
3468
3469 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3470 {
3471     int d, t;
3472
3473     for (d = 0; d < DIM; d++)
3474     {
3475         if (new->grid.n[d] > old->grid.n[d])
3476         {
3477             return;
3478         }
3479     }
3480
3481     sfree_aligned(new->grid.grid);
3482     new->grid.grid = old->grid.grid;
3483
3484     if (new->grid_th != NULL && new->nthread == old->nthread)
3485     {
3486         sfree_aligned(new->grid_all);
3487         for (t = 0; t < new->nthread; t++)
3488         {
3489             new->grid_th[t].grid = old->grid_th[t].grid;
3490         }
3491     }
3492 }
3493
3494 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3495                    t_commrec *         cr,
3496                    gmx_pme_t           pme_src,
3497                    const t_inputrec *  ir,
3498                    ivec                grid_size)
3499 {
3500     t_inputrec irc;
3501     int homenr;
3502     int ret;
3503
3504     irc     = *ir;
3505     irc.nkx = grid_size[XX];
3506     irc.nky = grid_size[YY];
3507     irc.nkz = grid_size[ZZ];
3508
3509     if (pme_src->nnodes == 1)
3510     {
3511         homenr = pme_src->atc[0].n;
3512     }
3513     else
3514     {
3515         homenr = -1;
3516     }
3517
3518     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3519                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3520
3521     if (ret == 0)
3522     {
3523         /* We can easily reuse the allocated pme grids in pme_src */
3524         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3525         /* We would like to reuse the fft grids, but that's harder */
3526     }
3527
3528     return ret;
3529 }
3530
3531
3532 static void copy_local_grid(gmx_pme_t pme,
3533                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3534 {
3535     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3536     int  fft_my, fft_mz;
3537     int  nsx, nsy, nsz;
3538     ivec nf;
3539     int  offx, offy, offz, x, y, z, i0, i0t;
3540     int  d;
3541     pmegrid_t *pmegrid;
3542     real *grid_th;
3543
3544     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3545                                    local_fft_ndata,
3546                                    local_fft_offset,
3547                                    local_fft_size);
3548     fft_my = local_fft_size[YY];
3549     fft_mz = local_fft_size[ZZ];
3550
3551     pmegrid = &pmegrids->grid_th[thread];
3552
3553     nsx = pmegrid->s[XX];
3554     nsy = pmegrid->s[YY];
3555     nsz = pmegrid->s[ZZ];
3556
3557     for (d = 0; d < DIM; d++)
3558     {
3559         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3560                     local_fft_ndata[d] - pmegrid->offset[d]);
3561     }
3562
3563     offx = pmegrid->offset[XX];
3564     offy = pmegrid->offset[YY];
3565     offz = pmegrid->offset[ZZ];
3566
3567     /* Directly copy the non-overlapping parts of the local grids.
3568      * This also initializes the full grid.
3569      */
3570     grid_th = pmegrid->grid;
3571     for (x = 0; x < nf[XX]; x++)
3572     {
3573         for (y = 0; y < nf[YY]; y++)
3574         {
3575             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3576             i0t = (x*nsy + y)*nsz;
3577             for (z = 0; z < nf[ZZ]; z++)
3578             {
3579                 fftgrid[i0+z] = grid_th[i0t+z];
3580             }
3581         }
3582     }
3583 }
3584
3585 static void
3586 reduce_threadgrid_overlap(gmx_pme_t pme,
3587                           const pmegrids_t *pmegrids, int thread,
3588                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3589 {
3590     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3591     int  fft_nx, fft_ny, fft_nz;
3592     int  fft_my, fft_mz;
3593     int  buf_my = -1;
3594     int  nsx, nsy, nsz;
3595     ivec localcopy_end, commcopy_end;
3596     int  offx, offy, offz, x, y, z, i0, i0t;
3597     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3598     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3599     gmx_bool bCommX, bCommY;
3600     int  d;
3601     int  thread_f;
3602     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3603     const real *grid_th;
3604     real *commbuf = NULL;
3605
3606     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3607                                    local_fft_ndata,
3608                                    local_fft_offset,
3609                                    local_fft_size);
3610     fft_nx = local_fft_ndata[XX];
3611     fft_ny = local_fft_ndata[YY];
3612     fft_nz = local_fft_ndata[ZZ];
3613
3614     fft_my = local_fft_size[YY];
3615     fft_mz = local_fft_size[ZZ];
3616
3617     /* This routine is called when all thread have finished spreading.
3618      * Here each thread sums grid contributions calculated by other threads
3619      * to the thread local grid volume.
3620      * To minimize the number of grid copying operations,
3621      * this routines sums immediately from the pmegrid to the fftgrid.
3622      */
3623
3624     /* Determine which part of the full node grid we should operate on,
3625      * this is our thread local part of the full grid.
3626      */
3627     pmegrid = &pmegrids->grid_th[thread];
3628
3629     for (d = 0; d < DIM; d++)
3630     {
3631         /* Determine up to where our thread needs to copy from the
3632          * thread-local charge spreading grid to the rank-local FFT grid.
3633          * This is up to our spreading grid end minus order-1 and
3634          * not beyond the local FFT grid.
3635          */
3636         localcopy_end[d] =
3637             min(pmegrid->offset[d] + pmegrid->n[d] - (pmegrid->order - 1),
3638                 local_fft_ndata[d]);
3639
3640         /* Determine up to where our thread needs to copy from the
3641          * thread-local charge spreading grid to the communication buffer.
3642          * Note: only relevant with communication, ignored otherwise.
3643          */
3644         commcopy_end[d]  = localcopy_end[d];
3645         if (pmegrid->ci[d] == pmegrids->nc[d] - 1)
3646         {
3647             /* The last thread should copy up to the last pme grid line.
3648              * When the rank-local FFT grid is narrower than pme-order,
3649              * we need the max below to ensure copying of all data.
3650              */
3651             commcopy_end[d] = max(commcopy_end[d], pme->pme_order);
3652         }
3653     }
3654
3655     offx = pmegrid->offset[XX];
3656     offy = pmegrid->offset[YY];
3657     offz = pmegrid->offset[ZZ];
3658
3659
3660     bClearBufX  = TRUE;
3661     bClearBufY  = TRUE;
3662     bClearBufXY = TRUE;
3663
3664     /* Now loop over all the thread data blocks that contribute
3665      * to the grid region we (our thread) are operating on.
3666      */
3667     /* Note that fft_nx/y is equal to the number of grid points
3668      * between the first point of our node grid and the one of the next node.
3669      */
3670     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3671     {
3672         fx     = pmegrid->ci[XX] + sx;
3673         ox     = 0;
3674         bCommX = FALSE;
3675         if (fx < 0)
3676         {
3677             fx    += pmegrids->nc[XX];
3678             ox    -= fft_nx;
3679             bCommX = (pme->nnodes_major > 1);
3680         }
3681         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3682         ox       += pmegrid_g->offset[XX];
3683         /* Determine the end of our part of the source grid.
3684          * Use our thread local source grid and target grid part
3685          */
3686         tx1 = min(ox + pmegrid_g->n[XX],
3687                   !bCommX ? localcopy_end[XX] : commcopy_end[XX]);
3688
3689         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3690         {
3691             fy     = pmegrid->ci[YY] + sy;
3692             oy     = 0;
3693             bCommY = FALSE;
3694             if (fy < 0)
3695             {
3696                 fy    += pmegrids->nc[YY];
3697                 oy    -= fft_ny;
3698                 bCommY = (pme->nnodes_minor > 1);
3699             }
3700             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3701             oy       += pmegrid_g->offset[YY];
3702             /* Determine the end of our part of the source grid.
3703              * Use our thread local source grid and target grid part
3704              */
3705             ty1 = min(oy + pmegrid_g->n[YY],
3706                       !bCommY ? localcopy_end[YY] : commcopy_end[YY]);
3707
3708             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3709             {
3710                 fz = pmegrid->ci[ZZ] + sz;
3711                 oz = 0;
3712                 if (fz < 0)
3713                 {
3714                     fz += pmegrids->nc[ZZ];
3715                     oz -= fft_nz;
3716                 }
3717                 pmegrid_g = &pmegrids->grid_th[fz];
3718                 oz       += pmegrid_g->offset[ZZ];
3719                 tz1       = min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
3720
3721                 if (sx == 0 && sy == 0 && sz == 0)
3722                 {
3723                     /* We have already added our local contribution
3724                      * before calling this routine, so skip it here.
3725                      */
3726                     continue;
3727                 }
3728
3729                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3730
3731                 pmegrid_f = &pmegrids->grid_th[thread_f];
3732
3733                 grid_th = pmegrid_f->grid;
3734
3735                 nsx = pmegrid_f->s[XX];
3736                 nsy = pmegrid_f->s[YY];
3737                 nsz = pmegrid_f->s[ZZ];
3738
3739 #ifdef DEBUG_PME_REDUCE
3740                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3741                        pme->nodeid, thread, thread_f,
3742                        pme->pmegrid_start_ix,
3743                        pme->pmegrid_start_iy,
3744                        pme->pmegrid_start_iz,
3745                        sx, sy, sz,
3746                        offx-ox, tx1-ox, offx, tx1,
3747                        offy-oy, ty1-oy, offy, ty1,
3748                        offz-oz, tz1-oz, offz, tz1);
3749 #endif
3750
3751                 if (!(bCommX || bCommY))
3752                 {
3753                     /* Copy from the thread local grid to the node grid */
3754                     for (x = offx; x < tx1; x++)
3755                     {
3756                         for (y = offy; y < ty1; y++)
3757                         {
3758                             i0  = (x*fft_my + y)*fft_mz;
3759                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3760                             for (z = offz; z < tz1; z++)
3761                             {
3762                                 fftgrid[i0+z] += grid_th[i0t+z];
3763                             }
3764                         }
3765                     }
3766                 }
3767                 else
3768                 {
3769                     /* The order of this conditional decides
3770                      * where the corner volume gets stored with x+y decomp.
3771                      */
3772                     if (bCommY)
3773                     {
3774                         commbuf = commbuf_y;
3775                         /* The y-size of the communication buffer is set by
3776                          * the overlap of the grid part of our local slab
3777                          * with the part starting at the next slab.
3778                          */
3779                         buf_my  =
3780                             pme->overlap[1].s2g1[pme->nodeid_minor] -
3781                             pme->overlap[1].s2g0[pme->nodeid_minor+1];
3782                         if (bCommX)
3783                         {
3784                             /* We index commbuf modulo the local grid size */
3785                             commbuf += buf_my*fft_nx*fft_nz;
3786
3787                             bClearBuf   = bClearBufXY;
3788                             bClearBufXY = FALSE;
3789                         }
3790                         else
3791                         {
3792                             bClearBuf  = bClearBufY;
3793                             bClearBufY = FALSE;
3794                         }
3795                     }
3796                     else
3797                     {
3798                         commbuf    = commbuf_x;
3799                         buf_my     = fft_ny;
3800                         bClearBuf  = bClearBufX;
3801                         bClearBufX = FALSE;
3802                     }
3803
3804                     /* Copy to the communication buffer */
3805                     for (x = offx; x < tx1; x++)
3806                     {
3807                         for (y = offy; y < ty1; y++)
3808                         {
3809                             i0  = (x*buf_my + y)*fft_nz;
3810                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3811
3812                             if (bClearBuf)
3813                             {
3814                                 /* First access of commbuf, initialize it */
3815                                 for (z = offz; z < tz1; z++)
3816                                 {
3817                                     commbuf[i0+z]  = grid_th[i0t+z];
3818                                 }
3819                             }
3820                             else
3821                             {
3822                                 for (z = offz; z < tz1; z++)
3823                                 {
3824                                     commbuf[i0+z] += grid_th[i0t+z];
3825                                 }
3826                             }
3827                         }
3828                     }
3829                 }
3830             }
3831         }
3832     }
3833 }
3834
3835
3836 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3837 {
3838     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3839     pme_overlap_t *overlap;
3840     int  send_index0, send_nindex;
3841     int  recv_nindex;
3842 #ifdef GMX_MPI
3843     MPI_Status stat;
3844 #endif
3845     int  send_size_y, recv_size_y;
3846     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3847     real *sendptr, *recvptr;
3848     int  x, y, z, indg, indb;
3849
3850     /* Note that this routine is only used for forward communication.
3851      * Since the force gathering, unlike the charge spreading,
3852      * can be trivially parallelized over the particles,
3853      * the backwards process is much simpler and can use the "old"
3854      * communication setup.
3855      */
3856
3857     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3858                                    local_fft_ndata,
3859                                    local_fft_offset,
3860                                    local_fft_size);
3861
3862     if (pme->nnodes_minor > 1)
3863     {
3864         /* Major dimension */
3865         overlap = &pme->overlap[1];
3866
3867         if (pme->nnodes_major > 1)
3868         {
3869             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3870         }
3871         else
3872         {
3873             size_yx = 0;
3874         }
3875         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3876
3877         send_size_y = overlap->send_size;
3878
3879         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3880         {
3881             send_id       = overlap->send_id[ipulse];
3882             recv_id       = overlap->recv_id[ipulse];
3883             send_index0   =
3884                 overlap->comm_data[ipulse].send_index0 -
3885                 overlap->comm_data[0].send_index0;
3886             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3887             /* We don't use recv_index0, as we always receive starting at 0 */
3888             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3889             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3890
3891             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3892             recvptr = overlap->recvbuf;
3893
3894             if (debug != NULL)
3895             {
3896                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n",
3897                         local_fft_ndata[XX], send_nindex, local_fft_ndata[ZZ]);
3898             }
3899
3900 #ifdef GMX_MPI
3901             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3902                          send_id, ipulse,
3903                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3904                          recv_id, ipulse,
3905                          overlap->mpi_comm, &stat);
3906 #endif
3907
3908             for (x = 0; x < local_fft_ndata[XX]; x++)
3909             {
3910                 for (y = 0; y < recv_nindex; y++)
3911                 {
3912                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3913                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3914                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3915                     {
3916                         fftgrid[indg+z] += recvptr[indb+z];
3917                     }
3918                 }
3919             }
3920
3921             if (pme->nnodes_major > 1)
3922             {
3923                 /* Copy from the received buffer to the send buffer for dim 0 */
3924                 sendptr = pme->overlap[0].sendbuf;
3925                 for (x = 0; x < size_yx; x++)
3926                 {
3927                     for (y = 0; y < recv_nindex; y++)
3928                     {
3929                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3930                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3931                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3932                         {
3933                             sendptr[indg+z] += recvptr[indb+z];
3934                         }
3935                     }
3936                 }
3937             }
3938         }
3939     }
3940
3941     /* We only support a single pulse here.
3942      * This is not a severe limitation, as this code is only used
3943      * with OpenMP and with OpenMP the (PME) domains can be larger.
3944      */
3945     if (pme->nnodes_major > 1)
3946     {
3947         /* Major dimension */
3948         overlap = &pme->overlap[0];
3949
3950         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3951         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3952
3953         ipulse = 0;
3954
3955         send_id       = overlap->send_id[ipulse];
3956         recv_id       = overlap->recv_id[ipulse];
3957         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3958         /* We don't use recv_index0, as we always receive starting at 0 */
3959         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3960
3961         sendptr = overlap->sendbuf;
3962         recvptr = overlap->recvbuf;
3963
3964         if (debug != NULL)
3965         {
3966             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n",
3967                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3968         }
3969
3970 #ifdef GMX_MPI
3971         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3972                      send_id, ipulse,
3973                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3974                      recv_id, ipulse,
3975                      overlap->mpi_comm, &stat);
3976 #endif
3977
3978         for (x = 0; x < recv_nindex; x++)
3979         {
3980             for (y = 0; y < local_fft_ndata[YY]; y++)
3981             {
3982                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3983                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3984                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3985                 {
3986                     fftgrid[indg+z] += recvptr[indb+z];
3987                 }
3988             }
3989         }
3990     }
3991 }
3992
3993
3994 static void spread_on_grid(gmx_pme_t pme,
3995                            pme_atomcomm_t *atc, pmegrids_t *grids,
3996                            gmx_bool bCalcSplines, gmx_bool bSpread,
3997                            real *fftgrid)
3998 {
3999     int nthread, thread;
4000 #ifdef PME_TIME_THREADS
4001     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4002     static double cs1     = 0, cs2 = 0, cs3 = 0;
4003     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4004     static int cnt        = 0;
4005 #endif
4006
4007     nthread = pme->nthread;
4008     assert(nthread > 0);
4009
4010 #ifdef PME_TIME_THREADS
4011     c1 = omp_cyc_start();
4012 #endif
4013     if (bCalcSplines)
4014     {
4015 #pragma omp parallel for num_threads(nthread) schedule(static)
4016         for (thread = 0; thread < nthread; thread++)
4017         {
4018             int start, end;
4019
4020             start = atc->n* thread   /nthread;
4021             end   = atc->n*(thread+1)/nthread;
4022
4023             /* Compute fftgrid index for all atoms,
4024              * with help of some extra variables.
4025              */
4026             calc_interpolation_idx(pme, atc, start, end, thread);
4027         }
4028     }
4029 #ifdef PME_TIME_THREADS
4030     c1   = omp_cyc_end(c1);
4031     cs1 += (double)c1;
4032 #endif
4033
4034 #ifdef PME_TIME_THREADS
4035     c2 = omp_cyc_start();
4036 #endif
4037 #pragma omp parallel for num_threads(nthread) schedule(static)
4038     for (thread = 0; thread < nthread; thread++)
4039     {
4040         splinedata_t *spline;
4041         pmegrid_t *grid = NULL;
4042
4043         /* make local bsplines  */
4044         if (grids == NULL || !pme->bUseThreads)
4045         {
4046             spline = &atc->spline[0];
4047
4048             spline->n = atc->n;
4049
4050             if (bSpread)
4051             {
4052                 grid = &grids->grid;
4053             }
4054         }
4055         else
4056         {
4057             spline = &atc->spline[thread];
4058
4059             if (grids->nthread == 1)
4060             {
4061                 /* One thread, we operate on all charges */
4062                 spline->n = atc->n;
4063             }
4064             else
4065             {
4066                 /* Get the indices our thread should operate on */
4067                 make_thread_local_ind(atc, thread, spline);
4068             }
4069
4070             grid = &grids->grid_th[thread];
4071         }
4072
4073         if (bCalcSplines)
4074         {
4075             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4076                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
4077         }
4078
4079         if (bSpread)
4080         {
4081             /* put local atoms on grid. */
4082 #ifdef PME_TIME_SPREAD
4083             ct1a = omp_cyc_start();
4084 #endif
4085             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4086
4087             if (pme->bUseThreads)
4088             {
4089                 copy_local_grid(pme, grids, thread, fftgrid);
4090             }
4091 #ifdef PME_TIME_SPREAD
4092             ct1a          = omp_cyc_end(ct1a);
4093             cs1a[thread] += (double)ct1a;
4094 #endif
4095         }
4096     }
4097 #ifdef PME_TIME_THREADS
4098     c2   = omp_cyc_end(c2);
4099     cs2 += (double)c2;
4100 #endif
4101
4102     if (bSpread && pme->bUseThreads)
4103     {
4104 #ifdef PME_TIME_THREADS
4105         c3 = omp_cyc_start();
4106 #endif
4107 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4108         for (thread = 0; thread < grids->nthread; thread++)
4109         {
4110             reduce_threadgrid_overlap(pme, grids, thread,
4111                                       fftgrid,
4112                                       pme->overlap[0].sendbuf,
4113                                       pme->overlap[1].sendbuf);
4114         }
4115 #ifdef PME_TIME_THREADS
4116         c3   = omp_cyc_end(c3);
4117         cs3 += (double)c3;
4118 #endif
4119
4120         if (pme->nnodes > 1)
4121         {
4122             /* Communicate the overlapping part of the fftgrid.
4123              * For this communication call we need to check pme->bUseThreads
4124              * to have all ranks communicate here, regardless of pme->nthread.
4125              */
4126             sum_fftgrid_dd(pme, fftgrid);
4127         }
4128     }
4129
4130 #ifdef PME_TIME_THREADS
4131     cnt++;
4132     if (cnt % 20 == 0)
4133     {
4134         printf("idx %.2f spread %.2f red %.2f",
4135                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4136 #ifdef PME_TIME_SPREAD
4137         for (thread = 0; thread < nthread; thread++)
4138         {
4139             printf(" %.2f", cs1a[thread]*1e-9);
4140         }
4141 #endif
4142         printf("\n");
4143     }
4144 #endif
4145 }
4146
4147
4148 static void dump_grid(FILE *fp,
4149                       int sx, int sy, int sz, int nx, int ny, int nz,
4150                       int my, int mz, const real *g)
4151 {
4152     int x, y, z;
4153
4154     for (x = 0; x < nx; x++)
4155     {
4156         for (y = 0; y < ny; y++)
4157         {
4158             for (z = 0; z < nz; z++)
4159             {
4160                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4161                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4162             }
4163         }
4164     }
4165 }
4166
4167 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4168 {
4169     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4170
4171     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4172                                    local_fft_ndata,
4173                                    local_fft_offset,
4174                                    local_fft_size);
4175
4176     dump_grid(stderr,
4177               pme->pmegrid_start_ix,
4178               pme->pmegrid_start_iy,
4179               pme->pmegrid_start_iz,
4180               pme->pmegrid_nx-pme->pme_order+1,
4181               pme->pmegrid_ny-pme->pme_order+1,
4182               pme->pmegrid_nz-pme->pme_order+1,
4183               local_fft_size[YY],
4184               local_fft_size[ZZ],
4185               fftgrid);
4186 }
4187
4188
4189 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4190 {
4191     pme_atomcomm_t *atc;
4192     pmegrids_t *grid;
4193
4194     if (pme->nnodes > 1)
4195     {
4196         gmx_incons("gmx_pme_calc_energy called in parallel");
4197     }
4198     if (pme->bFEP > 1)
4199     {
4200         gmx_incons("gmx_pme_calc_energy with free energy");
4201     }
4202
4203     atc            = &pme->atc_energy;
4204     atc->nthread   = 1;
4205     if (atc->spline == NULL)
4206     {
4207         snew(atc->spline, atc->nthread);
4208     }
4209     atc->nslab     = 1;
4210     atc->bSpread   = TRUE;
4211     atc->pme_order = pme->pme_order;
4212     atc->n         = n;
4213     pme_realloc_atomcomm_things(atc);
4214     atc->x         = x;
4215     atc->q         = q;
4216
4217     /* We only use the A-charges grid */
4218     grid = &pme->pmegridA;
4219
4220     /* Only calculate the spline coefficients, don't actually spread */
4221     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4222
4223     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4224 }
4225
4226
4227 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4228                                    gmx_runtime_t *runtime,
4229                                    t_nrnb *nrnb, t_inputrec *ir,
4230                                    gmx_large_int_t step)
4231 {
4232     /* Reset all the counters related to performance over the run */
4233     wallcycle_stop(wcycle, ewcRUN);
4234     wallcycle_reset_all(wcycle);
4235     init_nrnb(nrnb);
4236     if (ir->nsteps >= 0)
4237     {
4238         /* ir->nsteps is not used here, but we update it for consistency */
4239         ir->nsteps -= step - ir->init_step;
4240     }
4241     ir->init_step = step;
4242     wallcycle_start(wcycle, ewcRUN);
4243     runtime_start(runtime);
4244 }
4245
4246
4247 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4248                                ivec grid_size,
4249                                t_commrec *cr, t_inputrec *ir,
4250                                gmx_pme_t *pme_ret)
4251 {
4252     int ind;
4253     gmx_pme_t pme = NULL;
4254
4255     ind = 0;
4256     while (ind < *npmedata)
4257     {
4258         pme = (*pmedata)[ind];
4259         if (pme->nkx == grid_size[XX] &&
4260             pme->nky == grid_size[YY] &&
4261             pme->nkz == grid_size[ZZ])
4262         {
4263             *pme_ret = pme;
4264
4265             return;
4266         }
4267
4268         ind++;
4269     }
4270
4271     (*npmedata)++;
4272     srenew(*pmedata, *npmedata);
4273
4274     /* Generate a new PME data structure, copying part of the old pointers */
4275     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4276
4277     *pme_ret = (*pmedata)[ind];
4278 }
4279
4280
4281 int gmx_pmeonly(gmx_pme_t pme,
4282                 t_commrec *cr,    t_nrnb *nrnb,
4283                 gmx_wallcycle_t wcycle,
4284                 gmx_runtime_t *runtime,
4285                 real ewaldcoeff,  gmx_bool bGatherOnly,
4286                 t_inputrec *ir)
4287 {
4288     int npmedata;
4289     gmx_pme_t *pmedata;
4290     gmx_pme_pp_t pme_pp;
4291     int  ret;
4292     int  natoms;
4293     matrix box;
4294     rvec *x_pp      = NULL, *f_pp = NULL;
4295     real *chargeA   = NULL, *chargeB = NULL;
4296     real lambda     = 0;
4297     int  maxshift_x = 0, maxshift_y = 0;
4298     real energy, dvdlambda;
4299     matrix vir;
4300     float cycles;
4301     int  count;
4302     gmx_bool bEnerVir;
4303     gmx_large_int_t step, step_rel;
4304     ivec grid_switch;
4305
4306     /* This data will only use with PME tuning, i.e. switching PME grids */
4307     npmedata = 1;
4308     snew(pmedata, npmedata);
4309     pmedata[0] = pme;
4310
4311     pme_pp = gmx_pme_pp_init(cr);
4312
4313     init_nrnb(nrnb);
4314
4315     count = 0;
4316     do /****** this is a quasi-loop over time steps! */
4317     {
4318         /* The reason for having a loop here is PME grid tuning/switching */
4319         do
4320         {
4321             /* Domain decomposition */
4322             ret = gmx_pme_recv_q_x(pme_pp,
4323                                    &natoms,
4324                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4325                                    &maxshift_x, &maxshift_y,
4326                                    &pme->bFEP, &lambda,
4327                                    &bEnerVir,
4328                                    &step,
4329                                    grid_switch, &ewaldcoeff);
4330
4331             if (ret == pmerecvqxSWITCHGRID)
4332             {
4333                 /* Switch the PME grid to grid_switch */
4334                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4335             }
4336
4337             if (ret == pmerecvqxRESETCOUNTERS)
4338             {
4339                 /* Reset the cycle and flop counters */
4340                 reset_pmeonly_counters(cr, wcycle, runtime, nrnb, ir, step);
4341             }
4342         }
4343         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4344
4345         if (ret == pmerecvqxFINISH)
4346         {
4347             /* We should stop: break out of the loop */
4348             break;
4349         }
4350
4351         step_rel = step - ir->init_step;
4352
4353         if (count == 0)
4354         {
4355             wallcycle_start(wcycle, ewcRUN);
4356             runtime_start(runtime);
4357         }
4358
4359         wallcycle_start(wcycle, ewcPMEMESH);
4360
4361         dvdlambda = 0;
4362         clear_mat(vir);
4363         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4364                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4365                    &energy, lambda, &dvdlambda,
4366                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4367
4368         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4369
4370         gmx_pme_send_force_vir_ener(pme_pp,
4371                                     f_pp, vir, energy, dvdlambda,
4372                                     cycles);
4373
4374         count++;
4375     } /***** end of quasi-loop, we stop with the break above */
4376     while (TRUE);
4377
4378     runtime_end(runtime);
4379
4380     return 0;
4381 }
4382
4383 int gmx_pme_do(gmx_pme_t pme,
4384                int start,       int homenr,
4385                rvec x[],        rvec f[],
4386                real *chargeA,   real *chargeB,
4387                matrix box, t_commrec *cr,
4388                int  maxshift_x, int maxshift_y,
4389                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4390                matrix vir,      real ewaldcoeff,
4391                real *energy,    real lambda,
4392                real *dvdlambda, int flags)
4393 {
4394     int     q, d, i, j, ntot, npme;
4395     int     nx, ny, nz;
4396     int     n_d, local_ny;
4397     pme_atomcomm_t *atc = NULL;
4398     pmegrids_t *pmegrid = NULL;
4399     real    *grid       = NULL;
4400     real    *ptr;
4401     rvec    *x_d, *f_d;
4402     real    *charge = NULL, *q_d;
4403     real    energy_AB[2];
4404     matrix  vir_AB[2];
4405     gmx_bool bClearF;
4406     gmx_parallel_3dfft_t pfft_setup;
4407     real *  fftgrid;
4408     t_complex * cfftgrid;
4409     int     thread;
4410     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4411     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4412
4413     assert(pme->nnodes > 0);
4414     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4415
4416     if (pme->nnodes > 1)
4417     {
4418         atc      = &pme->atc[0];
4419         atc->npd = homenr;
4420         if (atc->npd > atc->pd_nalloc)
4421         {
4422             atc->pd_nalloc = over_alloc_dd(atc->npd);
4423             srenew(atc->pd, atc->pd_nalloc);
4424         }
4425         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4426     }
4427     else
4428     {
4429         /* This could be necessary for TPI */
4430         pme->atc[0].n = homenr;
4431     }
4432
4433     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4434     {
4435         if (q == 0)
4436         {
4437             pmegrid    = &pme->pmegridA;
4438             fftgrid    = pme->fftgridA;
4439             cfftgrid   = pme->cfftgridA;
4440             pfft_setup = pme->pfft_setupA;
4441             charge     = chargeA+start;
4442         }
4443         else
4444         {
4445             pmegrid    = &pme->pmegridB;
4446             fftgrid    = pme->fftgridB;
4447             cfftgrid   = pme->cfftgridB;
4448             pfft_setup = pme->pfft_setupB;
4449             charge     = chargeB+start;
4450         }
4451         grid = pmegrid->grid.grid;
4452         /* Unpack structure */
4453         if (debug)
4454         {
4455             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4456                     cr->nnodes, cr->nodeid);
4457             fprintf(debug, "Grid = %p\n", (void*)grid);
4458             if (grid == NULL)
4459             {
4460                 gmx_fatal(FARGS, "No grid!");
4461             }
4462         }
4463         where();
4464
4465         m_inv_ur0(box, pme->recipbox);
4466
4467         if (pme->nnodes == 1)
4468         {
4469             atc = &pme->atc[0];
4470             if (DOMAINDECOMP(cr))
4471             {
4472                 atc->n = homenr;
4473                 pme_realloc_atomcomm_things(atc);
4474             }
4475             atc->x = x;
4476             atc->q = charge;
4477             atc->f = f;
4478         }
4479         else
4480         {
4481             wallcycle_start(wcycle, ewcPME_REDISTXF);
4482             for (d = pme->ndecompdim-1; d >= 0; d--)
4483             {
4484                 if (d == pme->ndecompdim-1)
4485                 {
4486                     n_d = homenr;
4487                     x_d = x + start;
4488                     q_d = charge;
4489                 }
4490                 else
4491                 {
4492                     n_d = pme->atc[d+1].n;
4493                     x_d = atc->x;
4494                     q_d = atc->q;
4495                 }
4496                 atc      = &pme->atc[d];
4497                 atc->npd = n_d;
4498                 if (atc->npd > atc->pd_nalloc)
4499                 {
4500                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4501                     srenew(atc->pd, atc->pd_nalloc);
4502                 }
4503                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4504                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4505                 where();
4506
4507                 GMX_BARRIER(cr->mpi_comm_mygroup);
4508                 /* Redistribute x (only once) and qA or qB */
4509                 if (DOMAINDECOMP(cr))
4510                 {
4511                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4512                 }
4513                 else
4514                 {
4515                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4516                 }
4517             }
4518             where();
4519
4520             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4521         }
4522
4523         if (debug)
4524         {
4525             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4526                     cr->nodeid, atc->n);
4527         }
4528
4529         if (flags & GMX_PME_SPREAD_Q)
4530         {
4531             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4532
4533             /* Spread the charges on a grid */
4534             GMX_MPE_LOG(ev_spread_on_grid_start);
4535
4536             /* Spread the charges on a grid */
4537             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4538             GMX_MPE_LOG(ev_spread_on_grid_finish);
4539
4540             if (q == 0)
4541             {
4542                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4543             }
4544             inc_nrnb(nrnb, eNR_SPREADQBSP,
4545                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4546
4547             if (!pme->bUseThreads)
4548             {
4549                 wrap_periodic_pmegrid(pme, grid);
4550
4551                 /* sum contributions to local grid from other nodes */
4552 #ifdef GMX_MPI
4553                 if (pme->nnodes > 1)
4554                 {
4555                     GMX_BARRIER(cr->mpi_comm_mygroup);
4556                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4557                     where();
4558                 }
4559 #endif
4560
4561                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4562             }
4563
4564             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4565
4566             /*
4567                dump_local_fftgrid(pme,fftgrid);
4568                exit(0);
4569              */
4570         }
4571
4572         /* Here we start a large thread parallel region */
4573 #pragma omp parallel num_threads(pme->nthread) private(thread)
4574         {
4575             thread = gmx_omp_get_thread_num();
4576             if (flags & GMX_PME_SOLVE)
4577             {
4578                 int loop_count;
4579
4580                 /* do 3d-fft */
4581                 if (thread == 0)
4582                 {
4583                     GMX_BARRIER(cr->mpi_comm_mygroup);
4584                     GMX_MPE_LOG(ev_gmxfft3d_start);
4585                     wallcycle_start(wcycle, ewcPME_FFT);
4586                 }
4587                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4588                                            fftgrid, cfftgrid, thread, wcycle);
4589                 if (thread == 0)
4590                 {
4591                     wallcycle_stop(wcycle, ewcPME_FFT);
4592                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4593                 }
4594                 where();
4595
4596                 /* solve in k-space for our local cells */
4597                 if (thread == 0)
4598                 {
4599                     GMX_BARRIER(cr->mpi_comm_mygroup);
4600                     GMX_MPE_LOG(ev_solve_pme_start);
4601                     wallcycle_start(wcycle, ewcPME_SOLVE);
4602                 }
4603                 loop_count =
4604                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4605                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4606                                   bCalcEnerVir,
4607                                   pme->nthread, thread);
4608                 if (thread == 0)
4609                 {
4610                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4611                     where();
4612                     GMX_MPE_LOG(ev_solve_pme_finish);
4613                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4614                 }
4615             }
4616
4617             if (bCalcF)
4618             {
4619                 /* do 3d-invfft */
4620                 if (thread == 0)
4621                 {
4622                     GMX_BARRIER(cr->mpi_comm_mygroup);
4623                     GMX_MPE_LOG(ev_gmxfft3d_start);
4624                     where();
4625                     wallcycle_start(wcycle, ewcPME_FFT);
4626                 }
4627                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4628                                            cfftgrid, fftgrid, thread, wcycle);
4629                 if (thread == 0)
4630                 {
4631                     wallcycle_stop(wcycle, ewcPME_FFT);
4632
4633                     where();
4634                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4635
4636                     if (pme->nodeid == 0)
4637                     {
4638                         ntot  = pme->nkx*pme->nky*pme->nkz;
4639                         npme  = ntot*log((real)ntot)/log(2.0);
4640                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4641                     }
4642
4643                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4644                 }
4645
4646                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4647             }
4648         }
4649         /* End of thread parallel section.
4650          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4651          */
4652
4653         if (bCalcF)
4654         {
4655             /* distribute local grid to all nodes */
4656 #ifdef GMX_MPI
4657             if (pme->nnodes > 1)
4658             {
4659                 GMX_BARRIER(cr->mpi_comm_mygroup);
4660                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4661             }
4662 #endif
4663             where();
4664
4665             unwrap_periodic_pmegrid(pme, grid);
4666
4667             /* interpolate forces for our local atoms */
4668             GMX_BARRIER(cr->mpi_comm_mygroup);
4669             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4670
4671             where();
4672
4673             /* If we are running without parallelization,
4674              * atc->f is the actual force array, not a buffer,
4675              * therefore we should not clear it.
4676              */
4677             bClearF = (q == 0 && PAR(cr));
4678 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4679             for (thread = 0; thread < pme->nthread; thread++)
4680             {
4681                 gather_f_bsplines(pme, grid, bClearF, atc,
4682                                   &atc->spline[thread],
4683                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4684             }
4685
4686             where();
4687
4688             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4689
4690             inc_nrnb(nrnb, eNR_GATHERFBSP,
4691                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4692             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4693         }
4694
4695         if (bCalcEnerVir)
4696         {
4697             /* This should only be called on the master thread
4698              * and after the threads have synchronized.
4699              */
4700             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4701         }
4702     } /* of q-loop */
4703
4704     if (bCalcF && pme->nnodes > 1)
4705     {
4706         wallcycle_start(wcycle, ewcPME_REDISTXF);
4707         for (d = 0; d < pme->ndecompdim; d++)
4708         {
4709             atc = &pme->atc[d];
4710             if (d == pme->ndecompdim - 1)
4711             {
4712                 n_d = homenr;
4713                 f_d = f + start;
4714             }
4715             else
4716             {
4717                 n_d = pme->atc[d+1].n;
4718                 f_d = pme->atc[d+1].f;
4719             }
4720             GMX_BARRIER(cr->mpi_comm_mygroup);
4721             if (DOMAINDECOMP(cr))
4722             {
4723                 dd_pmeredist_f(pme, atc, n_d, f_d,
4724                                d == pme->ndecompdim-1 && pme->bPPnode);
4725             }
4726             else
4727             {
4728                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4729             }
4730         }
4731
4732         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4733     }
4734     where();
4735
4736     if (bCalcEnerVir)
4737     {
4738         if (!pme->bFEP)
4739         {
4740             *energy = energy_AB[0];
4741             m_add(vir, vir_AB[0], vir);
4742         }
4743         else
4744         {
4745             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4746             *dvdlambda += energy_AB[1] - energy_AB[0];
4747             for (i = 0; i < DIM; i++)
4748             {
4749                 for (j = 0; j < DIM; j++)
4750                 {
4751                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4752                         lambda*vir_AB[1][i][j];
4753                 }
4754             }
4755         }
4756     }
4757     else
4758     {
4759         *energy = 0;
4760     }
4761
4762     if (debug)
4763     {
4764         fprintf(debug, "PME mesh energy: %g\n", *energy);
4765     }
4766
4767     return 0;
4768 }