Fixed PME bug with high OpenMP thread count
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team,
6  * check out http://www.gromacs.org for more information.
7  * Copyright (c) 2012,2013, by the GROMACS development team, led by
8  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
9  * others, as listed in the AUTHORS file in the top-level source
10  * directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* IMPORTANT FOR DEVELOPERS:
39  *
40  * Triclinic pme stuff isn't entirely trivial, and we've experienced
41  * some bugs during development (many of them due to me). To avoid
42  * this in the future, please check the following things if you make
43  * changes in this file:
44  *
45  * 1. You should obtain identical (at least to the PME precision)
46  *    energies, forces, and virial for
47  *    a rectangular box and a triclinic one where the z (or y) axis is
48  *    tilted a whole box side. For instance you could use these boxes:
49  *
50  *    rectangular       triclinic
51  *     2  0  0           2  0  0
52  *     0  2  0           0  2  0
53  *     0  0  6           2  2  6
54  *
55  * 2. You should check the energy conservation in a triclinic box.
56  *
57  * It might seem an overkill, but better safe than sorry.
58  * /Erik 001109
59  */
60
61 #ifdef HAVE_CONFIG_H
62 #include <config.h>
63 #endif
64
65 #ifdef GMX_LIB_MPI
66 #include <mpi.h>
67 #endif
68 #ifdef GMX_THREAD_MPI
69 #include "tmpi.h"
70 #endif
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93 #include "gmx_omp.h"
94
95 /* Include the SIMD macro file and then check for support */
96 #include "gmx_simd_macros.h"
97 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
98 /* Turn on arbitrary width SIMD intrinsics for PME solve */
99 #define PME_SIMD
100 #endif
101
102 /* Include the 4-wide SIMD macro file */
103 #include "gmx_simd4_macros.h"
104 /* Check if we have 4-wide SIMD macro support */
105 #ifdef GMX_HAVE_SIMD4_MACROS
106 /* Do PME spread and gather with 4-wide SIMD.
107  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
108  */
109 #define PME_SIMD4_SPREAD_GATHER
110
111 #ifdef GMX_SIMD4_HAVE_UNALIGNED
112 /* With PME-order=4 on x86, unaligned load+store is slightly faster
113  * than doubling all SIMD operations when using aligned load+store.
114  */
115 #define PME_SIMD4_UNALIGNED
116 #endif
117 #endif
118
119
120 #include "mpelogging.h"
121
122 #define DFT_TOL 1e-7
123 /* #define PRT_FORCE */
124 /* conditions for on the fly time-measurement */
125 /* #define TAKETIME (step > 1 && timesteps < 10) */
126 #define TAKETIME FALSE
127
128 /* #define PME_TIME_THREADS */
129
130 #ifdef GMX_DOUBLE
131 #define mpi_type MPI_DOUBLE
132 #else
133 #define mpi_type MPI_FLOAT
134 #endif
135
136 #ifdef PME_SIMD4_SPREAD_GATHER
137 #define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
138 #else
139 /* We can use any alignment, apart from 0, so we use 4 reals */
140 #define SIMD4_ALIGNMENT  (4*sizeof(real))
141 #endif
142
143 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
144  * to preserve alignment.
145  */
146 #define GMX_CACHE_SEP 64
147
148 /* We only define a maximum to be able to use local arrays without allocation.
149  * An order larger than 12 should never be needed, even for test cases.
150  * If needed it can be changed here.
151  */
152 #define PME_ORDER_MAX 12
153
154 /* Internal datastructures */
155 typedef struct {
156     int send_index0;
157     int send_nindex;
158     int recv_index0;
159     int recv_nindex;
160     int recv_size;   /* Receive buffer width, used with OpenMP */
161 } pme_grid_comm_t;
162
163 typedef struct {
164 #ifdef GMX_MPI
165     MPI_Comm         mpi_comm;
166 #endif
167     int              nnodes, nodeid;
168     int             *s2g0;
169     int             *s2g1;
170     int              noverlap_nodes;
171     int             *send_id, *recv_id;
172     int              send_size; /* Send buffer width, used with OpenMP */
173     pme_grid_comm_t *comm_data;
174     real            *sendbuf;
175     real            *recvbuf;
176 } pme_overlap_t;
177
178 typedef struct {
179     int *n;      /* Cumulative counts of the number of particles per thread */
180     int  nalloc; /* Allocation size of i */
181     int *i;      /* Particle indices ordered on thread index (n) */
182 } thread_plist_t;
183
184 typedef struct {
185     int      *thread_one;
186     int       n;
187     int      *ind;
188     splinevec theta;
189     real     *ptr_theta_z;
190     splinevec dtheta;
191     real     *ptr_dtheta_z;
192 } splinedata_t;
193
194 typedef struct {
195     int      dimind;        /* The index of the dimension, 0=x, 1=y */
196     int      nslab;
197     int      nodeid;
198 #ifdef GMX_MPI
199     MPI_Comm mpi_comm;
200 #endif
201
202     int     *node_dest;     /* The nodes to send x and q to with DD */
203     int     *node_src;      /* The nodes to receive x and q from with DD */
204     int     *buf_index;     /* Index for commnode into the buffers */
205
206     int      maxshift;
207
208     int      npd;
209     int      pd_nalloc;
210     int     *pd;
211     int     *count;         /* The number of atoms to send to each node */
212     int    **count_thread;
213     int     *rcount;        /* The number of atoms to receive */
214
215     int      n;
216     int      nalloc;
217     rvec    *x;
218     real    *q;
219     rvec    *f;
220     gmx_bool bSpread;       /* These coordinates are used for spreading */
221     int      pme_order;
222     ivec    *idx;
223     rvec    *fractx;            /* Fractional coordinate relative to the
224                                  * lower cell boundary
225                                  */
226     int             nthread;
227     int            *thread_idx; /* Which thread should spread which charge */
228     thread_plist_t *thread_plist;
229     splinedata_t   *spline;
230 } pme_atomcomm_t;
231
232 #define FLBS  3
233 #define FLBSZ 4
234
235 typedef struct {
236     ivec  ci;     /* The spatial location of this grid         */
237     ivec  n;      /* The used size of *grid, including order-1 */
238     ivec  offset; /* The grid offset from the full node grid   */
239     int   order;  /* PME spreading order                       */
240     ivec  s;      /* The allocated size of *grid, s >= n       */
241     real *grid;   /* The grid local thread, size n             */
242 } pmegrid_t;
243
244 typedef struct {
245     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
246     int        nthread;      /* The number of threads operating on this grid     */
247     ivec       nc;           /* The local spatial decomposition over the threads */
248     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
249     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
250     int      **g2t;          /* The grid to thread index                         */
251     ivec       nthread_comm; /* The number of threads to communicate with        */
252 } pmegrids_t;
253
254
255 typedef struct {
256 #ifdef PME_SIMD4_SPREAD_GATHER
257     /* Masks for 4-wide SIMD aligned spreading and gathering */
258     gmx_simd4_pb mask_S0[6], mask_S1[6];
259 #else
260     int    dummy; /* C89 requires that struct has at least one member */
261 #endif
262 } pme_spline_work_t;
263
264 typedef struct {
265     /* work data for solve_pme */
266     int      nalloc;
267     real *   mhx;
268     real *   mhy;
269     real *   mhz;
270     real *   m2;
271     real *   denom;
272     real *   tmp1_alloc;
273     real *   tmp1;
274     real *   eterm;
275     real *   m2inv;
276
277     real     energy;
278     matrix   vir;
279 } pme_work_t;
280
281 typedef struct gmx_pme {
282     int           ndecompdim; /* The number of decomposition dimensions */
283     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
284     int           nodeid_major;
285     int           nodeid_minor;
286     int           nnodes;    /* The number of nodes doing PME */
287     int           nnodes_major;
288     int           nnodes_minor;
289
290     MPI_Comm      mpi_comm;
291     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
292 #ifdef GMX_MPI
293     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
294 #endif
295
296     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
297     int        nthread;       /* The number of threads doing PME on our rank */
298
299     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
300     gmx_bool   bFEP;          /* Compute Free energy contribution */
301     int        nkx, nky, nkz; /* Grid dimensions */
302     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
303     int        pme_order;
304     real       epsilon_r;
305
306     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
307     pmegrids_t pmegridB;
308     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
309     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
310     /* pmegrid_nz might be larger than strictly necessary to ensure
311      * memory alignment, pmegrid_nz_base gives the real base size.
312      */
313     int     pmegrid_nz_base;
314     /* The local PME grid starting indices */
315     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
316
317     /* Work data for spreading and gathering */
318     pme_spline_work_t    *spline_work;
319
320     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
321     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
322     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
323
324     t_complex            *cfftgridA;  /* Grids for complex FFT data */
325     t_complex            *cfftgridB;
326     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
327
328     gmx_parallel_3dfft_t  pfft_setupA;
329     gmx_parallel_3dfft_t  pfft_setupB;
330
331     int                  *nnx, *nny, *nnz;
332     real                 *fshx, *fshy, *fshz;
333
334     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
335     matrix                recipbox;
336     splinevec             bsp_mod;
337
338     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
339
340     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
341
342     rvec                 *bufv;       /* Communication buffer */
343     real                 *bufr;       /* Communication buffer */
344     int                   buf_nalloc; /* The communication buffer size */
345
346     /* thread local work data for solve_pme */
347     pme_work_t *work;
348
349     /* Work data for PME_redist */
350     gmx_bool redist_init;
351     int *    scounts;
352     int *    rcounts;
353     int *    sdispls;
354     int *    rdispls;
355     int *    sidx;
356     int *    idxa;
357     real *   redist_buf;
358     int      redist_buf_nalloc;
359
360     /* Work data for sum_qgrid */
361     real *   sum_qgrid_tmp;
362     real *   sum_qgrid_dd_tmp;
363 } t_gmx_pme;
364
365
366 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
367                                    int start, int end, int thread)
368 {
369     int             i;
370     int            *idxptr, tix, tiy, tiz;
371     real           *xptr, *fptr, tx, ty, tz;
372     real            rxx, ryx, ryy, rzx, rzy, rzz;
373     int             nx, ny, nz;
374     int             start_ix, start_iy, start_iz;
375     int            *g2tx, *g2ty, *g2tz;
376     gmx_bool        bThreads;
377     int            *thread_idx = NULL;
378     thread_plist_t *tpl        = NULL;
379     int            *tpl_n      = NULL;
380     int             thread_i;
381
382     nx  = pme->nkx;
383     ny  = pme->nky;
384     nz  = pme->nkz;
385
386     start_ix = pme->pmegrid_start_ix;
387     start_iy = pme->pmegrid_start_iy;
388     start_iz = pme->pmegrid_start_iz;
389
390     rxx = pme->recipbox[XX][XX];
391     ryx = pme->recipbox[YY][XX];
392     ryy = pme->recipbox[YY][YY];
393     rzx = pme->recipbox[ZZ][XX];
394     rzy = pme->recipbox[ZZ][YY];
395     rzz = pme->recipbox[ZZ][ZZ];
396
397     g2tx = pme->pmegridA.g2t[XX];
398     g2ty = pme->pmegridA.g2t[YY];
399     g2tz = pme->pmegridA.g2t[ZZ];
400
401     bThreads = (atc->nthread > 1);
402     if (bThreads)
403     {
404         thread_idx = atc->thread_idx;
405
406         tpl   = &atc->thread_plist[thread];
407         tpl_n = tpl->n;
408         for (i = 0; i < atc->nthread; i++)
409         {
410             tpl_n[i] = 0;
411         }
412     }
413
414     for (i = start; i < end; i++)
415     {
416         xptr   = atc->x[i];
417         idxptr = atc->idx[i];
418         fptr   = atc->fractx[i];
419
420         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
421         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
422         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
423         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
424
425         tix = (int)(tx);
426         tiy = (int)(ty);
427         tiz = (int)(tz);
428
429         /* Because decomposition only occurs in x and y,
430          * we never have a fraction correction in z.
431          */
432         fptr[XX] = tx - tix + pme->fshx[tix];
433         fptr[YY] = ty - tiy + pme->fshy[tiy];
434         fptr[ZZ] = tz - tiz;
435
436         idxptr[XX] = pme->nnx[tix];
437         idxptr[YY] = pme->nny[tiy];
438         idxptr[ZZ] = pme->nnz[tiz];
439
440 #ifdef DEBUG
441         range_check(idxptr[XX], 0, pme->pmegrid_nx);
442         range_check(idxptr[YY], 0, pme->pmegrid_ny);
443         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
444 #endif
445
446         if (bThreads)
447         {
448             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
449             thread_idx[i] = thread_i;
450             tpl_n[thread_i]++;
451         }
452     }
453
454     if (bThreads)
455     {
456         /* Make a list of particle indices sorted on thread */
457
458         /* Get the cumulative count */
459         for (i = 1; i < atc->nthread; i++)
460         {
461             tpl_n[i] += tpl_n[i-1];
462         }
463         /* The current implementation distributes particles equally
464          * over the threads, so we could actually allocate for that
465          * in pme_realloc_atomcomm_things.
466          */
467         if (tpl_n[atc->nthread-1] > tpl->nalloc)
468         {
469             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
470             srenew(tpl->i, tpl->nalloc);
471         }
472         /* Set tpl_n to the cumulative start */
473         for (i = atc->nthread-1; i >= 1; i--)
474         {
475             tpl_n[i] = tpl_n[i-1];
476         }
477         tpl_n[0] = 0;
478
479         /* Fill our thread local array with indices sorted on thread */
480         for (i = start; i < end; i++)
481         {
482             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
483         }
484         /* Now tpl_n contains the cummulative count again */
485     }
486 }
487
488 static void make_thread_local_ind(pme_atomcomm_t *atc,
489                                   int thread, splinedata_t *spline)
490 {
491     int             n, t, i, start, end;
492     thread_plist_t *tpl;
493
494     /* Combine the indices made by each thread into one index */
495
496     n     = 0;
497     start = 0;
498     for (t = 0; t < atc->nthread; t++)
499     {
500         tpl = &atc->thread_plist[t];
501         /* Copy our part (start - end) from the list of thread t */
502         if (thread > 0)
503         {
504             start = tpl->n[thread-1];
505         }
506         end = tpl->n[thread];
507         for (i = start; i < end; i++)
508         {
509             spline->ind[n++] = tpl->i[i];
510         }
511     }
512
513     spline->n = n;
514 }
515
516
517 static void pme_calc_pidx(int start, int end,
518                           matrix recipbox, rvec x[],
519                           pme_atomcomm_t *atc, int *count)
520 {
521     int   nslab, i;
522     int   si;
523     real *xptr, s;
524     real  rxx, ryx, rzx, ryy, rzy;
525     int  *pd;
526
527     /* Calculate PME task index (pidx) for each grid index.
528      * Here we always assign equally sized slabs to each node
529      * for load balancing reasons (the PME grid spacing is not used).
530      */
531
532     nslab = atc->nslab;
533     pd    = atc->pd;
534
535     /* Reset the count */
536     for (i = 0; i < nslab; i++)
537     {
538         count[i] = 0;
539     }
540
541     if (atc->dimind == 0)
542     {
543         rxx = recipbox[XX][XX];
544         ryx = recipbox[YY][XX];
545         rzx = recipbox[ZZ][XX];
546         /* Calculate the node index in x-dimension */
547         for (i = start; i < end; i++)
548         {
549             xptr   = x[i];
550             /* Fractional coordinates along box vectors */
551             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
552             si    = (int)(s + 2*nslab) % nslab;
553             pd[i] = si;
554             count[si]++;
555         }
556     }
557     else
558     {
559         ryy = recipbox[YY][YY];
560         rzy = recipbox[ZZ][YY];
561         /* Calculate the node index in y-dimension */
562         for (i = start; i < end; i++)
563         {
564             xptr   = x[i];
565             /* Fractional coordinates along box vectors */
566             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
567             si    = (int)(s + 2*nslab) % nslab;
568             pd[i] = si;
569             count[si]++;
570         }
571     }
572 }
573
574 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
575                                   pme_atomcomm_t *atc)
576 {
577     int nthread, thread, slab;
578
579     nthread = atc->nthread;
580
581 #pragma omp parallel for num_threads(nthread) schedule(static)
582     for (thread = 0; thread < nthread; thread++)
583     {
584         pme_calc_pidx(natoms* thread   /nthread,
585                       natoms*(thread+1)/nthread,
586                       recipbox, x, atc, atc->count_thread[thread]);
587     }
588     /* Non-parallel reduction, since nslab is small */
589
590     for (thread = 1; thread < nthread; thread++)
591     {
592         for (slab = 0; slab < atc->nslab; slab++)
593         {
594             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
595         }
596     }
597 }
598
599 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
600 {
601     const int padding = 4;
602     int       i;
603
604     srenew(th[XX], nalloc);
605     srenew(th[YY], nalloc);
606     /* In z we add padding, this is only required for the aligned SIMD code */
607     sfree_aligned(*ptr_z);
608     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
609     th[ZZ] = *ptr_z + padding;
610
611     for (i = 0; i < padding; i++)
612     {
613         (*ptr_z)[               i] = 0;
614         (*ptr_z)[padding+nalloc+i] = 0;
615     }
616 }
617
618 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
619 {
620     int i, d;
621
622     srenew(spline->ind, atc->nalloc);
623     /* Initialize the index to identity so it works without threads */
624     for (i = 0; i < atc->nalloc; i++)
625     {
626         spline->ind[i] = i;
627     }
628
629     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
630                       atc->pme_order*atc->nalloc);
631     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
632                       atc->pme_order*atc->nalloc);
633 }
634
635 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
636 {
637     int nalloc_old, i, j, nalloc_tpl;
638
639     /* We have to avoid a NULL pointer for atc->x to avoid
640      * possible fatal errors in MPI routines.
641      */
642     if (atc->n > atc->nalloc || atc->nalloc == 0)
643     {
644         nalloc_old  = atc->nalloc;
645         atc->nalloc = over_alloc_dd(max(atc->n, 1));
646
647         if (atc->nslab > 1)
648         {
649             srenew(atc->x, atc->nalloc);
650             srenew(atc->q, atc->nalloc);
651             srenew(atc->f, atc->nalloc);
652             for (i = nalloc_old; i < atc->nalloc; i++)
653             {
654                 clear_rvec(atc->f[i]);
655             }
656         }
657         if (atc->bSpread)
658         {
659             srenew(atc->fractx, atc->nalloc);
660             srenew(atc->idx, atc->nalloc);
661
662             if (atc->nthread > 1)
663             {
664                 srenew(atc->thread_idx, atc->nalloc);
665             }
666
667             for (i = 0; i < atc->nthread; i++)
668             {
669                 pme_realloc_splinedata(&atc->spline[i], atc);
670             }
671         }
672     }
673 }
674
675 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
676                          int n, gmx_bool bXF, rvec *x_f, real *charge,
677                          pme_atomcomm_t *atc)
678 /* Redistribute particle data for PME calculation */
679 /* domain decomposition by x coordinate           */
680 {
681     int *idxa;
682     int  i, ii;
683
684     if (FALSE == pme->redist_init)
685     {
686         snew(pme->scounts, atc->nslab);
687         snew(pme->rcounts, atc->nslab);
688         snew(pme->sdispls, atc->nslab);
689         snew(pme->rdispls, atc->nslab);
690         snew(pme->sidx, atc->nslab);
691         pme->redist_init = TRUE;
692     }
693     if (n > pme->redist_buf_nalloc)
694     {
695         pme->redist_buf_nalloc = over_alloc_dd(n);
696         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
697     }
698
699     pme->idxa = atc->pd;
700
701 #ifdef GMX_MPI
702     if (forw && bXF)
703     {
704         /* forward, redistribution from pp to pme */
705
706         /* Calculate send counts and exchange them with other nodes */
707         for (i = 0; (i < atc->nslab); i++)
708         {
709             pme->scounts[i] = 0;
710         }
711         for (i = 0; (i < n); i++)
712         {
713             pme->scounts[pme->idxa[i]]++;
714         }
715         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
716
717         /* Calculate send and receive displacements and index into send
718            buffer */
719         pme->sdispls[0] = 0;
720         pme->rdispls[0] = 0;
721         pme->sidx[0]    = 0;
722         for (i = 1; i < atc->nslab; i++)
723         {
724             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
725             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
726             pme->sidx[i]    = pme->sdispls[i];
727         }
728         /* Total # of particles to be received */
729         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
730
731         pme_realloc_atomcomm_things(atc);
732
733         /* Copy particle coordinates into send buffer and exchange*/
734         for (i = 0; (i < n); i++)
735         {
736             ii = DIM*pme->sidx[pme->idxa[i]];
737             pme->sidx[pme->idxa[i]]++;
738             pme->redist_buf[ii+XX] = x_f[i][XX];
739             pme->redist_buf[ii+YY] = x_f[i][YY];
740             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
741         }
742         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
743                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
744                       pme->rvec_mpi, atc->mpi_comm);
745     }
746     if (forw)
747     {
748         /* Copy charge into send buffer and exchange*/
749         for (i = 0; i < atc->nslab; i++)
750         {
751             pme->sidx[i] = pme->sdispls[i];
752         }
753         for (i = 0; (i < n); i++)
754         {
755             ii = pme->sidx[pme->idxa[i]];
756             pme->sidx[pme->idxa[i]]++;
757             pme->redist_buf[ii] = charge[i];
758         }
759         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
760                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
761                       atc->mpi_comm);
762     }
763     else   /* backward, redistribution from pme to pp */
764     {
765         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
766                       pme->redist_buf, pme->scounts, pme->sdispls,
767                       pme->rvec_mpi, atc->mpi_comm);
768
769         /* Copy data from receive buffer */
770         for (i = 0; i < atc->nslab; i++)
771         {
772             pme->sidx[i] = pme->sdispls[i];
773         }
774         for (i = 0; (i < n); i++)
775         {
776             ii          = DIM*pme->sidx[pme->idxa[i]];
777             x_f[i][XX] += pme->redist_buf[ii+XX];
778             x_f[i][YY] += pme->redist_buf[ii+YY];
779             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
780             pme->sidx[pme->idxa[i]]++;
781         }
782     }
783 #endif
784 }
785
786 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
787                             gmx_bool bBackward, int shift,
788                             void *buf_s, int nbyte_s,
789                             void *buf_r, int nbyte_r)
790 {
791 #ifdef GMX_MPI
792     int        dest, src;
793     MPI_Status stat;
794
795     if (bBackward == FALSE)
796     {
797         dest = atc->node_dest[shift];
798         src  = atc->node_src[shift];
799     }
800     else
801     {
802         dest = atc->node_src[shift];
803         src  = atc->node_dest[shift];
804     }
805
806     if (nbyte_s > 0 && nbyte_r > 0)
807     {
808         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
809                      dest, shift,
810                      buf_r, nbyte_r, MPI_BYTE,
811                      src, shift,
812                      atc->mpi_comm, &stat);
813     }
814     else if (nbyte_s > 0)
815     {
816         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
817                  dest, shift,
818                  atc->mpi_comm);
819     }
820     else if (nbyte_r > 0)
821     {
822         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
823                  src, shift,
824                  atc->mpi_comm, &stat);
825     }
826 #endif
827 }
828
829 static void dd_pmeredist_x_q(gmx_pme_t pme,
830                              int n, gmx_bool bX, rvec *x, real *charge,
831                              pme_atomcomm_t *atc)
832 {
833     int *commnode, *buf_index;
834     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
835
836     commnode  = atc->node_dest;
837     buf_index = atc->buf_index;
838
839     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
840
841     nsend = 0;
842     for (i = 0; i < nnodes_comm; i++)
843     {
844         buf_index[commnode[i]] = nsend;
845         nsend                 += atc->count[commnode[i]];
846     }
847     if (bX)
848     {
849         if (atc->count[atc->nodeid] + nsend != n)
850         {
851             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
852                       "This usually means that your system is not well equilibrated.",
853                       n - (atc->count[atc->nodeid] + nsend),
854                       pme->nodeid, 'x'+atc->dimind);
855         }
856
857         if (nsend > pme->buf_nalloc)
858         {
859             pme->buf_nalloc = over_alloc_dd(nsend);
860             srenew(pme->bufv, pme->buf_nalloc);
861             srenew(pme->bufr, pme->buf_nalloc);
862         }
863
864         atc->n = atc->count[atc->nodeid];
865         for (i = 0; i < nnodes_comm; i++)
866         {
867             scount = atc->count[commnode[i]];
868             /* Communicate the count */
869             if (debug)
870             {
871                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
872                         atc->dimind, atc->nodeid, commnode[i], scount);
873             }
874             pme_dd_sendrecv(atc, FALSE, i,
875                             &scount, sizeof(int),
876                             &atc->rcount[i], sizeof(int));
877             atc->n += atc->rcount[i];
878         }
879
880         pme_realloc_atomcomm_things(atc);
881     }
882
883     local_pos = 0;
884     for (i = 0; i < n; i++)
885     {
886         node = atc->pd[i];
887         if (node == atc->nodeid)
888         {
889             /* Copy direct to the receive buffer */
890             if (bX)
891             {
892                 copy_rvec(x[i], atc->x[local_pos]);
893             }
894             atc->q[local_pos] = charge[i];
895             local_pos++;
896         }
897         else
898         {
899             /* Copy to the send buffer */
900             if (bX)
901             {
902                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
903             }
904             pme->bufr[buf_index[node]] = charge[i];
905             buf_index[node]++;
906         }
907     }
908
909     buf_pos = 0;
910     for (i = 0; i < nnodes_comm; i++)
911     {
912         scount = atc->count[commnode[i]];
913         rcount = atc->rcount[i];
914         if (scount > 0 || rcount > 0)
915         {
916             if (bX)
917             {
918                 /* Communicate the coordinates */
919                 pme_dd_sendrecv(atc, FALSE, i,
920                                 pme->bufv[buf_pos], scount*sizeof(rvec),
921                                 atc->x[local_pos], rcount*sizeof(rvec));
922             }
923             /* Communicate the charges */
924             pme_dd_sendrecv(atc, FALSE, i,
925                             pme->bufr+buf_pos, scount*sizeof(real),
926                             atc->q+local_pos, rcount*sizeof(real));
927             buf_pos   += scount;
928             local_pos += atc->rcount[i];
929         }
930     }
931 }
932
933 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
934                            int n, rvec *f,
935                            gmx_bool bAddF)
936 {
937     int *commnode, *buf_index;
938     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
939
940     commnode  = atc->node_dest;
941     buf_index = atc->buf_index;
942
943     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
944
945     local_pos = atc->count[atc->nodeid];
946     buf_pos   = 0;
947     for (i = 0; i < nnodes_comm; i++)
948     {
949         scount = atc->rcount[i];
950         rcount = atc->count[commnode[i]];
951         if (scount > 0 || rcount > 0)
952         {
953             /* Communicate the forces */
954             pme_dd_sendrecv(atc, TRUE, i,
955                             atc->f[local_pos], scount*sizeof(rvec),
956                             pme->bufv[buf_pos], rcount*sizeof(rvec));
957             local_pos += scount;
958         }
959         buf_index[commnode[i]] = buf_pos;
960         buf_pos               += rcount;
961     }
962
963     local_pos = 0;
964     if (bAddF)
965     {
966         for (i = 0; i < n; i++)
967         {
968             node = atc->pd[i];
969             if (node == atc->nodeid)
970             {
971                 /* Add from the local force array */
972                 rvec_inc(f[i], atc->f[local_pos]);
973                 local_pos++;
974             }
975             else
976             {
977                 /* Add from the receive buffer */
978                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
979                 buf_index[node]++;
980             }
981         }
982     }
983     else
984     {
985         for (i = 0; i < n; i++)
986         {
987             node = atc->pd[i];
988             if (node == atc->nodeid)
989             {
990                 /* Copy from the local force array */
991                 copy_rvec(atc->f[local_pos], f[i]);
992                 local_pos++;
993             }
994             else
995             {
996                 /* Copy from the receive buffer */
997                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
998                 buf_index[node]++;
999             }
1000         }
1001     }
1002 }
1003
1004 #ifdef GMX_MPI
1005 static void
1006 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
1007 {
1008     pme_overlap_t *overlap;
1009     int            send_index0, send_nindex;
1010     int            recv_index0, recv_nindex;
1011     MPI_Status     stat;
1012     int            i, j, k, ix, iy, iz, icnt;
1013     int            ipulse, send_id, recv_id, datasize;
1014     real          *p;
1015     real          *sendptr, *recvptr;
1016
1017     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1018     overlap = &pme->overlap[1];
1019
1020     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1021     {
1022         /* Since we have already (un)wrapped the overlap in the z-dimension,
1023          * we only have to communicate 0 to nkz (not pmegrid_nz).
1024          */
1025         if (direction == GMX_SUM_QGRID_FORWARD)
1026         {
1027             send_id       = overlap->send_id[ipulse];
1028             recv_id       = overlap->recv_id[ipulse];
1029             send_index0   = overlap->comm_data[ipulse].send_index0;
1030             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1031             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1032             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1033         }
1034         else
1035         {
1036             send_id       = overlap->recv_id[ipulse];
1037             recv_id       = overlap->send_id[ipulse];
1038             send_index0   = overlap->comm_data[ipulse].recv_index0;
1039             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1040             recv_index0   = overlap->comm_data[ipulse].send_index0;
1041             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1042         }
1043
1044         /* Copy data to contiguous send buffer */
1045         if (debug)
1046         {
1047             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1048                     pme->nodeid, overlap->nodeid, send_id,
1049                     pme->pmegrid_start_iy,
1050                     send_index0-pme->pmegrid_start_iy,
1051                     send_index0-pme->pmegrid_start_iy+send_nindex);
1052         }
1053         icnt = 0;
1054         for (i = 0; i < pme->pmegrid_nx; i++)
1055         {
1056             ix = i;
1057             for (j = 0; j < send_nindex; j++)
1058             {
1059                 iy = j + send_index0 - pme->pmegrid_start_iy;
1060                 for (k = 0; k < pme->nkz; k++)
1061                 {
1062                     iz = k;
1063                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1064                 }
1065             }
1066         }
1067
1068         datasize      = pme->pmegrid_nx * pme->nkz;
1069
1070         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1071                      send_id, ipulse,
1072                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1073                      recv_id, ipulse,
1074                      overlap->mpi_comm, &stat);
1075
1076         /* Get data from contiguous recv buffer */
1077         if (debug)
1078         {
1079             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1080                     pme->nodeid, overlap->nodeid, recv_id,
1081                     pme->pmegrid_start_iy,
1082                     recv_index0-pme->pmegrid_start_iy,
1083                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1084         }
1085         icnt = 0;
1086         for (i = 0; i < pme->pmegrid_nx; i++)
1087         {
1088             ix = i;
1089             for (j = 0; j < recv_nindex; j++)
1090             {
1091                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1092                 for (k = 0; k < pme->nkz; k++)
1093                 {
1094                     iz = k;
1095                     if (direction == GMX_SUM_QGRID_FORWARD)
1096                     {
1097                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1098                     }
1099                     else
1100                     {
1101                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1102                     }
1103                 }
1104             }
1105         }
1106     }
1107
1108     /* Major dimension is easier, no copying required,
1109      * but we might have to sum to separate array.
1110      * Since we don't copy, we have to communicate up to pmegrid_nz,
1111      * not nkz as for the minor direction.
1112      */
1113     overlap = &pme->overlap[0];
1114
1115     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1116     {
1117         if (direction == GMX_SUM_QGRID_FORWARD)
1118         {
1119             send_id       = overlap->send_id[ipulse];
1120             recv_id       = overlap->recv_id[ipulse];
1121             send_index0   = overlap->comm_data[ipulse].send_index0;
1122             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1123             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1124             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1125             recvptr       = overlap->recvbuf;
1126         }
1127         else
1128         {
1129             send_id       = overlap->recv_id[ipulse];
1130             recv_id       = overlap->send_id[ipulse];
1131             send_index0   = overlap->comm_data[ipulse].recv_index0;
1132             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1133             recv_index0   = overlap->comm_data[ipulse].send_index0;
1134             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1135             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1136         }
1137
1138         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1139         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1140
1141         if (debug)
1142         {
1143             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1144                     pme->nodeid, overlap->nodeid, send_id,
1145                     pme->pmegrid_start_ix,
1146                     send_index0-pme->pmegrid_start_ix,
1147                     send_index0-pme->pmegrid_start_ix+send_nindex);
1148             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1149                     pme->nodeid, overlap->nodeid, recv_id,
1150                     pme->pmegrid_start_ix,
1151                     recv_index0-pme->pmegrid_start_ix,
1152                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1153         }
1154
1155         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1156                      send_id, ipulse,
1157                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1158                      recv_id, ipulse,
1159                      overlap->mpi_comm, &stat);
1160
1161         /* ADD data from contiguous recv buffer */
1162         if (direction == GMX_SUM_QGRID_FORWARD)
1163         {
1164             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1165             for (i = 0; i < recv_nindex*datasize; i++)
1166             {
1167                 p[i] += overlap->recvbuf[i];
1168             }
1169         }
1170     }
1171 }
1172 #endif
1173
1174
1175 static int
1176 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1177 {
1178     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1179     ivec    local_pme_size;
1180     int     i, ix, iy, iz;
1181     int     pmeidx, fftidx;
1182
1183     /* Dimensions should be identical for A/B grid, so we just use A here */
1184     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1185                                    local_fft_ndata,
1186                                    local_fft_offset,
1187                                    local_fft_size);
1188
1189     local_pme_size[0] = pme->pmegrid_nx;
1190     local_pme_size[1] = pme->pmegrid_ny;
1191     local_pme_size[2] = pme->pmegrid_nz;
1192
1193     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1194        the offset is identical, and the PME grid always has more data (due to overlap)
1195      */
1196     {
1197 #ifdef DEBUG_PME
1198         FILE *fp, *fp2;
1199         char  fn[STRLEN], format[STRLEN];
1200         real  val;
1201         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1202         fp = ffopen(fn, "w");
1203         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1204         fp2 = ffopen(fn, "w");
1205         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1206 #endif
1207
1208         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1209         {
1210             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1211             {
1212                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1213                 {
1214                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1215                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1216                     fftgrid[fftidx] = pmegrid[pmeidx];
1217 #ifdef DEBUG_PME
1218                     val = 100*pmegrid[pmeidx];
1219                     if (pmegrid[pmeidx] != 0)
1220                     {
1221                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1222                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1223                     }
1224                     if (pmegrid[pmeidx] != 0)
1225                     {
1226                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1227                                 "qgrid",
1228                                 pme->pmegrid_start_ix + ix,
1229                                 pme->pmegrid_start_iy + iy,
1230                                 pme->pmegrid_start_iz + iz,
1231                                 pmegrid[pmeidx]);
1232                     }
1233 #endif
1234                 }
1235             }
1236         }
1237 #ifdef DEBUG_PME
1238         ffclose(fp);
1239         ffclose(fp2);
1240 #endif
1241     }
1242     return 0;
1243 }
1244
1245
1246 static gmx_cycles_t omp_cyc_start()
1247 {
1248     return gmx_cycles_read();
1249 }
1250
1251 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1252 {
1253     return gmx_cycles_read() - c;
1254 }
1255
1256
1257 static int
1258 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1259                         int nthread, int thread)
1260 {
1261     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1262     ivec          local_pme_size;
1263     int           ixy0, ixy1, ixy, ix, iy, iz;
1264     int           pmeidx, fftidx;
1265 #ifdef PME_TIME_THREADS
1266     gmx_cycles_t  c1;
1267     static double cs1 = 0;
1268     static int    cnt = 0;
1269 #endif
1270
1271 #ifdef PME_TIME_THREADS
1272     c1 = omp_cyc_start();
1273 #endif
1274     /* Dimensions should be identical for A/B grid, so we just use A here */
1275     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1276                                    local_fft_ndata,
1277                                    local_fft_offset,
1278                                    local_fft_size);
1279
1280     local_pme_size[0] = pme->pmegrid_nx;
1281     local_pme_size[1] = pme->pmegrid_ny;
1282     local_pme_size[2] = pme->pmegrid_nz;
1283
1284     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1285        the offset is identical, and the PME grid always has more data (due to overlap)
1286      */
1287     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1288     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1289
1290     for (ixy = ixy0; ixy < ixy1; ixy++)
1291     {
1292         ix = ixy/local_fft_ndata[YY];
1293         iy = ixy - ix*local_fft_ndata[YY];
1294
1295         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1296         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1297         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1298         {
1299             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1300         }
1301     }
1302
1303 #ifdef PME_TIME_THREADS
1304     c1   = omp_cyc_end(c1);
1305     cs1 += (double)c1;
1306     cnt++;
1307     if (cnt % 20 == 0)
1308     {
1309         printf("copy %.2f\n", cs1*1e-9);
1310     }
1311 #endif
1312
1313     return 0;
1314 }
1315
1316
1317 static void
1318 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1319 {
1320     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1321
1322     nx = pme->nkx;
1323     ny = pme->nky;
1324     nz = pme->nkz;
1325
1326     pnx = pme->pmegrid_nx;
1327     pny = pme->pmegrid_ny;
1328     pnz = pme->pmegrid_nz;
1329
1330     overlap = pme->pme_order - 1;
1331
1332     /* Add periodic overlap in z */
1333     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1334     {
1335         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1336         {
1337             for (iz = 0; iz < overlap; iz++)
1338             {
1339                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1340                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1341             }
1342         }
1343     }
1344
1345     if (pme->nnodes_minor == 1)
1346     {
1347         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1348         {
1349             for (iy = 0; iy < overlap; iy++)
1350             {
1351                 for (iz = 0; iz < nz; iz++)
1352                 {
1353                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1354                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1355                 }
1356             }
1357         }
1358     }
1359
1360     if (pme->nnodes_major == 1)
1361     {
1362         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1363
1364         for (ix = 0; ix < overlap; ix++)
1365         {
1366             for (iy = 0; iy < ny_x; iy++)
1367             {
1368                 for (iz = 0; iz < nz; iz++)
1369                 {
1370                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1371                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1372                 }
1373             }
1374         }
1375     }
1376 }
1377
1378
1379 static void
1380 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1381 {
1382     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1383
1384     nx = pme->nkx;
1385     ny = pme->nky;
1386     nz = pme->nkz;
1387
1388     pnx = pme->pmegrid_nx;
1389     pny = pme->pmegrid_ny;
1390     pnz = pme->pmegrid_nz;
1391
1392     overlap = pme->pme_order - 1;
1393
1394     if (pme->nnodes_major == 1)
1395     {
1396         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1397
1398         for (ix = 0; ix < overlap; ix++)
1399         {
1400             int iy, iz;
1401
1402             for (iy = 0; iy < ny_x; iy++)
1403             {
1404                 for (iz = 0; iz < nz; iz++)
1405                 {
1406                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1407                         pmegrid[(ix*pny+iy)*pnz+iz];
1408                 }
1409             }
1410         }
1411     }
1412
1413     if (pme->nnodes_minor == 1)
1414     {
1415 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1416         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1417         {
1418             int iy, iz;
1419
1420             for (iy = 0; iy < overlap; iy++)
1421             {
1422                 for (iz = 0; iz < nz; iz++)
1423                 {
1424                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1425                         pmegrid[(ix*pny+iy)*pnz+iz];
1426                 }
1427             }
1428         }
1429     }
1430
1431     /* Copy periodic overlap in z */
1432 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1433     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1434     {
1435         int iy, iz;
1436
1437         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1438         {
1439             for (iz = 0; iz < overlap; iz++)
1440             {
1441                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1442                     pmegrid[(ix*pny+iy)*pnz+iz];
1443             }
1444         }
1445     }
1446 }
1447
1448 static void clear_grid(int nx, int ny, int nz, real *grid,
1449                        ivec fs, int *flag,
1450                        int fx, int fy, int fz,
1451                        int order)
1452 {
1453     int nc, ncz;
1454     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1455     int flind;
1456
1457     nc  = 2 + (order - 2)/FLBS;
1458     ncz = 2 + (order - 2)/FLBSZ;
1459
1460     for (fsx = fx; fsx < fx+nc; fsx++)
1461     {
1462         for (fsy = fy; fsy < fy+nc; fsy++)
1463         {
1464             for (fsz = fz; fsz < fz+ncz; fsz++)
1465             {
1466                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1467                 if (flag[flind] == 0)
1468                 {
1469                     gx  = fsx*FLBS;
1470                     gy  = fsy*FLBS;
1471                     gz  = fsz*FLBSZ;
1472                     g0x = (gx*ny + gy)*nz + gz;
1473                     for (x = 0; x < FLBS; x++)
1474                     {
1475                         g0y = g0x;
1476                         for (y = 0; y < FLBS; y++)
1477                         {
1478                             for (z = 0; z < FLBSZ; z++)
1479                             {
1480                                 grid[g0y+z] = 0;
1481                             }
1482                             g0y += nz;
1483                         }
1484                         g0x += ny*nz;
1485                     }
1486
1487                     flag[flind] = 1;
1488                 }
1489             }
1490         }
1491     }
1492 }
1493
1494 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1495 #define DO_BSPLINE(order)                            \
1496     for (ithx = 0; (ithx < order); ithx++)                    \
1497     {                                                    \
1498         index_x = (i0+ithx)*pny*pnz;                     \
1499         valx    = qn*thx[ithx];                          \
1500                                                      \
1501         for (ithy = 0; (ithy < order); ithy++)                \
1502         {                                                \
1503             valxy    = valx*thy[ithy];                   \
1504             index_xy = index_x+(j0+ithy)*pnz;            \
1505                                                      \
1506             for (ithz = 0; (ithz < order); ithz++)            \
1507             {                                            \
1508                 index_xyz        = index_xy+(k0+ithz);   \
1509                 grid[index_xyz] += valxy*thz[ithz];      \
1510             }                                            \
1511         }                                                \
1512     }
1513
1514
1515 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1516                                      pme_atomcomm_t *atc, splinedata_t *spline,
1517                                      pme_spline_work_t *work)
1518 {
1519
1520     /* spread charges from home atoms to local grid */
1521     real          *grid;
1522     pme_overlap_t *ol;
1523     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1524     int       *    idxptr;
1525     int            order, norder, index_x, index_xy, index_xyz;
1526     real           valx, valxy, qn;
1527     real          *thx, *thy, *thz;
1528     int            localsize, bndsize;
1529     int            pnx, pny, pnz, ndatatot;
1530     int            offx, offy, offz;
1531
1532 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1533     real           thz_buffer[12], *thz_aligned;
1534
1535     thz_aligned = gmx_simd4_align_real(thz_buffer);
1536 #endif
1537
1538     pnx = pmegrid->s[XX];
1539     pny = pmegrid->s[YY];
1540     pnz = pmegrid->s[ZZ];
1541
1542     offx = pmegrid->offset[XX];
1543     offy = pmegrid->offset[YY];
1544     offz = pmegrid->offset[ZZ];
1545
1546     ndatatot = pnx*pny*pnz;
1547     grid     = pmegrid->grid;
1548     for (i = 0; i < ndatatot; i++)
1549     {
1550         grid[i] = 0;
1551     }
1552
1553     order = pmegrid->order;
1554
1555     for (nn = 0; nn < spline->n; nn++)
1556     {
1557         n  = spline->ind[nn];
1558         qn = atc->q[n];
1559
1560         if (qn != 0)
1561         {
1562             idxptr = atc->idx[n];
1563             norder = nn*order;
1564
1565             i0   = idxptr[XX] - offx;
1566             j0   = idxptr[YY] - offy;
1567             k0   = idxptr[ZZ] - offz;
1568
1569             thx = spline->theta[XX] + norder;
1570             thy = spline->theta[YY] + norder;
1571             thz = spline->theta[ZZ] + norder;
1572
1573             switch (order)
1574             {
1575                 case 4:
1576 #ifdef PME_SIMD4_SPREAD_GATHER
1577 #ifdef PME_SIMD4_UNALIGNED
1578 #define PME_SPREAD_SIMD4_ORDER4
1579 #else
1580 #define PME_SPREAD_SIMD4_ALIGNED
1581 #define PME_ORDER 4
1582 #endif
1583 #include "pme_simd4.h"
1584 #else
1585                     DO_BSPLINE(4);
1586 #endif
1587                     break;
1588                 case 5:
1589 #ifdef PME_SIMD4_SPREAD_GATHER
1590 #define PME_SPREAD_SIMD4_ALIGNED
1591 #define PME_ORDER 5
1592 #include "pme_simd4.h"
1593 #else
1594                     DO_BSPLINE(5);
1595 #endif
1596                     break;
1597                 default:
1598                     DO_BSPLINE(order);
1599                     break;
1600             }
1601         }
1602     }
1603 }
1604
1605 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1606 {
1607 #ifdef PME_SIMD4_SPREAD_GATHER
1608     if (pme_order == 5
1609 #ifndef PME_SIMD4_UNALIGNED
1610         || pme_order == 4
1611 #endif
1612         )
1613     {
1614         /* Round nz up to a multiple of 4 to ensure alignment */
1615         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1616     }
1617 #endif
1618 }
1619
1620 static void set_gridsize_alignment(int *gridsize, int pme_order)
1621 {
1622 #ifdef PME_SIMD4_SPREAD_GATHER
1623 #ifndef PME_SIMD4_UNALIGNED
1624     if (pme_order == 4)
1625     {
1626         /* Add extra elements to ensured aligned operations do not go
1627          * beyond the allocated grid size.
1628          * Note that for pme_order=5, the pme grid z-size alignment
1629          * ensures that we will not go beyond the grid size.
1630          */
1631         *gridsize += 4;
1632     }
1633 #endif
1634 #endif
1635 }
1636
1637 static void pmegrid_init(pmegrid_t *grid,
1638                          int cx, int cy, int cz,
1639                          int x0, int y0, int z0,
1640                          int x1, int y1, int z1,
1641                          gmx_bool set_alignment,
1642                          int pme_order,
1643                          real *ptr)
1644 {
1645     int nz, gridsize;
1646
1647     grid->ci[XX]     = cx;
1648     grid->ci[YY]     = cy;
1649     grid->ci[ZZ]     = cz;
1650     grid->offset[XX] = x0;
1651     grid->offset[YY] = y0;
1652     grid->offset[ZZ] = z0;
1653     grid->n[XX]      = x1 - x0 + pme_order - 1;
1654     grid->n[YY]      = y1 - y0 + pme_order - 1;
1655     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1656     copy_ivec(grid->n, grid->s);
1657
1658     nz = grid->s[ZZ];
1659     set_grid_alignment(&nz, pme_order);
1660     if (set_alignment)
1661     {
1662         grid->s[ZZ] = nz;
1663     }
1664     else if (nz != grid->s[ZZ])
1665     {
1666         gmx_incons("pmegrid_init call with an unaligned z size");
1667     }
1668
1669     grid->order = pme_order;
1670     if (ptr == NULL)
1671     {
1672         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1673         set_gridsize_alignment(&gridsize, pme_order);
1674         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1675     }
1676     else
1677     {
1678         grid->grid = ptr;
1679     }
1680 }
1681
1682 static int div_round_up(int enumerator, int denominator)
1683 {
1684     return (enumerator + denominator - 1)/denominator;
1685 }
1686
1687 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1688                                   ivec nsub)
1689 {
1690     int gsize_opt, gsize;
1691     int nsx, nsy, nsz;
1692     char *env;
1693
1694     gsize_opt = -1;
1695     for (nsx = 1; nsx <= nthread; nsx++)
1696     {
1697         if (nthread % nsx == 0)
1698         {
1699             for (nsy = 1; nsy <= nthread; nsy++)
1700             {
1701                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1702                 {
1703                     nsz = nthread/(nsx*nsy);
1704
1705                     /* Determine the number of grid points per thread */
1706                     gsize =
1707                         (div_round_up(n[XX], nsx) + ovl)*
1708                         (div_round_up(n[YY], nsy) + ovl)*
1709                         (div_round_up(n[ZZ], nsz) + ovl);
1710
1711                     /* Minimize the number of grids points per thread
1712                      * and, secondarily, the number of cuts in minor dimensions.
1713                      */
1714                     if (gsize_opt == -1 ||
1715                         gsize < gsize_opt ||
1716                         (gsize == gsize_opt &&
1717                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1718                     {
1719                         nsub[XX]  = nsx;
1720                         nsub[YY]  = nsy;
1721                         nsub[ZZ]  = nsz;
1722                         gsize_opt = gsize;
1723                     }
1724                 }
1725             }
1726         }
1727     }
1728
1729     env = getenv("GMX_PME_THREAD_DIVISION");
1730     if (env != NULL)
1731     {
1732         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1733     }
1734
1735     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1736     {
1737         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1738     }
1739 }
1740
1741 static void pmegrids_init(pmegrids_t *grids,
1742                           int nx, int ny, int nz, int nz_base,
1743                           int pme_order,
1744                           gmx_bool bUseThreads,
1745                           int nthread,
1746                           int overlap_x,
1747                           int overlap_y)
1748 {
1749     ivec n, n_base, g0, g1;
1750     int t, x, y, z, d, i, tfac;
1751     int max_comm_lines = -1;
1752
1753     n[XX] = nx - (pme_order - 1);
1754     n[YY] = ny - (pme_order - 1);
1755     n[ZZ] = nz - (pme_order - 1);
1756
1757     copy_ivec(n, n_base);
1758     n_base[ZZ] = nz_base;
1759
1760     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1761                  NULL);
1762
1763     grids->nthread = nthread;
1764
1765     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1766
1767     if (bUseThreads)
1768     {
1769         ivec nst;
1770         int gridsize;
1771
1772         for (d = 0; d < DIM; d++)
1773         {
1774             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1775         }
1776         set_grid_alignment(&nst[ZZ], pme_order);
1777
1778         if (debug)
1779         {
1780             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1781                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1782             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1783                     nx, ny, nz,
1784                     nst[XX], nst[YY], nst[ZZ]);
1785         }
1786
1787         snew(grids->grid_th, grids->nthread);
1788         t        = 0;
1789         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1790         set_gridsize_alignment(&gridsize, pme_order);
1791         snew_aligned(grids->grid_all,
1792                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1793                      SIMD4_ALIGNMENT);
1794
1795         for (x = 0; x < grids->nc[XX]; x++)
1796         {
1797             for (y = 0; y < grids->nc[YY]; y++)
1798             {
1799                 for (z = 0; z < grids->nc[ZZ]; z++)
1800                 {
1801                     pmegrid_init(&grids->grid_th[t],
1802                                  x, y, z,
1803                                  (n[XX]*(x  ))/grids->nc[XX],
1804                                  (n[YY]*(y  ))/grids->nc[YY],
1805                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1806                                  (n[XX]*(x+1))/grids->nc[XX],
1807                                  (n[YY]*(y+1))/grids->nc[YY],
1808                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1809                                  TRUE,
1810                                  pme_order,
1811                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1812                     t++;
1813                 }
1814             }
1815         }
1816     }
1817     else
1818     {
1819         grids->grid_th = NULL;
1820     }
1821
1822     snew(grids->g2t, DIM);
1823     tfac = 1;
1824     for (d = DIM-1; d >= 0; d--)
1825     {
1826         snew(grids->g2t[d], n[d]);
1827         t = 0;
1828         for (i = 0; i < n[d]; i++)
1829         {
1830             /* The second check should match the parameters
1831              * of the pmegrid_init call above.
1832              */
1833             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1834             {
1835                 t++;
1836             }
1837             grids->g2t[d][i] = t*tfac;
1838         }
1839
1840         tfac *= grids->nc[d];
1841
1842         switch (d)
1843         {
1844             case XX: max_comm_lines = overlap_x;     break;
1845             case YY: max_comm_lines = overlap_y;     break;
1846             case ZZ: max_comm_lines = pme_order - 1; break;
1847         }
1848         grids->nthread_comm[d] = 0;
1849         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1850                grids->nthread_comm[d] < grids->nc[d])
1851         {
1852             grids->nthread_comm[d]++;
1853         }
1854         if (debug != NULL)
1855         {
1856             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1857                     'x'+d, grids->nthread_comm[d]);
1858         }
1859         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1860          * work, but this is not a problematic restriction.
1861          */
1862         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1863         {
1864             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1865         }
1866     }
1867 }
1868
1869
1870 static void pmegrids_destroy(pmegrids_t *grids)
1871 {
1872     int t;
1873
1874     if (grids->grid.grid != NULL)
1875     {
1876         sfree(grids->grid.grid);
1877
1878         if (grids->nthread > 0)
1879         {
1880             for (t = 0; t < grids->nthread; t++)
1881             {
1882                 sfree(grids->grid_th[t].grid);
1883             }
1884             sfree(grids->grid_th);
1885         }
1886     }
1887 }
1888
1889
1890 static void realloc_work(pme_work_t *work, int nkx)
1891 {
1892     int simd_width;
1893
1894     if (nkx > work->nalloc)
1895     {
1896         work->nalloc = nkx;
1897         srenew(work->mhx, work->nalloc);
1898         srenew(work->mhy, work->nalloc);
1899         srenew(work->mhz, work->nalloc);
1900         srenew(work->m2, work->nalloc);
1901         /* Allocate an aligned pointer for SIMD operations, including extra
1902          * elements at the end for padding.
1903          */
1904 #ifdef PME_SIMD
1905         simd_width = GMX_SIMD_WIDTH_HERE;
1906 #else
1907         /* We can use any alignment, apart from 0, so we use 4 */
1908         simd_width = 4;
1909 #endif
1910         sfree_aligned(work->denom);
1911         sfree_aligned(work->tmp1);
1912         sfree_aligned(work->eterm);
1913         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1914         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1915         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1916         srenew(work->m2inv, work->nalloc);
1917     }
1918 }
1919
1920
1921 static void free_work(pme_work_t *work)
1922 {
1923     sfree(work->mhx);
1924     sfree(work->mhy);
1925     sfree(work->mhz);
1926     sfree(work->m2);
1927     sfree_aligned(work->denom);
1928     sfree_aligned(work->tmp1);
1929     sfree_aligned(work->eterm);
1930     sfree(work->m2inv);
1931 }
1932
1933
1934 #ifdef PME_SIMD
1935 /* Calculate exponentials through SIMD */
1936 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1937 {
1938     {
1939         const gmx_mm_pr two = gmx_set1_pr(2.0);
1940         gmx_mm_pr f_simd;
1941         gmx_mm_pr lu;
1942         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1943         int kx;
1944         f_simd = gmx_set1_pr(f);
1945         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1946         {
1947             tmp_d1   = gmx_load_pr(d_aligned+kx);
1948             d_inv    = gmx_inv_pr(tmp_d1);
1949             tmp_r    = gmx_load_pr(r_aligned+kx);
1950             tmp_r    = gmx_exp_pr(tmp_r);
1951             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1952             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1953             gmx_store_pr(e_aligned+kx, tmp_e);
1954         }
1955     }
1956 }
1957 #else
1958 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1959 {
1960     int kx;
1961     for (kx = start; kx < end; kx++)
1962     {
1963         d[kx] = 1.0/d[kx];
1964     }
1965     for (kx = start; kx < end; kx++)
1966     {
1967         r[kx] = exp(r[kx]);
1968     }
1969     for (kx = start; kx < end; kx++)
1970     {
1971         e[kx] = f*r[kx]*d[kx];
1972     }
1973 }
1974 #endif
1975
1976
1977 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1978                          real ewaldcoeff, real vol,
1979                          gmx_bool bEnerVir,
1980                          int nthread, int thread)
1981 {
1982     /* do recip sum over local cells in grid */
1983     /* y major, z middle, x minor or continuous */
1984     t_complex *p0;
1985     int     kx, ky, kz, maxkx, maxky, maxkz;
1986     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1987     real    mx, my, mz;
1988     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1989     real    ets2, struct2, vfactor, ets2vf;
1990     real    d1, d2, energy = 0;
1991     real    by, bz;
1992     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1993     real    rxx, ryx, ryy, rzx, rzy, rzz;
1994     pme_work_t *work;
1995     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1996     real    mhxk, mhyk, mhzk, m2k;
1997     real    corner_fac;
1998     ivec    complex_order;
1999     ivec    local_ndata, local_offset, local_size;
2000     real    elfac;
2001
2002     elfac = ONE_4PI_EPS0/pme->epsilon_r;
2003
2004     nx = pme->nkx;
2005     ny = pme->nky;
2006     nz = pme->nkz;
2007
2008     /* Dimensions should be identical for A/B grid, so we just use A here */
2009     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
2010                                       complex_order,
2011                                       local_ndata,
2012                                       local_offset,
2013                                       local_size);
2014
2015     rxx = pme->recipbox[XX][XX];
2016     ryx = pme->recipbox[YY][XX];
2017     ryy = pme->recipbox[YY][YY];
2018     rzx = pme->recipbox[ZZ][XX];
2019     rzy = pme->recipbox[ZZ][YY];
2020     rzz = pme->recipbox[ZZ][ZZ];
2021
2022     maxkx = (nx+1)/2;
2023     maxky = (ny+1)/2;
2024     maxkz = nz/2+1;
2025
2026     work  = &pme->work[thread];
2027     mhx   = work->mhx;
2028     mhy   = work->mhy;
2029     mhz   = work->mhz;
2030     m2    = work->m2;
2031     denom = work->denom;
2032     tmp1  = work->tmp1;
2033     eterm = work->eterm;
2034     m2inv = work->m2inv;
2035
2036     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2037     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2038
2039     for (iyz = iyz0; iyz < iyz1; iyz++)
2040     {
2041         iy = iyz/local_ndata[ZZ];
2042         iz = iyz - iy*local_ndata[ZZ];
2043
2044         ky = iy + local_offset[YY];
2045
2046         if (ky < maxky)
2047         {
2048             my = ky;
2049         }
2050         else
2051         {
2052             my = (ky - ny);
2053         }
2054
2055         by = M_PI*vol*pme->bsp_mod[YY][ky];
2056
2057         kz = iz + local_offset[ZZ];
2058
2059         mz = kz;
2060
2061         bz = pme->bsp_mod[ZZ][kz];
2062
2063         /* 0.5 correction for corner points */
2064         corner_fac = 1;
2065         if (kz == 0 || kz == (nz+1)/2)
2066         {
2067             corner_fac = 0.5;
2068         }
2069
2070         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2071
2072         /* We should skip the k-space point (0,0,0) */
2073         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2074         {
2075             kxstart = local_offset[XX];
2076         }
2077         else
2078         {
2079             kxstart = local_offset[XX] + 1;
2080             p0++;
2081         }
2082         kxend = local_offset[XX] + local_ndata[XX];
2083
2084         if (bEnerVir)
2085         {
2086             /* More expensive inner loop, especially because of the storage
2087              * of the mh elements in array's.
2088              * Because x is the minor grid index, all mh elements
2089              * depend on kx for triclinic unit cells.
2090              */
2091
2092             /* Two explicit loops to avoid a conditional inside the loop */
2093             for (kx = kxstart; kx < maxkx; kx++)
2094             {
2095                 mx = kx;
2096
2097                 mhxk      = mx * rxx;
2098                 mhyk      = mx * ryx + my * ryy;
2099                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2100                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2101                 mhx[kx]   = mhxk;
2102                 mhy[kx]   = mhyk;
2103                 mhz[kx]   = mhzk;
2104                 m2[kx]    = m2k;
2105                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2106                 tmp1[kx]  = -factor*m2k;
2107             }
2108
2109             for (kx = maxkx; kx < kxend; kx++)
2110             {
2111                 mx = (kx - nx);
2112
2113                 mhxk      = mx * rxx;
2114                 mhyk      = mx * ryx + my * ryy;
2115                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2116                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2117                 mhx[kx]   = mhxk;
2118                 mhy[kx]   = mhyk;
2119                 mhz[kx]   = mhzk;
2120                 m2[kx]    = m2k;
2121                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2122                 tmp1[kx]  = -factor*m2k;
2123             }
2124
2125             for (kx = kxstart; kx < kxend; kx++)
2126             {
2127                 m2inv[kx] = 1.0/m2[kx];
2128             }
2129
2130             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2131
2132             for (kx = kxstart; kx < kxend; kx++, p0++)
2133             {
2134                 d1      = p0->re;
2135                 d2      = p0->im;
2136
2137                 p0->re  = d1*eterm[kx];
2138                 p0->im  = d2*eterm[kx];
2139
2140                 struct2 = 2.0*(d1*d1+d2*d2);
2141
2142                 tmp1[kx] = eterm[kx]*struct2;
2143             }
2144
2145             for (kx = kxstart; kx < kxend; kx++)
2146             {
2147                 ets2     = corner_fac*tmp1[kx];
2148                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2149                 energy  += ets2;
2150
2151                 ets2vf   = ets2*vfactor;
2152                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2153                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2154                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2155                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2156                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2157                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2158             }
2159         }
2160         else
2161         {
2162             /* We don't need to calculate the energy and the virial.
2163              * In this case the triclinic overhead is small.
2164              */
2165
2166             /* Two explicit loops to avoid a conditional inside the loop */
2167
2168             for (kx = kxstart; kx < maxkx; kx++)
2169             {
2170                 mx = kx;
2171
2172                 mhxk      = mx * rxx;
2173                 mhyk      = mx * ryx + my * ryy;
2174                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2175                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2176                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2177                 tmp1[kx]  = -factor*m2k;
2178             }
2179
2180             for (kx = maxkx; kx < kxend; kx++)
2181             {
2182                 mx = (kx - nx);
2183
2184                 mhxk      = mx * rxx;
2185                 mhyk      = mx * ryx + my * ryy;
2186                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2187                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2188                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2189                 tmp1[kx]  = -factor*m2k;
2190             }
2191
2192             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2193
2194             for (kx = kxstart; kx < kxend; kx++, p0++)
2195             {
2196                 d1      = p0->re;
2197                 d2      = p0->im;
2198
2199                 p0->re  = d1*eterm[kx];
2200                 p0->im  = d2*eterm[kx];
2201             }
2202         }
2203     }
2204
2205     if (bEnerVir)
2206     {
2207         /* Update virial with local values.
2208          * The virial is symmetric by definition.
2209          * this virial seems ok for isotropic scaling, but I'm
2210          * experiencing problems on semiisotropic membranes.
2211          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2212          */
2213         work->vir[XX][XX] = 0.25*virxx;
2214         work->vir[YY][YY] = 0.25*viryy;
2215         work->vir[ZZ][ZZ] = 0.25*virzz;
2216         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2217         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2218         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2219
2220         /* This energy should be corrected for a charged system */
2221         work->energy = 0.5*energy;
2222     }
2223
2224     /* Return the loop count */
2225     return local_ndata[YY]*local_ndata[XX];
2226 }
2227
2228 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2229                              real *mesh_energy, matrix vir)
2230 {
2231     /* This function sums output over threads
2232      * and should therefore only be called after thread synchronization.
2233      */
2234     int thread;
2235
2236     *mesh_energy = pme->work[0].energy;
2237     copy_mat(pme->work[0].vir, vir);
2238
2239     for (thread = 1; thread < nthread; thread++)
2240     {
2241         *mesh_energy += pme->work[thread].energy;
2242         m_add(vir, pme->work[thread].vir, vir);
2243     }
2244 }
2245
2246 #define DO_FSPLINE(order)                      \
2247     for (ithx = 0; (ithx < order); ithx++)              \
2248     {                                              \
2249         index_x = (i0+ithx)*pny*pnz;               \
2250         tx      = thx[ithx];                       \
2251         dx      = dthx[ithx];                      \
2252                                                \
2253         for (ithy = 0; (ithy < order); ithy++)          \
2254         {                                          \
2255             index_xy = index_x+(j0+ithy)*pnz;      \
2256             ty       = thy[ithy];                  \
2257             dy       = dthy[ithy];                 \
2258             fxy1     = fz1 = 0;                    \
2259                                                \
2260             for (ithz = 0; (ithz < order); ithz++)      \
2261             {                                      \
2262                 gval  = grid[index_xy+(k0+ithz)];  \
2263                 fxy1 += thz[ithz]*gval;            \
2264                 fz1  += dthz[ithz]*gval;           \
2265             }                                      \
2266             fx += dx*ty*fxy1;                      \
2267             fy += tx*dy*fxy1;                      \
2268             fz += tx*ty*fz1;                       \
2269         }                                          \
2270     }
2271
2272
2273 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2274                               gmx_bool bClearF, pme_atomcomm_t *atc,
2275                               splinedata_t *spline,
2276                               real scale)
2277 {
2278     /* sum forces for local particles */
2279     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2280     int     index_x, index_xy;
2281     int     nx, ny, nz, pnx, pny, pnz;
2282     int *   idxptr;
2283     real    tx, ty, dx, dy, qn;
2284     real    fx, fy, fz, gval;
2285     real    fxy1, fz1;
2286     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2287     int     norder;
2288     real    rxx, ryx, ryy, rzx, rzy, rzz;
2289     int     order;
2290
2291     pme_spline_work_t *work;
2292
2293 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2294     real           thz_buffer[12],  *thz_aligned;
2295     real           dthz_buffer[12], *dthz_aligned;
2296
2297     thz_aligned  = gmx_simd4_align_real(thz_buffer);
2298     dthz_aligned = gmx_simd4_align_real(dthz_buffer);
2299 #endif
2300
2301     work = pme->spline_work;
2302
2303     order = pme->pme_order;
2304     thx   = spline->theta[XX];
2305     thy   = spline->theta[YY];
2306     thz   = spline->theta[ZZ];
2307     dthx  = spline->dtheta[XX];
2308     dthy  = spline->dtheta[YY];
2309     dthz  = spline->dtheta[ZZ];
2310     nx    = pme->nkx;
2311     ny    = pme->nky;
2312     nz    = pme->nkz;
2313     pnx   = pme->pmegrid_nx;
2314     pny   = pme->pmegrid_ny;
2315     pnz   = pme->pmegrid_nz;
2316
2317     rxx   = pme->recipbox[XX][XX];
2318     ryx   = pme->recipbox[YY][XX];
2319     ryy   = pme->recipbox[YY][YY];
2320     rzx   = pme->recipbox[ZZ][XX];
2321     rzy   = pme->recipbox[ZZ][YY];
2322     rzz   = pme->recipbox[ZZ][ZZ];
2323
2324     for (nn = 0; nn < spline->n; nn++)
2325     {
2326         n  = spline->ind[nn];
2327         qn = scale*atc->q[n];
2328
2329         if (bClearF)
2330         {
2331             atc->f[n][XX] = 0;
2332             atc->f[n][YY] = 0;
2333             atc->f[n][ZZ] = 0;
2334         }
2335         if (qn != 0)
2336         {
2337             fx     = 0;
2338             fy     = 0;
2339             fz     = 0;
2340             idxptr = atc->idx[n];
2341             norder = nn*order;
2342
2343             i0   = idxptr[XX];
2344             j0   = idxptr[YY];
2345             k0   = idxptr[ZZ];
2346
2347             /* Pointer arithmetic alert, next six statements */
2348             thx  = spline->theta[XX] + norder;
2349             thy  = spline->theta[YY] + norder;
2350             thz  = spline->theta[ZZ] + norder;
2351             dthx = spline->dtheta[XX] + norder;
2352             dthy = spline->dtheta[YY] + norder;
2353             dthz = spline->dtheta[ZZ] + norder;
2354
2355             switch (order)
2356             {
2357                 case 4:
2358 #ifdef PME_SIMD4_SPREAD_GATHER
2359 #ifdef PME_SIMD4_UNALIGNED
2360 #define PME_GATHER_F_SIMD4_ORDER4
2361 #else
2362 #define PME_GATHER_F_SIMD4_ALIGNED
2363 #define PME_ORDER 4
2364 #endif
2365 #include "pme_simd4.h"
2366 #else
2367                     DO_FSPLINE(4);
2368 #endif
2369                     break;
2370                 case 5:
2371 #ifdef PME_SIMD4_SPREAD_GATHER
2372 #define PME_GATHER_F_SIMD4_ALIGNED
2373 #define PME_ORDER 5
2374 #include "pme_simd4.h"
2375 #else
2376                     DO_FSPLINE(5);
2377 #endif
2378                     break;
2379                 default:
2380                     DO_FSPLINE(order);
2381                     break;
2382             }
2383
2384             atc->f[n][XX] += -qn*( fx*nx*rxx );
2385             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2386             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2387         }
2388     }
2389     /* Since the energy and not forces are interpolated
2390      * the net force might not be exactly zero.
2391      * This can be solved by also interpolating F, but
2392      * that comes at a cost.
2393      * A better hack is to remove the net force every
2394      * step, but that must be done at a higher level
2395      * since this routine doesn't see all atoms if running
2396      * in parallel. Don't know how important it is?  EL 990726
2397      */
2398 }
2399
2400
2401 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2402                                    pme_atomcomm_t *atc)
2403 {
2404     splinedata_t *spline;
2405     int     n, ithx, ithy, ithz, i0, j0, k0;
2406     int     index_x, index_xy;
2407     int *   idxptr;
2408     real    energy, pot, tx, ty, qn, gval;
2409     real    *thx, *thy, *thz;
2410     int     norder;
2411     int     order;
2412
2413     spline = &atc->spline[0];
2414
2415     order = pme->pme_order;
2416
2417     energy = 0;
2418     for (n = 0; (n < atc->n); n++)
2419     {
2420         qn      = atc->q[n];
2421
2422         if (qn != 0)
2423         {
2424             idxptr = atc->idx[n];
2425             norder = n*order;
2426
2427             i0   = idxptr[XX];
2428             j0   = idxptr[YY];
2429             k0   = idxptr[ZZ];
2430
2431             /* Pointer arithmetic alert, next three statements */
2432             thx  = spline->theta[XX] + norder;
2433             thy  = spline->theta[YY] + norder;
2434             thz  = spline->theta[ZZ] + norder;
2435
2436             pot = 0;
2437             for (ithx = 0; (ithx < order); ithx++)
2438             {
2439                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2440                 tx      = thx[ithx];
2441
2442                 for (ithy = 0; (ithy < order); ithy++)
2443                 {
2444                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2445                     ty       = thy[ithy];
2446
2447                     for (ithz = 0; (ithz < order); ithz++)
2448                     {
2449                         gval  = grid[index_xy+(k0+ithz)];
2450                         pot  += tx*ty*thz[ithz]*gval;
2451                     }
2452
2453                 }
2454             }
2455
2456             energy += pot*qn;
2457         }
2458     }
2459
2460     return energy;
2461 }
2462
2463 /* Macro to force loop unrolling by fixing order.
2464  * This gives a significant performance gain.
2465  */
2466 #define CALC_SPLINE(order)                     \
2467     {                                              \
2468         int j, k, l;                                 \
2469         real dr, div;                               \
2470         real data[PME_ORDER_MAX];                  \
2471         real ddata[PME_ORDER_MAX];                 \
2472                                                \
2473         for (j = 0; (j < DIM); j++)                     \
2474         {                                          \
2475             dr  = xptr[j];                         \
2476                                                \
2477             /* dr is relative offset from lower cell limit */ \
2478             data[order-1] = 0;                     \
2479             data[1]       = dr;                          \
2480             data[0]       = 1 - dr;                      \
2481                                                \
2482             for (k = 3; (k < order); k++)               \
2483             {                                      \
2484                 div       = 1.0/(k - 1.0);               \
2485                 data[k-1] = div*dr*data[k-2];      \
2486                 for (l = 1; (l < (k-1)); l++)           \
2487                 {                                  \
2488                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2489                                        data[k-l-1]);                \
2490                 }                                  \
2491                 data[0] = div*(1-dr)*data[0];      \
2492             }                                      \
2493             /* differentiate */                    \
2494             ddata[0] = -data[0];                   \
2495             for (k = 1; (k < order); k++)               \
2496             {                                      \
2497                 ddata[k] = data[k-1] - data[k];    \
2498             }                                      \
2499                                                \
2500             div           = 1.0/(order - 1);                 \
2501             data[order-1] = div*dr*data[order-2];  \
2502             for (l = 1; (l < (order-1)); l++)           \
2503             {                                      \
2504                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2505                                        (order-l-dr)*data[order-l-1]); \
2506             }                                      \
2507             data[0] = div*(1 - dr)*data[0];        \
2508                                                \
2509             for (k = 0; k < order; k++)                 \
2510             {                                      \
2511                 theta[j][i*order+k]  = data[k];    \
2512                 dtheta[j][i*order+k] = ddata[k];   \
2513             }                                      \
2514         }                                          \
2515     }
2516
2517 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2518                    rvec fractx[], int nr, int ind[], real charge[],
2519                    gmx_bool bFreeEnergy)
2520 {
2521     /* construct splines for local atoms */
2522     int  i, ii;
2523     real *xptr;
2524
2525     for (i = 0; i < nr; i++)
2526     {
2527         /* With free energy we do not use the charge check.
2528          * In most cases this will be more efficient than calling make_bsplines
2529          * twice, since usually more than half the particles have charges.
2530          */
2531         ii = ind[i];
2532         if (bFreeEnergy || charge[ii] != 0.0)
2533         {
2534             xptr = fractx[ii];
2535             switch (order)
2536             {
2537                 case 4:  CALC_SPLINE(4);     break;
2538                 case 5:  CALC_SPLINE(5);     break;
2539                 default: CALC_SPLINE(order); break;
2540             }
2541         }
2542     }
2543 }
2544
2545
2546 void make_dft_mod(real *mod, real *data, int ndata)
2547 {
2548     int i, j;
2549     real sc, ss, arg;
2550
2551     for (i = 0; i < ndata; i++)
2552     {
2553         sc = ss = 0;
2554         for (j = 0; j < ndata; j++)
2555         {
2556             arg = (2.0*M_PI*i*j)/ndata;
2557             sc += data[j]*cos(arg);
2558             ss += data[j]*sin(arg);
2559         }
2560         mod[i] = sc*sc+ss*ss;
2561     }
2562     for (i = 0; i < ndata; i++)
2563     {
2564         if (mod[i] < 1e-7)
2565         {
2566             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2567         }
2568     }
2569 }
2570
2571
2572 static void make_bspline_moduli(splinevec bsp_mod,
2573                                 int nx, int ny, int nz, int order)
2574 {
2575     int nmax = max(nx, max(ny, nz));
2576     real *data, *ddata, *bsp_data;
2577     int i, k, l;
2578     real div;
2579
2580     snew(data, order);
2581     snew(ddata, order);
2582     snew(bsp_data, nmax);
2583
2584     data[order-1] = 0;
2585     data[1]       = 0;
2586     data[0]       = 1;
2587
2588     for (k = 3; k < order; k++)
2589     {
2590         div       = 1.0/(k-1.0);
2591         data[k-1] = 0;
2592         for (l = 1; l < (k-1); l++)
2593         {
2594             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2595         }
2596         data[0] = div*data[0];
2597     }
2598     /* differentiate */
2599     ddata[0] = -data[0];
2600     for (k = 1; k < order; k++)
2601     {
2602         ddata[k] = data[k-1]-data[k];
2603     }
2604     div           = 1.0/(order-1);
2605     data[order-1] = 0;
2606     for (l = 1; l < (order-1); l++)
2607     {
2608         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2609     }
2610     data[0] = div*data[0];
2611
2612     for (i = 0; i < nmax; i++)
2613     {
2614         bsp_data[i] = 0;
2615     }
2616     for (i = 1; i <= order; i++)
2617     {
2618         bsp_data[i] = data[i-1];
2619     }
2620
2621     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2622     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2623     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2624
2625     sfree(data);
2626     sfree(ddata);
2627     sfree(bsp_data);
2628 }
2629
2630
2631 /* Return the P3M optimal influence function */
2632 static double do_p3m_influence(double z, int order)
2633 {
2634     double z2, z4;
2635
2636     z2 = z*z;
2637     z4 = z2*z2;
2638
2639     /* The formula and most constants can be found in:
2640      * Ballenegger et al., JCTC 8, 936 (2012)
2641      */
2642     switch (order)
2643     {
2644         case 2:
2645             return 1.0 - 2.0*z2/3.0;
2646             break;
2647         case 3:
2648             return 1.0 - z2 + 2.0*z4/15.0;
2649             break;
2650         case 4:
2651             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2652             break;
2653         case 5:
2654             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2655             break;
2656         case 6:
2657             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2658             break;
2659         case 7:
2660             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2661         case 8:
2662             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2663             break;
2664     }
2665
2666     return 0.0;
2667 }
2668
2669 /* Calculate the P3M B-spline moduli for one dimension */
2670 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2671 {
2672     double zarg, zai, sinzai, infl;
2673     int    maxk, i;
2674
2675     if (order > 8)
2676     {
2677         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2678     }
2679
2680     zarg = M_PI/n;
2681
2682     maxk = (n + 1)/2;
2683
2684     for (i = -maxk; i < 0; i++)
2685     {
2686         zai          = zarg*i;
2687         sinzai       = sin(zai);
2688         infl         = do_p3m_influence(sinzai, order);
2689         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2690     }
2691     bsp_mod[0] = 1.0;
2692     for (i = 1; i < maxk; i++)
2693     {
2694         zai        = zarg*i;
2695         sinzai     = sin(zai);
2696         infl       = do_p3m_influence(sinzai, order);
2697         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2698     }
2699 }
2700
2701 /* Calculate the P3M B-spline moduli */
2702 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2703                                     int nx, int ny, int nz, int order)
2704 {
2705     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2706     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2707     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2708 }
2709
2710
2711 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2712 {
2713     int nslab, n, i;
2714     int fw, bw;
2715
2716     nslab = atc->nslab;
2717
2718     n = 0;
2719     for (i = 1; i <= nslab/2; i++)
2720     {
2721         fw = (atc->nodeid + i) % nslab;
2722         bw = (atc->nodeid - i + nslab) % nslab;
2723         if (n < nslab - 1)
2724         {
2725             atc->node_dest[n] = fw;
2726             atc->node_src[n]  = bw;
2727             n++;
2728         }
2729         if (n < nslab - 1)
2730         {
2731             atc->node_dest[n] = bw;
2732             atc->node_src[n]  = fw;
2733             n++;
2734         }
2735     }
2736 }
2737
2738 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2739 {
2740     int thread;
2741
2742     if (NULL != log)
2743     {
2744         fprintf(log, "Destroying PME data structures.\n");
2745     }
2746
2747     sfree((*pmedata)->nnx);
2748     sfree((*pmedata)->nny);
2749     sfree((*pmedata)->nnz);
2750
2751     pmegrids_destroy(&(*pmedata)->pmegridA);
2752
2753     sfree((*pmedata)->fftgridA);
2754     sfree((*pmedata)->cfftgridA);
2755     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2756
2757     if ((*pmedata)->pmegridB.grid.grid != NULL)
2758     {
2759         pmegrids_destroy(&(*pmedata)->pmegridB);
2760         sfree((*pmedata)->fftgridB);
2761         sfree((*pmedata)->cfftgridB);
2762         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2763     }
2764     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2765     {
2766         free_work(&(*pmedata)->work[thread]);
2767     }
2768     sfree((*pmedata)->work);
2769
2770     sfree(*pmedata);
2771     *pmedata = NULL;
2772
2773     return 0;
2774 }
2775
2776 static int mult_up(int n, int f)
2777 {
2778     return ((n + f - 1)/f)*f;
2779 }
2780
2781
2782 static double pme_load_imbalance(gmx_pme_t pme)
2783 {
2784     int    nma, nmi;
2785     double n1, n2, n3;
2786
2787     nma = pme->nnodes_major;
2788     nmi = pme->nnodes_minor;
2789
2790     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2791     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2792     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2793
2794     /* pme_solve is roughly double the cost of an fft */
2795
2796     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2797 }
2798
2799 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2800                           int dimind, gmx_bool bSpread)
2801 {
2802     int nk, k, s, thread;
2803
2804     atc->dimind    = dimind;
2805     atc->nslab     = 1;
2806     atc->nodeid    = 0;
2807     atc->pd_nalloc = 0;
2808 #ifdef GMX_MPI
2809     if (pme->nnodes > 1)
2810     {
2811         atc->mpi_comm = pme->mpi_comm_d[dimind];
2812         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2813         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2814     }
2815     if (debug)
2816     {
2817         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2818     }
2819 #endif
2820
2821     atc->bSpread   = bSpread;
2822     atc->pme_order = pme->pme_order;
2823
2824     if (atc->nslab > 1)
2825     {
2826         /* These three allocations are not required for particle decomp. */
2827         snew(atc->node_dest, atc->nslab);
2828         snew(atc->node_src, atc->nslab);
2829         setup_coordinate_communication(atc);
2830
2831         snew(atc->count_thread, pme->nthread);
2832         for (thread = 0; thread < pme->nthread; thread++)
2833         {
2834             snew(atc->count_thread[thread], atc->nslab);
2835         }
2836         atc->count = atc->count_thread[0];
2837         snew(atc->rcount, atc->nslab);
2838         snew(atc->buf_index, atc->nslab);
2839     }
2840
2841     atc->nthread = pme->nthread;
2842     if (atc->nthread > 1)
2843     {
2844         snew(atc->thread_plist, atc->nthread);
2845     }
2846     snew(atc->spline, atc->nthread);
2847     for (thread = 0; thread < atc->nthread; thread++)
2848     {
2849         if (atc->nthread > 1)
2850         {
2851             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2852             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2853         }
2854         snew(atc->spline[thread].thread_one, pme->nthread);
2855         atc->spline[thread].thread_one[thread] = 1;
2856     }
2857 }
2858
2859 static void
2860 init_overlap_comm(pme_overlap_t *  ol,
2861                   int              norder,
2862 #ifdef GMX_MPI
2863                   MPI_Comm         comm,
2864 #endif
2865                   int              nnodes,
2866                   int              nodeid,
2867                   int              ndata,
2868                   int              commplainsize)
2869 {
2870     int lbnd, rbnd, maxlr, b, i;
2871     int exten;
2872     int nn, nk;
2873     pme_grid_comm_t *pgc;
2874     gmx_bool bCont;
2875     int fft_start, fft_end, send_index1, recv_index1;
2876 #ifdef GMX_MPI
2877     MPI_Status stat;
2878
2879     ol->mpi_comm = comm;
2880 #endif
2881
2882     ol->nnodes = nnodes;
2883     ol->nodeid = nodeid;
2884
2885     /* Linear translation of the PME grid won't affect reciprocal space
2886      * calculations, so to optimize we only interpolate "upwards",
2887      * which also means we only have to consider overlap in one direction.
2888      * I.e., particles on this node might also be spread to grid indices
2889      * that belong to higher nodes (modulo nnodes)
2890      */
2891
2892     snew(ol->s2g0, ol->nnodes+1);
2893     snew(ol->s2g1, ol->nnodes);
2894     if (debug)
2895     {
2896         fprintf(debug, "PME slab boundaries:");
2897     }
2898     for (i = 0; i < nnodes; i++)
2899     {
2900         /* s2g0 the local interpolation grid start.
2901          * s2g1 the local interpolation grid end.
2902          * Because grid overlap communication only goes forward,
2903          * the grid the slabs for fft's should be rounded down.
2904          */
2905         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2906         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2907
2908         if (debug)
2909         {
2910             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2911         }
2912     }
2913     ol->s2g0[nnodes] = ndata;
2914     if (debug)
2915     {
2916         fprintf(debug, "\n");
2917     }
2918
2919     /* Determine with how many nodes we need to communicate the grid overlap */
2920     b = 0;
2921     do
2922     {
2923         b++;
2924         bCont = FALSE;
2925         for (i = 0; i < nnodes; i++)
2926         {
2927             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2928                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2929             {
2930                 bCont = TRUE;
2931             }
2932         }
2933     }
2934     while (bCont && b < nnodes);
2935     ol->noverlap_nodes = b - 1;
2936
2937     snew(ol->send_id, ol->noverlap_nodes);
2938     snew(ol->recv_id, ol->noverlap_nodes);
2939     for (b = 0; b < ol->noverlap_nodes; b++)
2940     {
2941         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2942         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2943     }
2944     snew(ol->comm_data, ol->noverlap_nodes);
2945
2946     ol->send_size = 0;
2947     for (b = 0; b < ol->noverlap_nodes; b++)
2948     {
2949         pgc = &ol->comm_data[b];
2950         /* Send */
2951         fft_start        = ol->s2g0[ol->send_id[b]];
2952         fft_end          = ol->s2g0[ol->send_id[b]+1];
2953         if (ol->send_id[b] < nodeid)
2954         {
2955             fft_start += ndata;
2956             fft_end   += ndata;
2957         }
2958         send_index1       = ol->s2g1[nodeid];
2959         send_index1       = min(send_index1, fft_end);
2960         pgc->send_index0  = fft_start;
2961         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2962         ol->send_size    += pgc->send_nindex;
2963
2964         /* We always start receiving to the first index of our slab */
2965         fft_start        = ol->s2g0[ol->nodeid];
2966         fft_end          = ol->s2g0[ol->nodeid+1];
2967         recv_index1      = ol->s2g1[ol->recv_id[b]];
2968         if (ol->recv_id[b] > nodeid)
2969         {
2970             recv_index1 -= ndata;
2971         }
2972         recv_index1      = min(recv_index1, fft_end);
2973         pgc->recv_index0 = fft_start;
2974         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2975     }
2976
2977 #ifdef GMX_MPI
2978     /* Communicate the buffer sizes to receive */
2979     for (b = 0; b < ol->noverlap_nodes; b++)
2980     {
2981         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2982                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2983                      ol->mpi_comm, &stat);
2984     }
2985 #endif
2986
2987     /* For non-divisible grid we need pme_order iso pme_order-1 */
2988     snew(ol->sendbuf, norder*commplainsize);
2989     snew(ol->recvbuf, norder*commplainsize);
2990 }
2991
2992 static void
2993 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2994                               int **global_to_local,
2995                               real **fraction_shift)
2996 {
2997     int i;
2998     int * gtl;
2999     real * fsh;
3000
3001     snew(gtl, 5*n);
3002     snew(fsh, 5*n);
3003     for (i = 0; (i < 5*n); i++)
3004     {
3005         /* Determine the global to local grid index */
3006         gtl[i] = (i - local_start + n) % n;
3007         /* For coordinates that fall within the local grid the fraction
3008          * is correct, we don't need to shift it.
3009          */
3010         fsh[i] = 0;
3011         if (local_range < n)
3012         {
3013             /* Due to rounding issues i could be 1 beyond the lower or
3014              * upper boundary of the local grid. Correct the index for this.
3015              * If we shift the index, we need to shift the fraction by
3016              * the same amount in the other direction to not affect
3017              * the weights.
3018              * Note that due to this shifting the weights at the end of
3019              * the spline might change, but that will only involve values
3020              * between zero and values close to the precision of a real,
3021              * which is anyhow the accuracy of the whole mesh calculation.
3022              */
3023             /* With local_range=0 we should not change i=local_start */
3024             if (i % n != local_start)
3025             {
3026                 if (gtl[i] == n-1)
3027                 {
3028                     gtl[i] = 0;
3029                     fsh[i] = -1;
3030                 }
3031                 else if (gtl[i] == local_range)
3032                 {
3033                     gtl[i] = local_range - 1;
3034                     fsh[i] = 1;
3035                 }
3036             }
3037         }
3038     }
3039
3040     *global_to_local = gtl;
3041     *fraction_shift  = fsh;
3042 }
3043
3044 static pme_spline_work_t *make_pme_spline_work(int order)
3045 {
3046     pme_spline_work_t *work;
3047
3048 #ifdef PME_SIMD4_SPREAD_GATHER
3049     real         tmp[12], *tmp_aligned;
3050     gmx_simd4_pr zero_S;
3051     gmx_simd4_pr real_mask_S0, real_mask_S1;
3052     int          of, i;
3053
3054     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3055
3056     tmp_aligned = gmx_simd4_align_real(tmp);
3057
3058     zero_S = gmx_simd4_setzero_pr();
3059
3060     /* Generate bit masks to mask out the unused grid entries,
3061      * as we only operate on order of the 8 grid entries that are
3062      * load into 2 SIMD registers.
3063      */
3064     for (of = 0; of < 8-(order-1); of++)
3065     {
3066         for (i = 0; i < 8; i++)
3067         {
3068             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3069         }
3070         real_mask_S0      = gmx_simd4_load_pr(tmp_aligned);
3071         real_mask_S1      = gmx_simd4_load_pr(tmp_aligned+4);
3072         work->mask_S0[of] = gmx_simd4_cmplt_pr(real_mask_S0, zero_S);
3073         work->mask_S1[of] = gmx_simd4_cmplt_pr(real_mask_S1, zero_S);
3074     }
3075 #else
3076     work = NULL;
3077 #endif
3078
3079     return work;
3080 }
3081
3082 void gmx_pme_check_restrictions(int pme_order,
3083                                 int nkx, int nky, int nkz,
3084                                 int nnodes_major,
3085                                 int nnodes_minor,
3086                                 gmx_bool bUseThreads,
3087                                 gmx_bool bFatal,
3088                                 gmx_bool *bValidSettings)
3089 {
3090     if (pme_order > PME_ORDER_MAX)
3091     {
3092         if (!bFatal)
3093         {
3094             *bValidSettings = FALSE;
3095             return;
3096         }
3097         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3098                   pme_order, PME_ORDER_MAX);
3099     }
3100
3101     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3102         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3103         nkz <= pme_order)
3104     {
3105         if (!bFatal)
3106         {
3107             *bValidSettings = FALSE;
3108             return;
3109         }
3110         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3111                   pme_order);
3112     }
3113
3114     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3115      * We only allow multiple communication pulses in dim 1, not in dim 0.
3116      */
3117     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3118                         nkx != nnodes_major*(pme_order - 1)))
3119     {
3120         if (!bFatal)
3121         {
3122             *bValidSettings = FALSE;
3123             return;
3124         }
3125         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3126                   nkx/(double)nnodes_major, pme_order);
3127     }
3128
3129     if (bValidSettings != NULL)
3130     {
3131         *bValidSettings = TRUE;
3132     }
3133
3134     return;
3135 }
3136
3137 int gmx_pme_init(gmx_pme_t *         pmedata,
3138                  t_commrec *         cr,
3139                  int                 nnodes_major,
3140                  int                 nnodes_minor,
3141                  t_inputrec *        ir,
3142                  int                 homenr,
3143                  gmx_bool            bFreeEnergy,
3144                  gmx_bool            bReproducible,
3145                  int                 nthread)
3146 {
3147     gmx_pme_t pme = NULL;
3148
3149     int  use_threads, sum_use_threads;
3150     ivec ndata;
3151
3152     if (debug)
3153     {
3154         fprintf(debug, "Creating PME data structures.\n");
3155     }
3156     snew(pme, 1);
3157
3158     pme->redist_init         = FALSE;
3159     pme->sum_qgrid_tmp       = NULL;
3160     pme->sum_qgrid_dd_tmp    = NULL;
3161     pme->buf_nalloc          = 0;
3162     pme->redist_buf_nalloc   = 0;
3163
3164     pme->nnodes              = 1;
3165     pme->bPPnode             = TRUE;
3166
3167     pme->nnodes_major        = nnodes_major;
3168     pme->nnodes_minor        = nnodes_minor;
3169
3170 #ifdef GMX_MPI
3171     if (nnodes_major*nnodes_minor > 1)
3172     {
3173         pme->mpi_comm = cr->mpi_comm_mygroup;
3174
3175         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3176         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3177         if (pme->nnodes != nnodes_major*nnodes_minor)
3178         {
3179             gmx_incons("PME node count mismatch");
3180         }
3181     }
3182     else
3183     {
3184         pme->mpi_comm = MPI_COMM_NULL;
3185     }
3186 #endif
3187
3188     if (pme->nnodes == 1)
3189     {
3190 #ifdef GMX_MPI
3191         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3192         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3193 #endif
3194         pme->ndecompdim   = 0;
3195         pme->nodeid_major = 0;
3196         pme->nodeid_minor = 0;
3197 #ifdef GMX_MPI
3198         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3199 #endif
3200     }
3201     else
3202     {
3203         if (nnodes_minor == 1)
3204         {
3205 #ifdef GMX_MPI
3206             pme->mpi_comm_d[0] = pme->mpi_comm;
3207             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3208 #endif
3209             pme->ndecompdim   = 1;
3210             pme->nodeid_major = pme->nodeid;
3211             pme->nodeid_minor = 0;
3212
3213         }
3214         else if (nnodes_major == 1)
3215         {
3216 #ifdef GMX_MPI
3217             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3218             pme->mpi_comm_d[1] = pme->mpi_comm;
3219 #endif
3220             pme->ndecompdim   = 1;
3221             pme->nodeid_major = 0;
3222             pme->nodeid_minor = pme->nodeid;
3223         }
3224         else
3225         {
3226             if (pme->nnodes % nnodes_major != 0)
3227             {
3228                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3229             }
3230             pme->ndecompdim = 2;
3231
3232 #ifdef GMX_MPI
3233             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3234                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3235             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3236                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3237
3238             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3239             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3240             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3241             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3242 #endif
3243         }
3244         pme->bPPnode = (cr->duty & DUTY_PP);
3245     }
3246
3247     pme->nthread = nthread;
3248
3249      /* Check if any of the PME MPI ranks uses threads */
3250     use_threads = (pme->nthread > 1 ? 1 : 0);
3251 #ifdef GMX_MPI
3252     if (pme->nnodes > 1)
3253     {
3254         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3255                       MPI_SUM, pme->mpi_comm);
3256     }
3257     else
3258 #endif
3259     {
3260         sum_use_threads = use_threads;
3261     }
3262     pme->bUseThreads = (sum_use_threads > 0);
3263
3264     if (ir->ePBC == epbcSCREW)
3265     {
3266         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3267     }
3268
3269     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3270     pme->nkx         = ir->nkx;
3271     pme->nky         = ir->nky;
3272     pme->nkz         = ir->nkz;
3273     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3274     pme->pme_order   = ir->pme_order;
3275     pme->epsilon_r   = ir->epsilon_r;
3276
3277     /* If we violate restrictions, generate a fatal error here */
3278     gmx_pme_check_restrictions(pme->pme_order,
3279                                pme->nkx, pme->nky, pme->nkz,
3280                                pme->nnodes_major,
3281                                pme->nnodes_minor,
3282                                pme->bUseThreads,
3283                                TRUE,
3284                                NULL);
3285
3286     if (pme->nnodes > 1)
3287     {
3288         double imbal;
3289
3290 #ifdef GMX_MPI
3291         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3292         MPI_Type_commit(&(pme->rvec_mpi));
3293 #endif
3294
3295         /* Note that the charge spreading and force gathering, which usually
3296          * takes about the same amount of time as FFT+solve_pme,
3297          * is always fully load balanced
3298          * (unless the charge distribution is inhomogeneous).
3299          */
3300
3301         imbal = pme_load_imbalance(pme);
3302         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3303         {
3304             fprintf(stderr,
3305                     "\n"
3306                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3307                     "      For optimal PME load balancing\n"
3308                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3309                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3310                     "\n",
3311                     (int)((imbal-1)*100 + 0.5),
3312                     pme->nkx, pme->nky, pme->nnodes_major,
3313                     pme->nky, pme->nkz, pme->nnodes_minor);
3314         }
3315     }
3316
3317     /* For non-divisible grid we need pme_order iso pme_order-1 */
3318     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3319      * y is always copied through a buffer: we don't need padding in z,
3320      * but we do need the overlap in x because of the communication order.
3321      */
3322     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3323 #ifdef GMX_MPI
3324                       pme->mpi_comm_d[0],
3325 #endif
3326                       pme->nnodes_major, pme->nodeid_major,
3327                       pme->nkx,
3328                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3329
3330     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3331      * We do this with an offset buffer of equal size, so we need to allocate
3332      * extra for the offset. That's what the (+1)*pme->nkz is for.
3333      */
3334     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3335 #ifdef GMX_MPI
3336                       pme->mpi_comm_d[1],
3337 #endif
3338                       pme->nnodes_minor, pme->nodeid_minor,
3339                       pme->nky,
3340                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3341
3342     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3343      * Note that gmx_pme_check_restrictions checked for this already.
3344      */
3345     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3346     {
3347         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3348     }
3349
3350     snew(pme->bsp_mod[XX], pme->nkx);
3351     snew(pme->bsp_mod[YY], pme->nky);
3352     snew(pme->bsp_mod[ZZ], pme->nkz);
3353
3354     /* The required size of the interpolation grid, including overlap.
3355      * The allocated size (pmegrid_n?) might be slightly larger.
3356      */
3357     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3358         pme->overlap[0].s2g0[pme->nodeid_major];
3359     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3360         pme->overlap[1].s2g0[pme->nodeid_minor];
3361     pme->pmegrid_nz_base = pme->nkz;
3362     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3363     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3364
3365     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3366     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3367     pme->pmegrid_start_iz = 0;
3368
3369     make_gridindex5_to_localindex(pme->nkx,
3370                                   pme->pmegrid_start_ix,
3371                                   pme->pmegrid_nx - (pme->pme_order-1),
3372                                   &pme->nnx, &pme->fshx);
3373     make_gridindex5_to_localindex(pme->nky,
3374                                   pme->pmegrid_start_iy,
3375                                   pme->pmegrid_ny - (pme->pme_order-1),
3376                                   &pme->nny, &pme->fshy);
3377     make_gridindex5_to_localindex(pme->nkz,
3378                                   pme->pmegrid_start_iz,
3379                                   pme->pmegrid_nz_base,
3380                                   &pme->nnz, &pme->fshz);
3381
3382     pmegrids_init(&pme->pmegridA,
3383                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3384                   pme->pmegrid_nz_base,
3385                   pme->pme_order,
3386                   pme->bUseThreads,
3387                   pme->nthread,
3388                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3389                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3390
3391     pme->spline_work = make_pme_spline_work(pme->pme_order);
3392
3393     ndata[0] = pme->nkx;
3394     ndata[1] = pme->nky;
3395     ndata[2] = pme->nkz;
3396
3397     /* This routine will allocate the grid data to fit the FFTs */
3398     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3399                             &pme->fftgridA, &pme->cfftgridA,
3400                             pme->mpi_comm_d,
3401                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3402                             bReproducible, pme->nthread);
3403
3404     if (bFreeEnergy)
3405     {
3406         pmegrids_init(&pme->pmegridB,
3407                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3408                       pme->pmegrid_nz_base,
3409                       pme->pme_order,
3410                       pme->bUseThreads,
3411                       pme->nthread,
3412                       pme->nkx % pme->nnodes_major != 0,
3413                       pme->nky % pme->nnodes_minor != 0);
3414
3415         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3416                                 &pme->fftgridB, &pme->cfftgridB,
3417                                 pme->mpi_comm_d,
3418                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3419                                 bReproducible, pme->nthread);
3420     }
3421     else
3422     {
3423         pme->pmegridB.grid.grid = NULL;
3424         pme->fftgridB           = NULL;
3425         pme->cfftgridB          = NULL;
3426     }
3427
3428     if (!pme->bP3M)
3429     {
3430         /* Use plain SPME B-spline interpolation */
3431         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3432     }
3433     else
3434     {
3435         /* Use the P3M grid-optimized influence function */
3436         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3437     }
3438
3439     /* Use atc[0] for spreading */
3440     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3441     if (pme->ndecompdim >= 2)
3442     {
3443         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3444     }
3445
3446     if (pme->nnodes == 1)
3447     {
3448         pme->atc[0].n = homenr;
3449         pme_realloc_atomcomm_things(&pme->atc[0]);
3450     }
3451
3452     {
3453         int thread;
3454
3455         /* Use fft5d, order after FFT is y major, z, x minor */
3456
3457         snew(pme->work, pme->nthread);
3458         for (thread = 0; thread < pme->nthread; thread++)
3459         {
3460             realloc_work(&pme->work[thread], pme->nkx);
3461         }
3462     }
3463
3464     *pmedata = pme;
3465
3466     return 0;
3467 }
3468
3469 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3470 {
3471     int d, t;
3472
3473     for (d = 0; d < DIM; d++)
3474     {
3475         if (new->grid.n[d] > old->grid.n[d])
3476         {
3477             return;
3478         }
3479     }
3480
3481     sfree_aligned(new->grid.grid);
3482     new->grid.grid = old->grid.grid;
3483
3484     if (new->grid_th != NULL && new->nthread == old->nthread)
3485     {
3486         sfree_aligned(new->grid_all);
3487         for (t = 0; t < new->nthread; t++)
3488         {
3489             new->grid_th[t].grid = old->grid_th[t].grid;
3490         }
3491     }
3492 }
3493
3494 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3495                    t_commrec *         cr,
3496                    gmx_pme_t           pme_src,
3497                    const t_inputrec *  ir,
3498                    ivec                grid_size)
3499 {
3500     t_inputrec irc;
3501     int homenr;
3502     int ret;
3503
3504     irc     = *ir;
3505     irc.nkx = grid_size[XX];
3506     irc.nky = grid_size[YY];
3507     irc.nkz = grid_size[ZZ];
3508
3509     if (pme_src->nnodes == 1)
3510     {
3511         homenr = pme_src->atc[0].n;
3512     }
3513     else
3514     {
3515         homenr = -1;
3516     }
3517
3518     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3519                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3520
3521     if (ret == 0)
3522     {
3523         /* We can easily reuse the allocated pme grids in pme_src */
3524         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3525         /* We would like to reuse the fft grids, but that's harder */
3526     }
3527
3528     return ret;
3529 }
3530
3531
3532 static void copy_local_grid(gmx_pme_t pme,
3533                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3534 {
3535     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3536     int  fft_my, fft_mz;
3537     int  nsx, nsy, nsz;
3538     ivec nf;
3539     int  offx, offy, offz, x, y, z, i0, i0t;
3540     int  d;
3541     pmegrid_t *pmegrid;
3542     real *grid_th;
3543
3544     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3545                                    local_fft_ndata,
3546                                    local_fft_offset,
3547                                    local_fft_size);
3548     fft_my = local_fft_size[YY];
3549     fft_mz = local_fft_size[ZZ];
3550
3551     pmegrid = &pmegrids->grid_th[thread];
3552
3553     nsx = pmegrid->s[XX];
3554     nsy = pmegrid->s[YY];
3555     nsz = pmegrid->s[ZZ];
3556
3557     for (d = 0; d < DIM; d++)
3558     {
3559         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3560                     local_fft_ndata[d] - pmegrid->offset[d]);
3561     }
3562
3563     offx = pmegrid->offset[XX];
3564     offy = pmegrid->offset[YY];
3565     offz = pmegrid->offset[ZZ];
3566
3567     /* Directly copy the non-overlapping parts of the local grids.
3568      * This also initializes the full grid.
3569      */
3570     grid_th = pmegrid->grid;
3571     for (x = 0; x < nf[XX]; x++)
3572     {
3573         for (y = 0; y < nf[YY]; y++)
3574         {
3575             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3576             i0t = (x*nsy + y)*nsz;
3577             for (z = 0; z < nf[ZZ]; z++)
3578             {
3579                 fftgrid[i0+z] = grid_th[i0t+z];
3580             }
3581         }
3582     }
3583 }
3584
3585 static void
3586 reduce_threadgrid_overlap(gmx_pme_t pme,
3587                           const pmegrids_t *pmegrids, int thread,
3588                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3589 {
3590     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3591     int  fft_nx, fft_ny, fft_nz;
3592     int  fft_my, fft_mz;
3593     int  buf_my = -1;
3594     int  nsx, nsy, nsz;
3595     ivec ne;
3596     int  offx, offy, offz, x, y, z, i0, i0t;
3597     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3598     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3599     gmx_bool bCommX, bCommY;
3600     int  d;
3601     int  thread_f;
3602     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3603     const real *grid_th;
3604     real *commbuf = NULL;
3605
3606     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3607                                    local_fft_ndata,
3608                                    local_fft_offset,
3609                                    local_fft_size);
3610     fft_nx = local_fft_ndata[XX];
3611     fft_ny = local_fft_ndata[YY];
3612     fft_nz = local_fft_ndata[ZZ];
3613
3614     fft_my = local_fft_size[YY];
3615     fft_mz = local_fft_size[ZZ];
3616
3617     /* This routine is called when all thread have finished spreading.
3618      * Here each thread sums grid contributions calculated by other threads
3619      * to the thread local grid volume.
3620      * To minimize the number of grid copying operations,
3621      * this routines sums immediately from the pmegrid to the fftgrid.
3622      */
3623
3624     /* Determine which part of the full node grid we should operate on,
3625      * this is our thread local part of the full grid.
3626      */
3627     pmegrid = &pmegrids->grid_th[thread];
3628
3629     for (d = 0; d < DIM; d++)
3630     {
3631         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3632                     local_fft_ndata[d]);
3633     }
3634
3635     offx = pmegrid->offset[XX];
3636     offy = pmegrid->offset[YY];
3637     offz = pmegrid->offset[ZZ];
3638
3639
3640     bClearBufX  = TRUE;
3641     bClearBufY  = TRUE;
3642     bClearBufXY = TRUE;
3643
3644     /* Now loop over all the thread data blocks that contribute
3645      * to the grid region we (our thread) are operating on.
3646      */
3647     /* Note that fft_nx/y is equal to the number of grid points
3648      * between the first point of our node grid and the one of the next node.
3649      */
3650     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3651     {
3652         fx     = pmegrid->ci[XX] + sx;
3653         ox     = 0;
3654         bCommX = FALSE;
3655         if (fx < 0)
3656         {
3657             fx    += pmegrids->nc[XX];
3658             ox    -= fft_nx;
3659             bCommX = (pme->nnodes_major > 1);
3660         }
3661         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3662         ox       += pmegrid_g->offset[XX];
3663         /* Determine the end of our part of the source grid */
3664         tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3665
3666         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3667         {
3668             fy     = pmegrid->ci[YY] + sy;
3669             oy     = 0;
3670             bCommY = FALSE;
3671             if (fy < 0)
3672             {
3673                 fy    += pmegrids->nc[YY];
3674                 oy    -= fft_ny;
3675                 bCommY = (pme->nnodes_minor > 1);
3676             }
3677             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3678             oy       += pmegrid_g->offset[YY];
3679             /* Determine the end of our part of the source grid */
3680             ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3681
3682             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3683             {
3684                 fz = pmegrid->ci[ZZ] + sz;
3685                 oz = 0;
3686                 if (fz < 0)
3687                 {
3688                     fz += pmegrids->nc[ZZ];
3689                     oz -= fft_nz;
3690                 }
3691                 pmegrid_g = &pmegrids->grid_th[fz];
3692                 oz       += pmegrid_g->offset[ZZ];
3693                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3694
3695                 if (sx == 0 && sy == 0 && sz == 0)
3696                 {
3697                     /* We have already added our local contribution
3698                      * before calling this routine, so skip it here.
3699                      */
3700                     continue;
3701                 }
3702
3703                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3704
3705                 pmegrid_f = &pmegrids->grid_th[thread_f];
3706
3707                 grid_th = pmegrid_f->grid;
3708
3709                 nsx = pmegrid_f->s[XX];
3710                 nsy = pmegrid_f->s[YY];
3711                 nsz = pmegrid_f->s[ZZ];
3712
3713 #ifdef DEBUG_PME_REDUCE
3714                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3715                        pme->nodeid, thread, thread_f,
3716                        pme->pmegrid_start_ix,
3717                        pme->pmegrid_start_iy,
3718                        pme->pmegrid_start_iz,
3719                        sx, sy, sz,
3720                        offx-ox, tx1-ox, offx, tx1,
3721                        offy-oy, ty1-oy, offy, ty1,
3722                        offz-oz, tz1-oz, offz, tz1);
3723 #endif
3724
3725                 if (!(bCommX || bCommY))
3726                 {
3727                     /* Copy from the thread local grid to the node grid */
3728                     for (x = offx; x < tx1; x++)
3729                     {
3730                         for (y = offy; y < ty1; y++)
3731                         {
3732                             i0  = (x*fft_my + y)*fft_mz;
3733                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3734                             for (z = offz; z < tz1; z++)
3735                             {
3736                                 fftgrid[i0+z] += grid_th[i0t+z];
3737                             }
3738                         }
3739                     }
3740                 }
3741                 else
3742                 {
3743                     /* The order of this conditional decides
3744                      * where the corner volume gets stored with x+y decomp.
3745                      */
3746                     if (bCommY)
3747                     {
3748                         commbuf = commbuf_y;
3749                         /* The y-size of the communication buffer is order-1 */
3750                         buf_my  = pmegrid->order - 1;
3751                         if (bCommX)
3752                         {
3753                             /* We index commbuf modulo the local grid size */
3754                             commbuf += buf_my*fft_nx*fft_nz;
3755
3756                             bClearBuf   = bClearBufXY;
3757                             bClearBufXY = FALSE;
3758                         }
3759                         else
3760                         {
3761                             bClearBuf  = bClearBufY;
3762                             bClearBufY = FALSE;
3763                         }
3764                     }
3765                     else
3766                     {
3767                         commbuf    = commbuf_x;
3768                         buf_my     = fft_ny;
3769                         bClearBuf  = bClearBufX;
3770                         bClearBufX = FALSE;
3771                     }
3772
3773                     /* Copy to the communication buffer */
3774                     for (x = offx; x < tx1; x++)
3775                     {
3776                         for (y = offy; y < ty1; y++)
3777                         {
3778                             i0  = (x*buf_my + y)*fft_nz;
3779                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3780
3781                             if (bClearBuf)
3782                             {
3783                                 /* First access of commbuf, initialize it */
3784                                 for (z = offz; z < tz1; z++)
3785                                 {
3786                                     commbuf[i0+z]  = grid_th[i0t+z];
3787                                 }
3788                             }
3789                             else
3790                             {
3791                                 for (z = offz; z < tz1; z++)
3792                                 {
3793                                     commbuf[i0+z] += grid_th[i0t+z];
3794                                 }
3795                             }
3796                         }
3797                     }
3798                 }
3799             }
3800         }
3801     }
3802 }
3803
3804
3805 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3806 {
3807     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3808     pme_overlap_t *overlap;
3809     int  send_index0, send_nindex;
3810     int  recv_nindex;
3811 #ifdef GMX_MPI
3812     MPI_Status stat;
3813 #endif
3814     int  send_size_y, recv_size_y;
3815     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3816     real *sendptr, *recvptr;
3817     int  x, y, z, indg, indb;
3818
3819     /* Note that this routine is only used for forward communication.
3820      * Since the force gathering, unlike the charge spreading,
3821      * can be trivially parallelized over the particles,
3822      * the backwards process is much simpler and can use the "old"
3823      * communication setup.
3824      */
3825
3826     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3827                                    local_fft_ndata,
3828                                    local_fft_offset,
3829                                    local_fft_size);
3830
3831     if (pme->nnodes_minor > 1)
3832     {
3833         /* Major dimension */
3834         overlap = &pme->overlap[1];
3835
3836         if (pme->nnodes_major > 1)
3837         {
3838             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3839         }
3840         else
3841         {
3842             size_yx = 0;
3843         }
3844         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3845
3846         send_size_y = overlap->send_size;
3847
3848         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3849         {
3850             send_id       = overlap->send_id[ipulse];
3851             recv_id       = overlap->recv_id[ipulse];
3852             send_index0   =
3853                 overlap->comm_data[ipulse].send_index0 -
3854                 overlap->comm_data[0].send_index0;
3855             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3856             /* We don't use recv_index0, as we always receive starting at 0 */
3857             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3858             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3859
3860             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3861             recvptr = overlap->recvbuf;
3862
3863 #ifdef GMX_MPI
3864             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3865                          send_id, ipulse,
3866                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3867                          recv_id, ipulse,
3868                          overlap->mpi_comm, &stat);
3869 #endif
3870
3871             for (x = 0; x < local_fft_ndata[XX]; x++)
3872             {
3873                 for (y = 0; y < recv_nindex; y++)
3874                 {
3875                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3876                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3877                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3878                     {
3879                         fftgrid[indg+z] += recvptr[indb+z];
3880                     }
3881                 }
3882             }
3883
3884             if (pme->nnodes_major > 1)
3885             {
3886                 /* Copy from the received buffer to the send buffer for dim 0 */
3887                 sendptr = pme->overlap[0].sendbuf;
3888                 for (x = 0; x < size_yx; x++)
3889                 {
3890                     for (y = 0; y < recv_nindex; y++)
3891                     {
3892                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3893                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3894                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3895                         {
3896                             sendptr[indg+z] += recvptr[indb+z];
3897                         }
3898                     }
3899                 }
3900             }
3901         }
3902     }
3903
3904     /* We only support a single pulse here.
3905      * This is not a severe limitation, as this code is only used
3906      * with OpenMP and with OpenMP the (PME) domains can be larger.
3907      */
3908     if (pme->nnodes_major > 1)
3909     {
3910         /* Major dimension */
3911         overlap = &pme->overlap[0];
3912
3913         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3914         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3915
3916         ipulse = 0;
3917
3918         send_id       = overlap->send_id[ipulse];
3919         recv_id       = overlap->recv_id[ipulse];
3920         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3921         /* We don't use recv_index0, as we always receive starting at 0 */
3922         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3923
3924         sendptr = overlap->sendbuf;
3925         recvptr = overlap->recvbuf;
3926
3927         if (debug != NULL)
3928         {
3929             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3930                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3931         }
3932
3933 #ifdef GMX_MPI
3934         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3935                      send_id, ipulse,
3936                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3937                      recv_id, ipulse,
3938                      overlap->mpi_comm, &stat);
3939 #endif
3940
3941         for (x = 0; x < recv_nindex; x++)
3942         {
3943             for (y = 0; y < local_fft_ndata[YY]; y++)
3944             {
3945                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3946                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3947                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3948                 {
3949                     fftgrid[indg+z] += recvptr[indb+z];
3950                 }
3951             }
3952         }
3953     }
3954 }
3955
3956
3957 static void spread_on_grid(gmx_pme_t pme,
3958                            pme_atomcomm_t *atc, pmegrids_t *grids,
3959                            gmx_bool bCalcSplines, gmx_bool bSpread,
3960                            real *fftgrid)
3961 {
3962     int nthread, thread;
3963 #ifdef PME_TIME_THREADS
3964     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3965     static double cs1     = 0, cs2 = 0, cs3 = 0;
3966     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3967     static int cnt        = 0;
3968 #endif
3969
3970     nthread = pme->nthread;
3971     assert(nthread > 0);
3972
3973 #ifdef PME_TIME_THREADS
3974     c1 = omp_cyc_start();
3975 #endif
3976     if (bCalcSplines)
3977     {
3978 #pragma omp parallel for num_threads(nthread) schedule(static)
3979         for (thread = 0; thread < nthread; thread++)
3980         {
3981             int start, end;
3982
3983             start = atc->n* thread   /nthread;
3984             end   = atc->n*(thread+1)/nthread;
3985
3986             /* Compute fftgrid index for all atoms,
3987              * with help of some extra variables.
3988              */
3989             calc_interpolation_idx(pme, atc, start, end, thread);
3990         }
3991     }
3992 #ifdef PME_TIME_THREADS
3993     c1   = omp_cyc_end(c1);
3994     cs1 += (double)c1;
3995 #endif
3996
3997 #ifdef PME_TIME_THREADS
3998     c2 = omp_cyc_start();
3999 #endif
4000 #pragma omp parallel for num_threads(nthread) schedule(static)
4001     for (thread = 0; thread < nthread; thread++)
4002     {
4003         splinedata_t *spline;
4004         pmegrid_t *grid = NULL;
4005
4006         /* make local bsplines  */
4007         if (grids == NULL || !pme->bUseThreads)
4008         {
4009             spline = &atc->spline[0];
4010
4011             spline->n = atc->n;
4012
4013             if (bSpread)
4014             {
4015                 grid = &grids->grid;
4016             }
4017         }
4018         else
4019         {
4020             spline = &atc->spline[thread];
4021
4022             if (grids->nthread == 1)
4023             {
4024                 /* One thread, we operate on all charges */
4025                 spline->n = atc->n;
4026             }
4027             else
4028             {
4029                 /* Get the indices our thread should operate on */
4030                 make_thread_local_ind(atc, thread, spline);
4031             }
4032
4033             grid = &grids->grid_th[thread];
4034         }
4035
4036         if (bCalcSplines)
4037         {
4038             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4039                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
4040         }
4041
4042         if (bSpread)
4043         {
4044             /* put local atoms on grid. */
4045 #ifdef PME_TIME_SPREAD
4046             ct1a = omp_cyc_start();
4047 #endif
4048             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4049
4050             if (pme->bUseThreads)
4051             {
4052                 copy_local_grid(pme, grids, thread, fftgrid);
4053             }
4054 #ifdef PME_TIME_SPREAD
4055             ct1a          = omp_cyc_end(ct1a);
4056             cs1a[thread] += (double)ct1a;
4057 #endif
4058         }
4059     }
4060 #ifdef PME_TIME_THREADS
4061     c2   = omp_cyc_end(c2);
4062     cs2 += (double)c2;
4063 #endif
4064
4065     if (bSpread && pme->bUseThreads)
4066     {
4067 #ifdef PME_TIME_THREADS
4068         c3 = omp_cyc_start();
4069 #endif
4070 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4071         for (thread = 0; thread < grids->nthread; thread++)
4072         {
4073             reduce_threadgrid_overlap(pme, grids, thread,
4074                                       fftgrid,
4075                                       pme->overlap[0].sendbuf,
4076                                       pme->overlap[1].sendbuf);
4077         }
4078 #ifdef PME_TIME_THREADS
4079         c3   = omp_cyc_end(c3);
4080         cs3 += (double)c3;
4081 #endif
4082
4083         if (pme->nnodes > 1)
4084         {
4085             /* Communicate the overlapping part of the fftgrid.
4086              * For this communication call we need to check pme->bUseThreads
4087              * to have all ranks communicate here, regardless of pme->nthread.
4088              */
4089             sum_fftgrid_dd(pme, fftgrid);
4090         }
4091     }
4092
4093 #ifdef PME_TIME_THREADS
4094     cnt++;
4095     if (cnt % 20 == 0)
4096     {
4097         printf("idx %.2f spread %.2f red %.2f",
4098                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4099 #ifdef PME_TIME_SPREAD
4100         for (thread = 0; thread < nthread; thread++)
4101         {
4102             printf(" %.2f", cs1a[thread]*1e-9);
4103         }
4104 #endif
4105         printf("\n");
4106     }
4107 #endif
4108 }
4109
4110
4111 static void dump_grid(FILE *fp,
4112                       int sx, int sy, int sz, int nx, int ny, int nz,
4113                       int my, int mz, const real *g)
4114 {
4115     int x, y, z;
4116
4117     for (x = 0; x < nx; x++)
4118     {
4119         for (y = 0; y < ny; y++)
4120         {
4121             for (z = 0; z < nz; z++)
4122             {
4123                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4124                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4125             }
4126         }
4127     }
4128 }
4129
4130 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4131 {
4132     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4133
4134     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4135                                    local_fft_ndata,
4136                                    local_fft_offset,
4137                                    local_fft_size);
4138
4139     dump_grid(stderr,
4140               pme->pmegrid_start_ix,
4141               pme->pmegrid_start_iy,
4142               pme->pmegrid_start_iz,
4143               pme->pmegrid_nx-pme->pme_order+1,
4144               pme->pmegrid_ny-pme->pme_order+1,
4145               pme->pmegrid_nz-pme->pme_order+1,
4146               local_fft_size[YY],
4147               local_fft_size[ZZ],
4148               fftgrid);
4149 }
4150
4151
4152 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4153 {
4154     pme_atomcomm_t *atc;
4155     pmegrids_t *grid;
4156
4157     if (pme->nnodes > 1)
4158     {
4159         gmx_incons("gmx_pme_calc_energy called in parallel");
4160     }
4161     if (pme->bFEP > 1)
4162     {
4163         gmx_incons("gmx_pme_calc_energy with free energy");
4164     }
4165
4166     atc            = &pme->atc_energy;
4167     atc->nthread   = 1;
4168     if (atc->spline == NULL)
4169     {
4170         snew(atc->spline, atc->nthread);
4171     }
4172     atc->nslab     = 1;
4173     atc->bSpread   = TRUE;
4174     atc->pme_order = pme->pme_order;
4175     atc->n         = n;
4176     pme_realloc_atomcomm_things(atc);
4177     atc->x         = x;
4178     atc->q         = q;
4179
4180     /* We only use the A-charges grid */
4181     grid = &pme->pmegridA;
4182
4183     /* Only calculate the spline coefficients, don't actually spread */
4184     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4185
4186     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4187 }
4188
4189
4190 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4191                                    gmx_runtime_t *runtime,
4192                                    t_nrnb *nrnb, t_inputrec *ir,
4193                                    gmx_large_int_t step)
4194 {
4195     /* Reset all the counters related to performance over the run */
4196     wallcycle_stop(wcycle, ewcRUN);
4197     wallcycle_reset_all(wcycle);
4198     init_nrnb(nrnb);
4199     if (ir->nsteps >= 0)
4200     {
4201         /* ir->nsteps is not used here, but we update it for consistency */
4202         ir->nsteps -= step - ir->init_step;
4203     }
4204     ir->init_step = step;
4205     wallcycle_start(wcycle, ewcRUN);
4206     runtime_start(runtime);
4207 }
4208
4209
4210 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4211                                ivec grid_size,
4212                                t_commrec *cr, t_inputrec *ir,
4213                                gmx_pme_t *pme_ret)
4214 {
4215     int ind;
4216     gmx_pme_t pme = NULL;
4217
4218     ind = 0;
4219     while (ind < *npmedata)
4220     {
4221         pme = (*pmedata)[ind];
4222         if (pme->nkx == grid_size[XX] &&
4223             pme->nky == grid_size[YY] &&
4224             pme->nkz == grid_size[ZZ])
4225         {
4226             *pme_ret = pme;
4227
4228             return;
4229         }
4230
4231         ind++;
4232     }
4233
4234     (*npmedata)++;
4235     srenew(*pmedata, *npmedata);
4236
4237     /* Generate a new PME data structure, copying part of the old pointers */
4238     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4239
4240     *pme_ret = (*pmedata)[ind];
4241 }
4242
4243
4244 int gmx_pmeonly(gmx_pme_t pme,
4245                 t_commrec *cr,    t_nrnb *nrnb,
4246                 gmx_wallcycle_t wcycle,
4247                 gmx_runtime_t *runtime,
4248                 real ewaldcoeff,  gmx_bool bGatherOnly,
4249                 t_inputrec *ir)
4250 {
4251     int npmedata;
4252     gmx_pme_t *pmedata;
4253     gmx_pme_pp_t pme_pp;
4254     int  ret;
4255     int  natoms;
4256     matrix box;
4257     rvec *x_pp      = NULL, *f_pp = NULL;
4258     real *chargeA   = NULL, *chargeB = NULL;
4259     real lambda     = 0;
4260     int  maxshift_x = 0, maxshift_y = 0;
4261     real energy, dvdlambda;
4262     matrix vir;
4263     float cycles;
4264     int  count;
4265     gmx_bool bEnerVir;
4266     gmx_large_int_t step, step_rel;
4267     ivec grid_switch;
4268
4269     /* This data will only use with PME tuning, i.e. switching PME grids */
4270     npmedata = 1;
4271     snew(pmedata, npmedata);
4272     pmedata[0] = pme;
4273
4274     pme_pp = gmx_pme_pp_init(cr);
4275
4276     init_nrnb(nrnb);
4277
4278     count = 0;
4279     do /****** this is a quasi-loop over time steps! */
4280     {
4281         /* The reason for having a loop here is PME grid tuning/switching */
4282         do
4283         {
4284             /* Domain decomposition */
4285             ret = gmx_pme_recv_q_x(pme_pp,
4286                                    &natoms,
4287                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4288                                    &maxshift_x, &maxshift_y,
4289                                    &pme->bFEP, &lambda,
4290                                    &bEnerVir,
4291                                    &step,
4292                                    grid_switch, &ewaldcoeff);
4293
4294             if (ret == pmerecvqxSWITCHGRID)
4295             {
4296                 /* Switch the PME grid to grid_switch */
4297                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4298             }
4299
4300             if (ret == pmerecvqxRESETCOUNTERS)
4301             {
4302                 /* Reset the cycle and flop counters */
4303                 reset_pmeonly_counters(cr, wcycle, runtime, nrnb, ir, step);
4304             }
4305         }
4306         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4307
4308         if (ret == pmerecvqxFINISH)
4309         {
4310             /* We should stop: break out of the loop */
4311             break;
4312         }
4313
4314         step_rel = step - ir->init_step;
4315
4316         if (count == 0)
4317         {
4318             wallcycle_start(wcycle, ewcRUN);
4319             runtime_start(runtime);
4320         }
4321
4322         wallcycle_start(wcycle, ewcPMEMESH);
4323
4324         dvdlambda = 0;
4325         clear_mat(vir);
4326         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4327                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4328                    &energy, lambda, &dvdlambda,
4329                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4330
4331         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4332
4333         gmx_pme_send_force_vir_ener(pme_pp,
4334                                     f_pp, vir, energy, dvdlambda,
4335                                     cycles);
4336
4337         count++;
4338     } /***** end of quasi-loop, we stop with the break above */
4339     while (TRUE);
4340
4341     runtime_end(runtime);
4342
4343     return 0;
4344 }
4345
4346 int gmx_pme_do(gmx_pme_t pme,
4347                int start,       int homenr,
4348                rvec x[],        rvec f[],
4349                real *chargeA,   real *chargeB,
4350                matrix box, t_commrec *cr,
4351                int  maxshift_x, int maxshift_y,
4352                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4353                matrix vir,      real ewaldcoeff,
4354                real *energy,    real lambda,
4355                real *dvdlambda, int flags)
4356 {
4357     int     q, d, i, j, ntot, npme;
4358     int     nx, ny, nz;
4359     int     n_d, local_ny;
4360     pme_atomcomm_t *atc = NULL;
4361     pmegrids_t *pmegrid = NULL;
4362     real    *grid       = NULL;
4363     real    *ptr;
4364     rvec    *x_d, *f_d;
4365     real    *charge = NULL, *q_d;
4366     real    energy_AB[2];
4367     matrix  vir_AB[2];
4368     gmx_bool bClearF;
4369     gmx_parallel_3dfft_t pfft_setup;
4370     real *  fftgrid;
4371     t_complex * cfftgrid;
4372     int     thread;
4373     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4374     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4375
4376     assert(pme->nnodes > 0);
4377     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4378
4379     if (pme->nnodes > 1)
4380     {
4381         atc      = &pme->atc[0];
4382         atc->npd = homenr;
4383         if (atc->npd > atc->pd_nalloc)
4384         {
4385             atc->pd_nalloc = over_alloc_dd(atc->npd);
4386             srenew(atc->pd, atc->pd_nalloc);
4387         }
4388         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4389     }
4390     else
4391     {
4392         /* This could be necessary for TPI */
4393         pme->atc[0].n = homenr;
4394     }
4395
4396     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4397     {
4398         if (q == 0)
4399         {
4400             pmegrid    = &pme->pmegridA;
4401             fftgrid    = pme->fftgridA;
4402             cfftgrid   = pme->cfftgridA;
4403             pfft_setup = pme->pfft_setupA;
4404             charge     = chargeA+start;
4405         }
4406         else
4407         {
4408             pmegrid    = &pme->pmegridB;
4409             fftgrid    = pme->fftgridB;
4410             cfftgrid   = pme->cfftgridB;
4411             pfft_setup = pme->pfft_setupB;
4412             charge     = chargeB+start;
4413         }
4414         grid = pmegrid->grid.grid;
4415         /* Unpack structure */
4416         if (debug)
4417         {
4418             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4419                     cr->nnodes, cr->nodeid);
4420             fprintf(debug, "Grid = %p\n", (void*)grid);
4421             if (grid == NULL)
4422             {
4423                 gmx_fatal(FARGS, "No grid!");
4424             }
4425         }
4426         where();
4427
4428         m_inv_ur0(box, pme->recipbox);
4429
4430         if (pme->nnodes == 1)
4431         {
4432             atc = &pme->atc[0];
4433             if (DOMAINDECOMP(cr))
4434             {
4435                 atc->n = homenr;
4436                 pme_realloc_atomcomm_things(atc);
4437             }
4438             atc->x = x;
4439             atc->q = charge;
4440             atc->f = f;
4441         }
4442         else
4443         {
4444             wallcycle_start(wcycle, ewcPME_REDISTXF);
4445             for (d = pme->ndecompdim-1; d >= 0; d--)
4446             {
4447                 if (d == pme->ndecompdim-1)
4448                 {
4449                     n_d = homenr;
4450                     x_d = x + start;
4451                     q_d = charge;
4452                 }
4453                 else
4454                 {
4455                     n_d = pme->atc[d+1].n;
4456                     x_d = atc->x;
4457                     q_d = atc->q;
4458                 }
4459                 atc      = &pme->atc[d];
4460                 atc->npd = n_d;
4461                 if (atc->npd > atc->pd_nalloc)
4462                 {
4463                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4464                     srenew(atc->pd, atc->pd_nalloc);
4465                 }
4466                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4467                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4468                 where();
4469
4470                 GMX_BARRIER(cr->mpi_comm_mygroup);
4471                 /* Redistribute x (only once) and qA or qB */
4472                 if (DOMAINDECOMP(cr))
4473                 {
4474                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4475                 }
4476                 else
4477                 {
4478                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4479                 }
4480             }
4481             where();
4482
4483             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4484         }
4485
4486         if (debug)
4487         {
4488             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4489                     cr->nodeid, atc->n);
4490         }
4491
4492         if (flags & GMX_PME_SPREAD_Q)
4493         {
4494             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4495
4496             /* Spread the charges on a grid */
4497             GMX_MPE_LOG(ev_spread_on_grid_start);
4498
4499             /* Spread the charges on a grid */
4500             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4501             GMX_MPE_LOG(ev_spread_on_grid_finish);
4502
4503             if (q == 0)
4504             {
4505                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4506             }
4507             inc_nrnb(nrnb, eNR_SPREADQBSP,
4508                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4509
4510             if (!pme->bUseThreads)
4511             {
4512                 wrap_periodic_pmegrid(pme, grid);
4513
4514                 /* sum contributions to local grid from other nodes */
4515 #ifdef GMX_MPI
4516                 if (pme->nnodes > 1)
4517                 {
4518                     GMX_BARRIER(cr->mpi_comm_mygroup);
4519                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4520                     where();
4521                 }
4522 #endif
4523
4524                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4525             }
4526
4527             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4528
4529             /*
4530                dump_local_fftgrid(pme,fftgrid);
4531                exit(0);
4532              */
4533         }
4534
4535         /* Here we start a large thread parallel region */
4536 #pragma omp parallel num_threads(pme->nthread) private(thread)
4537         {
4538             thread = gmx_omp_get_thread_num();
4539             if (flags & GMX_PME_SOLVE)
4540             {
4541                 int loop_count;
4542
4543                 /* do 3d-fft */
4544                 if (thread == 0)
4545                 {
4546                     GMX_BARRIER(cr->mpi_comm_mygroup);
4547                     GMX_MPE_LOG(ev_gmxfft3d_start);
4548                     wallcycle_start(wcycle, ewcPME_FFT);
4549                 }
4550                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4551                                            fftgrid, cfftgrid, thread, wcycle);
4552                 if (thread == 0)
4553                 {
4554                     wallcycle_stop(wcycle, ewcPME_FFT);
4555                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4556                 }
4557                 where();
4558
4559                 /* solve in k-space for our local cells */
4560                 if (thread == 0)
4561                 {
4562                     GMX_BARRIER(cr->mpi_comm_mygroup);
4563                     GMX_MPE_LOG(ev_solve_pme_start);
4564                     wallcycle_start(wcycle, ewcPME_SOLVE);
4565                 }
4566                 loop_count =
4567                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4568                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4569                                   bCalcEnerVir,
4570                                   pme->nthread, thread);
4571                 if (thread == 0)
4572                 {
4573                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4574                     where();
4575                     GMX_MPE_LOG(ev_solve_pme_finish);
4576                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4577                 }
4578             }
4579
4580             if (bCalcF)
4581             {
4582                 /* do 3d-invfft */
4583                 if (thread == 0)
4584                 {
4585                     GMX_BARRIER(cr->mpi_comm_mygroup);
4586                     GMX_MPE_LOG(ev_gmxfft3d_start);
4587                     where();
4588                     wallcycle_start(wcycle, ewcPME_FFT);
4589                 }
4590                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4591                                            cfftgrid, fftgrid, thread, wcycle);
4592                 if (thread == 0)
4593                 {
4594                     wallcycle_stop(wcycle, ewcPME_FFT);
4595
4596                     where();
4597                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4598
4599                     if (pme->nodeid == 0)
4600                     {
4601                         ntot  = pme->nkx*pme->nky*pme->nkz;
4602                         npme  = ntot*log((real)ntot)/log(2.0);
4603                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4604                     }
4605
4606                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4607                 }
4608
4609                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4610             }
4611         }
4612         /* End of thread parallel section.
4613          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4614          */
4615
4616         if (bCalcF)
4617         {
4618             /* distribute local grid to all nodes */
4619 #ifdef GMX_MPI
4620             if (pme->nnodes > 1)
4621             {
4622                 GMX_BARRIER(cr->mpi_comm_mygroup);
4623                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4624             }
4625 #endif
4626             where();
4627
4628             unwrap_periodic_pmegrid(pme, grid);
4629
4630             /* interpolate forces for our local atoms */
4631             GMX_BARRIER(cr->mpi_comm_mygroup);
4632             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4633
4634             where();
4635
4636             /* If we are running without parallelization,
4637              * atc->f is the actual force array, not a buffer,
4638              * therefore we should not clear it.
4639              */
4640             bClearF = (q == 0 && PAR(cr));
4641 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4642             for (thread = 0; thread < pme->nthread; thread++)
4643             {
4644                 gather_f_bsplines(pme, grid, bClearF, atc,
4645                                   &atc->spline[thread],
4646                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4647             }
4648
4649             where();
4650
4651             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4652
4653             inc_nrnb(nrnb, eNR_GATHERFBSP,
4654                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4655             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4656         }
4657
4658         if (bCalcEnerVir)
4659         {
4660             /* This should only be called on the master thread
4661              * and after the threads have synchronized.
4662              */
4663             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4664         }
4665     } /* of q-loop */
4666
4667     if (bCalcF && pme->nnodes > 1)
4668     {
4669         wallcycle_start(wcycle, ewcPME_REDISTXF);
4670         for (d = 0; d < pme->ndecompdim; d++)
4671         {
4672             atc = &pme->atc[d];
4673             if (d == pme->ndecompdim - 1)
4674             {
4675                 n_d = homenr;
4676                 f_d = f + start;
4677             }
4678             else
4679             {
4680                 n_d = pme->atc[d+1].n;
4681                 f_d = pme->atc[d+1].f;
4682             }
4683             GMX_BARRIER(cr->mpi_comm_mygroup);
4684             if (DOMAINDECOMP(cr))
4685             {
4686                 dd_pmeredist_f(pme, atc, n_d, f_d,
4687                                d == pme->ndecompdim-1 && pme->bPPnode);
4688             }
4689             else
4690             {
4691                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4692             }
4693         }
4694
4695         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4696     }
4697     where();
4698
4699     if (bCalcEnerVir)
4700     {
4701         if (!pme->bFEP)
4702         {
4703             *energy = energy_AB[0];
4704             m_add(vir, vir_AB[0], vir);
4705         }
4706         else
4707         {
4708             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4709             *dvdlambda += energy_AB[1] - energy_AB[0];
4710             for (i = 0; i < DIM; i++)
4711             {
4712                 for (j = 0; j < DIM; j++)
4713                 {
4714                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4715                         lambda*vir_AB[1][i][j];
4716                 }
4717             }
4718         }
4719     }
4720     else
4721     {
4722         *energy = 0;
4723     }
4724
4725     if (debug)
4726     {
4727         fprintf(debug, "PME mesh energy: %g\n", *energy);
4728     }
4729
4730     return 0;
4731 }