implemented plain-C SIMD macros for reference
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team,
6  * check out http://www.gromacs.org for more information.
7  * Copyright (c) 2012,2013, by the GROMACS development team, led by
8  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
9  * others, as listed in the AUTHORS file in the top-level source
10  * directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* IMPORTANT FOR DEVELOPERS:
39  *
40  * Triclinic pme stuff isn't entirely trivial, and we've experienced
41  * some bugs during development (many of them due to me). To avoid
42  * this in the future, please check the following things if you make
43  * changes in this file:
44  *
45  * 1. You should obtain identical (at least to the PME precision)
46  *    energies, forces, and virial for
47  *    a rectangular box and a triclinic one where the z (or y) axis is
48  *    tilted a whole box side. For instance you could use these boxes:
49  *
50  *    rectangular       triclinic
51  *     2  0  0           2  0  0
52  *     0  2  0           0  2  0
53  *     0  0  6           2  2  6
54  *
55  * 2. You should check the energy conservation in a triclinic box.
56  *
57  * It might seem an overkill, but better safe than sorry.
58  * /Erik 001109
59  */
60
61 #ifdef HAVE_CONFIG_H
62 #include <config.h>
63 #endif
64
65 #ifdef GMX_LIB_MPI
66 #include <mpi.h>
67 #endif
68 #ifdef GMX_THREAD_MPI
69 #include "tmpi.h"
70 #endif
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93 #include "gmx_omp.h"
94
95
96 /* Include the SIMD macro file and then check for support */
97 #include "gmx_simd_macros.h"
98 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
99 /* Turn on SIMD intrinsics for PME solve */
100 #define PME_SIMD
101 #endif
102
103 /* SIMD spread+gather only in single precision with SSE2 or higher available.
104  * We might want to switch to use gmx_simd_macros.h, but this is somewhat
105  * complicated, as we use unaligned and/or 4-wide only loads.
106  */
107 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
108 #define PME_SSE_SPREAD_GATHER
109 #include <emmintrin.h>
110 /* Some old AMD processors could have problems with unaligned loads+stores */
111 #ifndef GMX_FAHCORE
112 #define PME_SSE_UNALIGNED
113 #endif
114 #endif
115
116 #include "mpelogging.h"
117
118 #define DFT_TOL 1e-7
119 /* #define PRT_FORCE */
120 /* conditions for on the fly time-measurement */
121 /* #define TAKETIME (step > 1 && timesteps < 10) */
122 #define TAKETIME FALSE
123
124 /* #define PME_TIME_THREADS */
125
126 #ifdef GMX_DOUBLE
127 #define mpi_type MPI_DOUBLE
128 #else
129 #define mpi_type MPI_FLOAT
130 #endif
131
132 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
133 #define GMX_CACHE_SEP 64
134
135 /* We only define a maximum to be able to use local arrays without allocation.
136  * An order larger than 12 should never be needed, even for test cases.
137  * If needed it can be changed here.
138  */
139 #define PME_ORDER_MAX 12
140
141 /* Internal datastructures */
142 typedef struct {
143     int send_index0;
144     int send_nindex;
145     int recv_index0;
146     int recv_nindex;
147     int recv_size;   /* Receive buffer width, used with OpenMP */
148 } pme_grid_comm_t;
149
150 typedef struct {
151 #ifdef GMX_MPI
152     MPI_Comm         mpi_comm;
153 #endif
154     int              nnodes, nodeid;
155     int             *s2g0;
156     int             *s2g1;
157     int              noverlap_nodes;
158     int             *send_id, *recv_id;
159     int              send_size; /* Send buffer width, used with OpenMP */
160     pme_grid_comm_t *comm_data;
161     real            *sendbuf;
162     real            *recvbuf;
163 } pme_overlap_t;
164
165 typedef struct {
166     int *n;      /* Cumulative counts of the number of particles per thread */
167     int  nalloc; /* Allocation size of i */
168     int *i;      /* Particle indices ordered on thread index (n) */
169 } thread_plist_t;
170
171 typedef struct {
172     int      *thread_one;
173     int       n;
174     int      *ind;
175     splinevec theta;
176     real     *ptr_theta_z;
177     splinevec dtheta;
178     real     *ptr_dtheta_z;
179 } splinedata_t;
180
181 typedef struct {
182     int      dimind;        /* The index of the dimension, 0=x, 1=y */
183     int      nslab;
184     int      nodeid;
185 #ifdef GMX_MPI
186     MPI_Comm mpi_comm;
187 #endif
188
189     int     *node_dest;     /* The nodes to send x and q to with DD */
190     int     *node_src;      /* The nodes to receive x and q from with DD */
191     int     *buf_index;     /* Index for commnode into the buffers */
192
193     int      maxshift;
194
195     int      npd;
196     int      pd_nalloc;
197     int     *pd;
198     int     *count;         /* The number of atoms to send to each node */
199     int    **count_thread;
200     int     *rcount;        /* The number of atoms to receive */
201
202     int      n;
203     int      nalloc;
204     rvec    *x;
205     real    *q;
206     rvec    *f;
207     gmx_bool bSpread;       /* These coordinates are used for spreading */
208     int      pme_order;
209     ivec    *idx;
210     rvec    *fractx;            /* Fractional coordinate relative to the
211                                  * lower cell boundary
212                                  */
213     int             nthread;
214     int            *thread_idx; /* Which thread should spread which charge */
215     thread_plist_t *thread_plist;
216     splinedata_t   *spline;
217 } pme_atomcomm_t;
218
219 #define FLBS  3
220 #define FLBSZ 4
221
222 typedef struct {
223     ivec  ci;     /* The spatial location of this grid         */
224     ivec  n;      /* The used size of *grid, including order-1 */
225     ivec  offset; /* The grid offset from the full node grid   */
226     int   order;  /* PME spreading order                       */
227     ivec  s;      /* The allocated size of *grid, s >= n       */
228     real *grid;   /* The grid local thread, size n             */
229 } pmegrid_t;
230
231 typedef struct {
232     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
233     int        nthread;      /* The number of threads operating on this grid     */
234     ivec       nc;           /* The local spatial decomposition over the threads */
235     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
236     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
237     int      **g2t;          /* The grid to thread index                         */
238     ivec       nthread_comm; /* The number of threads to communicate with        */
239 } pmegrids_t;
240
241
242 typedef struct {
243 #ifdef PME_SSE_SPREAD_GATHER
244     /* Masks for SSE aligned spreading and gathering */
245     __m128 mask_SSE0[6], mask_SSE1[6];
246 #else
247     int    dummy; /* C89 requires that struct has at least one member */
248 #endif
249 } pme_spline_work_t;
250
251 typedef struct {
252     /* work data for solve_pme */
253     int      nalloc;
254     real *   mhx;
255     real *   mhy;
256     real *   mhz;
257     real *   m2;
258     real *   denom;
259     real *   tmp1_alloc;
260     real *   tmp1;
261     real *   eterm;
262     real *   m2inv;
263
264     real     energy;
265     matrix   vir;
266 } pme_work_t;
267
268 typedef struct gmx_pme {
269     int           ndecompdim; /* The number of decomposition dimensions */
270     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
271     int           nodeid_major;
272     int           nodeid_minor;
273     int           nnodes;    /* The number of nodes doing PME */
274     int           nnodes_major;
275     int           nnodes_minor;
276
277     MPI_Comm      mpi_comm;
278     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
279 #ifdef GMX_MPI
280     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
281 #endif
282
283     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
284     int        nthread;       /* The number of threads doing PME on our rank */
285
286     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
287     gmx_bool   bFEP;          /* Compute Free energy contribution */
288     int        nkx, nky, nkz; /* Grid dimensions */
289     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
290     int        pme_order;
291     real       epsilon_r;
292
293     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
294     pmegrids_t pmegridB;
295     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
296     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
297     /* pmegrid_nz might be larger than strictly necessary to ensure
298      * memory alignment, pmegrid_nz_base gives the real base size.
299      */
300     int     pmegrid_nz_base;
301     /* The local PME grid starting indices */
302     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
303
304     /* Work data for spreading and gathering */
305     pme_spline_work_t    *spline_work;
306
307     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
308     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
309     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
310
311     t_complex            *cfftgridA;  /* Grids for complex FFT data */
312     t_complex            *cfftgridB;
313     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
314
315     gmx_parallel_3dfft_t  pfft_setupA;
316     gmx_parallel_3dfft_t  pfft_setupB;
317
318     int                  *nnx, *nny, *nnz;
319     real                 *fshx, *fshy, *fshz;
320
321     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
322     matrix                recipbox;
323     splinevec             bsp_mod;
324
325     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
326
327     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
328
329     rvec                 *bufv;       /* Communication buffer */
330     real                 *bufr;       /* Communication buffer */
331     int                   buf_nalloc; /* The communication buffer size */
332
333     /* thread local work data for solve_pme */
334     pme_work_t *work;
335
336     /* Work data for PME_redist */
337     gmx_bool redist_init;
338     int *    scounts;
339     int *    rcounts;
340     int *    sdispls;
341     int *    rdispls;
342     int *    sidx;
343     int *    idxa;
344     real *   redist_buf;
345     int      redist_buf_nalloc;
346
347     /* Work data for sum_qgrid */
348     real *   sum_qgrid_tmp;
349     real *   sum_qgrid_dd_tmp;
350 } t_gmx_pme;
351
352
353 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
354                                    int start, int end, int thread)
355 {
356     int             i;
357     int            *idxptr, tix, tiy, tiz;
358     real           *xptr, *fptr, tx, ty, tz;
359     real            rxx, ryx, ryy, rzx, rzy, rzz;
360     int             nx, ny, nz;
361     int             start_ix, start_iy, start_iz;
362     int            *g2tx, *g2ty, *g2tz;
363     gmx_bool        bThreads;
364     int            *thread_idx = NULL;
365     thread_plist_t *tpl        = NULL;
366     int            *tpl_n      = NULL;
367     int             thread_i;
368
369     nx  = pme->nkx;
370     ny  = pme->nky;
371     nz  = pme->nkz;
372
373     start_ix = pme->pmegrid_start_ix;
374     start_iy = pme->pmegrid_start_iy;
375     start_iz = pme->pmegrid_start_iz;
376
377     rxx = pme->recipbox[XX][XX];
378     ryx = pme->recipbox[YY][XX];
379     ryy = pme->recipbox[YY][YY];
380     rzx = pme->recipbox[ZZ][XX];
381     rzy = pme->recipbox[ZZ][YY];
382     rzz = pme->recipbox[ZZ][ZZ];
383
384     g2tx = pme->pmegridA.g2t[XX];
385     g2ty = pme->pmegridA.g2t[YY];
386     g2tz = pme->pmegridA.g2t[ZZ];
387
388     bThreads = (atc->nthread > 1);
389     if (bThreads)
390     {
391         thread_idx = atc->thread_idx;
392
393         tpl   = &atc->thread_plist[thread];
394         tpl_n = tpl->n;
395         for (i = 0; i < atc->nthread; i++)
396         {
397             tpl_n[i] = 0;
398         }
399     }
400
401     for (i = start; i < end; i++)
402     {
403         xptr   = atc->x[i];
404         idxptr = atc->idx[i];
405         fptr   = atc->fractx[i];
406
407         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
408         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
409         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
410         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
411
412         tix = (int)(tx);
413         tiy = (int)(ty);
414         tiz = (int)(tz);
415
416         /* Because decomposition only occurs in x and y,
417          * we never have a fraction correction in z.
418          */
419         fptr[XX] = tx - tix + pme->fshx[tix];
420         fptr[YY] = ty - tiy + pme->fshy[tiy];
421         fptr[ZZ] = tz - tiz;
422
423         idxptr[XX] = pme->nnx[tix];
424         idxptr[YY] = pme->nny[tiy];
425         idxptr[ZZ] = pme->nnz[tiz];
426
427 #ifdef DEBUG
428         range_check(idxptr[XX], 0, pme->pmegrid_nx);
429         range_check(idxptr[YY], 0, pme->pmegrid_ny);
430         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
431 #endif
432
433         if (bThreads)
434         {
435             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
436             thread_idx[i] = thread_i;
437             tpl_n[thread_i]++;
438         }
439     }
440
441     if (bThreads)
442     {
443         /* Make a list of particle indices sorted on thread */
444
445         /* Get the cumulative count */
446         for (i = 1; i < atc->nthread; i++)
447         {
448             tpl_n[i] += tpl_n[i-1];
449         }
450         /* The current implementation distributes particles equally
451          * over the threads, so we could actually allocate for that
452          * in pme_realloc_atomcomm_things.
453          */
454         if (tpl_n[atc->nthread-1] > tpl->nalloc)
455         {
456             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
457             srenew(tpl->i, tpl->nalloc);
458         }
459         /* Set tpl_n to the cumulative start */
460         for (i = atc->nthread-1; i >= 1; i--)
461         {
462             tpl_n[i] = tpl_n[i-1];
463         }
464         tpl_n[0] = 0;
465
466         /* Fill our thread local array with indices sorted on thread */
467         for (i = start; i < end; i++)
468         {
469             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
470         }
471         /* Now tpl_n contains the cummulative count again */
472     }
473 }
474
475 static void make_thread_local_ind(pme_atomcomm_t *atc,
476                                   int thread, splinedata_t *spline)
477 {
478     int             n, t, i, start, end;
479     thread_plist_t *tpl;
480
481     /* Combine the indices made by each thread into one index */
482
483     n     = 0;
484     start = 0;
485     for (t = 0; t < atc->nthread; t++)
486     {
487         tpl = &atc->thread_plist[t];
488         /* Copy our part (start - end) from the list of thread t */
489         if (thread > 0)
490         {
491             start = tpl->n[thread-1];
492         }
493         end = tpl->n[thread];
494         for (i = start; i < end; i++)
495         {
496             spline->ind[n++] = tpl->i[i];
497         }
498     }
499
500     spline->n = n;
501 }
502
503
504 static void pme_calc_pidx(int start, int end,
505                           matrix recipbox, rvec x[],
506                           pme_atomcomm_t *atc, int *count)
507 {
508     int   nslab, i;
509     int   si;
510     real *xptr, s;
511     real  rxx, ryx, rzx, ryy, rzy;
512     int  *pd;
513
514     /* Calculate PME task index (pidx) for each grid index.
515      * Here we always assign equally sized slabs to each node
516      * for load balancing reasons (the PME grid spacing is not used).
517      */
518
519     nslab = atc->nslab;
520     pd    = atc->pd;
521
522     /* Reset the count */
523     for (i = 0; i < nslab; i++)
524     {
525         count[i] = 0;
526     }
527
528     if (atc->dimind == 0)
529     {
530         rxx = recipbox[XX][XX];
531         ryx = recipbox[YY][XX];
532         rzx = recipbox[ZZ][XX];
533         /* Calculate the node index in x-dimension */
534         for (i = start; i < end; i++)
535         {
536             xptr   = x[i];
537             /* Fractional coordinates along box vectors */
538             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
539             si    = (int)(s + 2*nslab) % nslab;
540             pd[i] = si;
541             count[si]++;
542         }
543     }
544     else
545     {
546         ryy = recipbox[YY][YY];
547         rzy = recipbox[ZZ][YY];
548         /* Calculate the node index in y-dimension */
549         for (i = start; i < end; i++)
550         {
551             xptr   = x[i];
552             /* Fractional coordinates along box vectors */
553             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
554             si    = (int)(s + 2*nslab) % nslab;
555             pd[i] = si;
556             count[si]++;
557         }
558     }
559 }
560
561 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
562                                   pme_atomcomm_t *atc)
563 {
564     int nthread, thread, slab;
565
566     nthread = atc->nthread;
567
568 #pragma omp parallel for num_threads(nthread) schedule(static)
569     for (thread = 0; thread < nthread; thread++)
570     {
571         pme_calc_pidx(natoms* thread   /nthread,
572                       natoms*(thread+1)/nthread,
573                       recipbox, x, atc, atc->count_thread[thread]);
574     }
575     /* Non-parallel reduction, since nslab is small */
576
577     for (thread = 1; thread < nthread; thread++)
578     {
579         for (slab = 0; slab < atc->nslab; slab++)
580         {
581             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
582         }
583     }
584 }
585
586 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
587 {
588     const int padding = 4;
589     int       i;
590
591     srenew(th[XX], nalloc);
592     srenew(th[YY], nalloc);
593     /* In z we add padding, this is only required for the aligned SSE code */
594     srenew(*ptr_z, nalloc+2*padding);
595     th[ZZ] = *ptr_z + padding;
596
597     for (i = 0; i < padding; i++)
598     {
599         (*ptr_z)[               i] = 0;
600         (*ptr_z)[padding+nalloc+i] = 0;
601     }
602 }
603
604 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
605 {
606     int i, d;
607
608     srenew(spline->ind, atc->nalloc);
609     /* Initialize the index to identity so it works without threads */
610     for (i = 0; i < atc->nalloc; i++)
611     {
612         spline->ind[i] = i;
613     }
614
615     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
616                       atc->pme_order*atc->nalloc);
617     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
618                       atc->pme_order*atc->nalloc);
619 }
620
621 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
622 {
623     int nalloc_old, i, j, nalloc_tpl;
624
625     /* We have to avoid a NULL pointer for atc->x to avoid
626      * possible fatal errors in MPI routines.
627      */
628     if (atc->n > atc->nalloc || atc->nalloc == 0)
629     {
630         nalloc_old  = atc->nalloc;
631         atc->nalloc = over_alloc_dd(max(atc->n, 1));
632
633         if (atc->nslab > 1)
634         {
635             srenew(atc->x, atc->nalloc);
636             srenew(atc->q, atc->nalloc);
637             srenew(atc->f, atc->nalloc);
638             for (i = nalloc_old; i < atc->nalloc; i++)
639             {
640                 clear_rvec(atc->f[i]);
641             }
642         }
643         if (atc->bSpread)
644         {
645             srenew(atc->fractx, atc->nalloc);
646             srenew(atc->idx, atc->nalloc);
647
648             if (atc->nthread > 1)
649             {
650                 srenew(atc->thread_idx, atc->nalloc);
651             }
652
653             for (i = 0; i < atc->nthread; i++)
654             {
655                 pme_realloc_splinedata(&atc->spline[i], atc);
656             }
657         }
658     }
659 }
660
661 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
662                          int n, gmx_bool bXF, rvec *x_f, real *charge,
663                          pme_atomcomm_t *atc)
664 /* Redistribute particle data for PME calculation */
665 /* domain decomposition by x coordinate           */
666 {
667     int *idxa;
668     int  i, ii;
669
670     if (FALSE == pme->redist_init)
671     {
672         snew(pme->scounts, atc->nslab);
673         snew(pme->rcounts, atc->nslab);
674         snew(pme->sdispls, atc->nslab);
675         snew(pme->rdispls, atc->nslab);
676         snew(pme->sidx, atc->nslab);
677         pme->redist_init = TRUE;
678     }
679     if (n > pme->redist_buf_nalloc)
680     {
681         pme->redist_buf_nalloc = over_alloc_dd(n);
682         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
683     }
684
685     pme->idxa = atc->pd;
686
687 #ifdef GMX_MPI
688     if (forw && bXF)
689     {
690         /* forward, redistribution from pp to pme */
691
692         /* Calculate send counts and exchange them with other nodes */
693         for (i = 0; (i < atc->nslab); i++)
694         {
695             pme->scounts[i] = 0;
696         }
697         for (i = 0; (i < n); i++)
698         {
699             pme->scounts[pme->idxa[i]]++;
700         }
701         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
702
703         /* Calculate send and receive displacements and index into send
704            buffer */
705         pme->sdispls[0] = 0;
706         pme->rdispls[0] = 0;
707         pme->sidx[0]    = 0;
708         for (i = 1; i < atc->nslab; i++)
709         {
710             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
711             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
712             pme->sidx[i]    = pme->sdispls[i];
713         }
714         /* Total # of particles to be received */
715         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
716
717         pme_realloc_atomcomm_things(atc);
718
719         /* Copy particle coordinates into send buffer and exchange*/
720         for (i = 0; (i < n); i++)
721         {
722             ii = DIM*pme->sidx[pme->idxa[i]];
723             pme->sidx[pme->idxa[i]]++;
724             pme->redist_buf[ii+XX] = x_f[i][XX];
725             pme->redist_buf[ii+YY] = x_f[i][YY];
726             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
727         }
728         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
729                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
730                       pme->rvec_mpi, atc->mpi_comm);
731     }
732     if (forw)
733     {
734         /* Copy charge into send buffer and exchange*/
735         for (i = 0; i < atc->nslab; i++)
736         {
737             pme->sidx[i] = pme->sdispls[i];
738         }
739         for (i = 0; (i < n); i++)
740         {
741             ii = pme->sidx[pme->idxa[i]];
742             pme->sidx[pme->idxa[i]]++;
743             pme->redist_buf[ii] = charge[i];
744         }
745         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
746                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
747                       atc->mpi_comm);
748     }
749     else   /* backward, redistribution from pme to pp */
750     {
751         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
752                       pme->redist_buf, pme->scounts, pme->sdispls,
753                       pme->rvec_mpi, atc->mpi_comm);
754
755         /* Copy data from receive buffer */
756         for (i = 0; i < atc->nslab; i++)
757         {
758             pme->sidx[i] = pme->sdispls[i];
759         }
760         for (i = 0; (i < n); i++)
761         {
762             ii          = DIM*pme->sidx[pme->idxa[i]];
763             x_f[i][XX] += pme->redist_buf[ii+XX];
764             x_f[i][YY] += pme->redist_buf[ii+YY];
765             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
766             pme->sidx[pme->idxa[i]]++;
767         }
768     }
769 #endif
770 }
771
772 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
773                             gmx_bool bBackward, int shift,
774                             void *buf_s, int nbyte_s,
775                             void *buf_r, int nbyte_r)
776 {
777 #ifdef GMX_MPI
778     int        dest, src;
779     MPI_Status stat;
780
781     if (bBackward == FALSE)
782     {
783         dest = atc->node_dest[shift];
784         src  = atc->node_src[shift];
785     }
786     else
787     {
788         dest = atc->node_src[shift];
789         src  = atc->node_dest[shift];
790     }
791
792     if (nbyte_s > 0 && nbyte_r > 0)
793     {
794         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
795                      dest, shift,
796                      buf_r, nbyte_r, MPI_BYTE,
797                      src, shift,
798                      atc->mpi_comm, &stat);
799     }
800     else if (nbyte_s > 0)
801     {
802         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
803                  dest, shift,
804                  atc->mpi_comm);
805     }
806     else if (nbyte_r > 0)
807     {
808         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
809                  src, shift,
810                  atc->mpi_comm, &stat);
811     }
812 #endif
813 }
814
815 static void dd_pmeredist_x_q(gmx_pme_t pme,
816                              int n, gmx_bool bX, rvec *x, real *charge,
817                              pme_atomcomm_t *atc)
818 {
819     int *commnode, *buf_index;
820     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
821
822     commnode  = atc->node_dest;
823     buf_index = atc->buf_index;
824
825     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
826
827     nsend = 0;
828     for (i = 0; i < nnodes_comm; i++)
829     {
830         buf_index[commnode[i]] = nsend;
831         nsend                 += atc->count[commnode[i]];
832     }
833     if (bX)
834     {
835         if (atc->count[atc->nodeid] + nsend != n)
836         {
837             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
838                       "This usually means that your system is not well equilibrated.",
839                       n - (atc->count[atc->nodeid] + nsend),
840                       pme->nodeid, 'x'+atc->dimind);
841         }
842
843         if (nsend > pme->buf_nalloc)
844         {
845             pme->buf_nalloc = over_alloc_dd(nsend);
846             srenew(pme->bufv, pme->buf_nalloc);
847             srenew(pme->bufr, pme->buf_nalloc);
848         }
849
850         atc->n = atc->count[atc->nodeid];
851         for (i = 0; i < nnodes_comm; i++)
852         {
853             scount = atc->count[commnode[i]];
854             /* Communicate the count */
855             if (debug)
856             {
857                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
858                         atc->dimind, atc->nodeid, commnode[i], scount);
859             }
860             pme_dd_sendrecv(atc, FALSE, i,
861                             &scount, sizeof(int),
862                             &atc->rcount[i], sizeof(int));
863             atc->n += atc->rcount[i];
864         }
865
866         pme_realloc_atomcomm_things(atc);
867     }
868
869     local_pos = 0;
870     for (i = 0; i < n; i++)
871     {
872         node = atc->pd[i];
873         if (node == atc->nodeid)
874         {
875             /* Copy direct to the receive buffer */
876             if (bX)
877             {
878                 copy_rvec(x[i], atc->x[local_pos]);
879             }
880             atc->q[local_pos] = charge[i];
881             local_pos++;
882         }
883         else
884         {
885             /* Copy to the send buffer */
886             if (bX)
887             {
888                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
889             }
890             pme->bufr[buf_index[node]] = charge[i];
891             buf_index[node]++;
892         }
893     }
894
895     buf_pos = 0;
896     for (i = 0; i < nnodes_comm; i++)
897     {
898         scount = atc->count[commnode[i]];
899         rcount = atc->rcount[i];
900         if (scount > 0 || rcount > 0)
901         {
902             if (bX)
903             {
904                 /* Communicate the coordinates */
905                 pme_dd_sendrecv(atc, FALSE, i,
906                                 pme->bufv[buf_pos], scount*sizeof(rvec),
907                                 atc->x[local_pos], rcount*sizeof(rvec));
908             }
909             /* Communicate the charges */
910             pme_dd_sendrecv(atc, FALSE, i,
911                             pme->bufr+buf_pos, scount*sizeof(real),
912                             atc->q+local_pos, rcount*sizeof(real));
913             buf_pos   += scount;
914             local_pos += atc->rcount[i];
915         }
916     }
917 }
918
919 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
920                            int n, rvec *f,
921                            gmx_bool bAddF)
922 {
923     int *commnode, *buf_index;
924     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
925
926     commnode  = atc->node_dest;
927     buf_index = atc->buf_index;
928
929     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
930
931     local_pos = atc->count[atc->nodeid];
932     buf_pos   = 0;
933     for (i = 0; i < nnodes_comm; i++)
934     {
935         scount = atc->rcount[i];
936         rcount = atc->count[commnode[i]];
937         if (scount > 0 || rcount > 0)
938         {
939             /* Communicate the forces */
940             pme_dd_sendrecv(atc, TRUE, i,
941                             atc->f[local_pos], scount*sizeof(rvec),
942                             pme->bufv[buf_pos], rcount*sizeof(rvec));
943             local_pos += scount;
944         }
945         buf_index[commnode[i]] = buf_pos;
946         buf_pos               += rcount;
947     }
948
949     local_pos = 0;
950     if (bAddF)
951     {
952         for (i = 0; i < n; i++)
953         {
954             node = atc->pd[i];
955             if (node == atc->nodeid)
956             {
957                 /* Add from the local force array */
958                 rvec_inc(f[i], atc->f[local_pos]);
959                 local_pos++;
960             }
961             else
962             {
963                 /* Add from the receive buffer */
964                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
965                 buf_index[node]++;
966             }
967         }
968     }
969     else
970     {
971         for (i = 0; i < n; i++)
972         {
973             node = atc->pd[i];
974             if (node == atc->nodeid)
975             {
976                 /* Copy from the local force array */
977                 copy_rvec(atc->f[local_pos], f[i]);
978                 local_pos++;
979             }
980             else
981             {
982                 /* Copy from the receive buffer */
983                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
984                 buf_index[node]++;
985             }
986         }
987     }
988 }
989
990 #ifdef GMX_MPI
991 static void
992 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
993 {
994     pme_overlap_t *overlap;
995     int            send_index0, send_nindex;
996     int            recv_index0, recv_nindex;
997     MPI_Status     stat;
998     int            i, j, k, ix, iy, iz, icnt;
999     int            ipulse, send_id, recv_id, datasize;
1000     real          *p;
1001     real          *sendptr, *recvptr;
1002
1003     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1004     overlap = &pme->overlap[1];
1005
1006     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1007     {
1008         /* Since we have already (un)wrapped the overlap in the z-dimension,
1009          * we only have to communicate 0 to nkz (not pmegrid_nz).
1010          */
1011         if (direction == GMX_SUM_QGRID_FORWARD)
1012         {
1013             send_id       = overlap->send_id[ipulse];
1014             recv_id       = overlap->recv_id[ipulse];
1015             send_index0   = overlap->comm_data[ipulse].send_index0;
1016             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1017             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1018             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1019         }
1020         else
1021         {
1022             send_id       = overlap->recv_id[ipulse];
1023             recv_id       = overlap->send_id[ipulse];
1024             send_index0   = overlap->comm_data[ipulse].recv_index0;
1025             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1026             recv_index0   = overlap->comm_data[ipulse].send_index0;
1027             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1028         }
1029
1030         /* Copy data to contiguous send buffer */
1031         if (debug)
1032         {
1033             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1034                     pme->nodeid, overlap->nodeid, send_id,
1035                     pme->pmegrid_start_iy,
1036                     send_index0-pme->pmegrid_start_iy,
1037                     send_index0-pme->pmegrid_start_iy+send_nindex);
1038         }
1039         icnt = 0;
1040         for (i = 0; i < pme->pmegrid_nx; i++)
1041         {
1042             ix = i;
1043             for (j = 0; j < send_nindex; j++)
1044             {
1045                 iy = j + send_index0 - pme->pmegrid_start_iy;
1046                 for (k = 0; k < pme->nkz; k++)
1047                 {
1048                     iz = k;
1049                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1050                 }
1051             }
1052         }
1053
1054         datasize      = pme->pmegrid_nx * pme->nkz;
1055
1056         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1057                      send_id, ipulse,
1058                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1059                      recv_id, ipulse,
1060                      overlap->mpi_comm, &stat);
1061
1062         /* Get data from contiguous recv buffer */
1063         if (debug)
1064         {
1065             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1066                     pme->nodeid, overlap->nodeid, recv_id,
1067                     pme->pmegrid_start_iy,
1068                     recv_index0-pme->pmegrid_start_iy,
1069                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1070         }
1071         icnt = 0;
1072         for (i = 0; i < pme->pmegrid_nx; i++)
1073         {
1074             ix = i;
1075             for (j = 0; j < recv_nindex; j++)
1076             {
1077                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1078                 for (k = 0; k < pme->nkz; k++)
1079                 {
1080                     iz = k;
1081                     if (direction == GMX_SUM_QGRID_FORWARD)
1082                     {
1083                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1084                     }
1085                     else
1086                     {
1087                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1088                     }
1089                 }
1090             }
1091         }
1092     }
1093
1094     /* Major dimension is easier, no copying required,
1095      * but we might have to sum to separate array.
1096      * Since we don't copy, we have to communicate up to pmegrid_nz,
1097      * not nkz as for the minor direction.
1098      */
1099     overlap = &pme->overlap[0];
1100
1101     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1102     {
1103         if (direction == GMX_SUM_QGRID_FORWARD)
1104         {
1105             send_id       = overlap->send_id[ipulse];
1106             recv_id       = overlap->recv_id[ipulse];
1107             send_index0   = overlap->comm_data[ipulse].send_index0;
1108             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1109             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1110             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1111             recvptr       = overlap->recvbuf;
1112         }
1113         else
1114         {
1115             send_id       = overlap->recv_id[ipulse];
1116             recv_id       = overlap->send_id[ipulse];
1117             send_index0   = overlap->comm_data[ipulse].recv_index0;
1118             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1119             recv_index0   = overlap->comm_data[ipulse].send_index0;
1120             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1121             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1122         }
1123
1124         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1125         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1126
1127         if (debug)
1128         {
1129             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1130                     pme->nodeid, overlap->nodeid, send_id,
1131                     pme->pmegrid_start_ix,
1132                     send_index0-pme->pmegrid_start_ix,
1133                     send_index0-pme->pmegrid_start_ix+send_nindex);
1134             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1135                     pme->nodeid, overlap->nodeid, recv_id,
1136                     pme->pmegrid_start_ix,
1137                     recv_index0-pme->pmegrid_start_ix,
1138                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1139         }
1140
1141         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1142                      send_id, ipulse,
1143                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1144                      recv_id, ipulse,
1145                      overlap->mpi_comm, &stat);
1146
1147         /* ADD data from contiguous recv buffer */
1148         if (direction == GMX_SUM_QGRID_FORWARD)
1149         {
1150             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1151             for (i = 0; i < recv_nindex*datasize; i++)
1152             {
1153                 p[i] += overlap->recvbuf[i];
1154             }
1155         }
1156     }
1157 }
1158 #endif
1159
1160
1161 static int
1162 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1163 {
1164     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1165     ivec    local_pme_size;
1166     int     i, ix, iy, iz;
1167     int     pmeidx, fftidx;
1168
1169     /* Dimensions should be identical for A/B grid, so we just use A here */
1170     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1171                                    local_fft_ndata,
1172                                    local_fft_offset,
1173                                    local_fft_size);
1174
1175     local_pme_size[0] = pme->pmegrid_nx;
1176     local_pme_size[1] = pme->pmegrid_ny;
1177     local_pme_size[2] = pme->pmegrid_nz;
1178
1179     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1180        the offset is identical, and the PME grid always has more data (due to overlap)
1181      */
1182     {
1183 #ifdef DEBUG_PME
1184         FILE *fp, *fp2;
1185         char  fn[STRLEN], format[STRLEN];
1186         real  val;
1187         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1188         fp = ffopen(fn, "w");
1189         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1190         fp2 = ffopen(fn, "w");
1191         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1192 #endif
1193
1194         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1195         {
1196             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1197             {
1198                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1199                 {
1200                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1201                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1202                     fftgrid[fftidx] = pmegrid[pmeidx];
1203 #ifdef DEBUG_PME
1204                     val = 100*pmegrid[pmeidx];
1205                     if (pmegrid[pmeidx] != 0)
1206                     {
1207                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1208                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1209                     }
1210                     if (pmegrid[pmeidx] != 0)
1211                     {
1212                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1213                                 "qgrid",
1214                                 pme->pmegrid_start_ix + ix,
1215                                 pme->pmegrid_start_iy + iy,
1216                                 pme->pmegrid_start_iz + iz,
1217                                 pmegrid[pmeidx]);
1218                     }
1219 #endif
1220                 }
1221             }
1222         }
1223 #ifdef DEBUG_PME
1224         ffclose(fp);
1225         ffclose(fp2);
1226 #endif
1227     }
1228     return 0;
1229 }
1230
1231
1232 static gmx_cycles_t omp_cyc_start()
1233 {
1234     return gmx_cycles_read();
1235 }
1236
1237 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1238 {
1239     return gmx_cycles_read() - c;
1240 }
1241
1242
1243 static int
1244 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1245                         int nthread, int thread)
1246 {
1247     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1248     ivec          local_pme_size;
1249     int           ixy0, ixy1, ixy, ix, iy, iz;
1250     int           pmeidx, fftidx;
1251 #ifdef PME_TIME_THREADS
1252     gmx_cycles_t  c1;
1253     static double cs1 = 0;
1254     static int    cnt = 0;
1255 #endif
1256
1257 #ifdef PME_TIME_THREADS
1258     c1 = omp_cyc_start();
1259 #endif
1260     /* Dimensions should be identical for A/B grid, so we just use A here */
1261     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1262                                    local_fft_ndata,
1263                                    local_fft_offset,
1264                                    local_fft_size);
1265
1266     local_pme_size[0] = pme->pmegrid_nx;
1267     local_pme_size[1] = pme->pmegrid_ny;
1268     local_pme_size[2] = pme->pmegrid_nz;
1269
1270     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1271        the offset is identical, and the PME grid always has more data (due to overlap)
1272      */
1273     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1274     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1275
1276     for (ixy = ixy0; ixy < ixy1; ixy++)
1277     {
1278         ix = ixy/local_fft_ndata[YY];
1279         iy = ixy - ix*local_fft_ndata[YY];
1280
1281         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1282         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1283         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1284         {
1285             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1286         }
1287     }
1288
1289 #ifdef PME_TIME_THREADS
1290     c1   = omp_cyc_end(c1);
1291     cs1 += (double)c1;
1292     cnt++;
1293     if (cnt % 20 == 0)
1294     {
1295         printf("copy %.2f\n", cs1*1e-9);
1296     }
1297 #endif
1298
1299     return 0;
1300 }
1301
1302
1303 static void
1304 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1305 {
1306     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1307
1308     nx = pme->nkx;
1309     ny = pme->nky;
1310     nz = pme->nkz;
1311
1312     pnx = pme->pmegrid_nx;
1313     pny = pme->pmegrid_ny;
1314     pnz = pme->pmegrid_nz;
1315
1316     overlap = pme->pme_order - 1;
1317
1318     /* Add periodic overlap in z */
1319     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1320     {
1321         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1322         {
1323             for (iz = 0; iz < overlap; iz++)
1324             {
1325                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1326                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1327             }
1328         }
1329     }
1330
1331     if (pme->nnodes_minor == 1)
1332     {
1333         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1334         {
1335             for (iy = 0; iy < overlap; iy++)
1336             {
1337                 for (iz = 0; iz < nz; iz++)
1338                 {
1339                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1340                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1341                 }
1342             }
1343         }
1344     }
1345
1346     if (pme->nnodes_major == 1)
1347     {
1348         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1349
1350         for (ix = 0; ix < overlap; ix++)
1351         {
1352             for (iy = 0; iy < ny_x; iy++)
1353             {
1354                 for (iz = 0; iz < nz; iz++)
1355                 {
1356                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1357                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1358                 }
1359             }
1360         }
1361     }
1362 }
1363
1364
1365 static void
1366 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1367 {
1368     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1369
1370     nx = pme->nkx;
1371     ny = pme->nky;
1372     nz = pme->nkz;
1373
1374     pnx = pme->pmegrid_nx;
1375     pny = pme->pmegrid_ny;
1376     pnz = pme->pmegrid_nz;
1377
1378     overlap = pme->pme_order - 1;
1379
1380     if (pme->nnodes_major == 1)
1381     {
1382         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1383
1384         for (ix = 0; ix < overlap; ix++)
1385         {
1386             int iy, iz;
1387
1388             for (iy = 0; iy < ny_x; iy++)
1389             {
1390                 for (iz = 0; iz < nz; iz++)
1391                 {
1392                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1393                         pmegrid[(ix*pny+iy)*pnz+iz];
1394                 }
1395             }
1396         }
1397     }
1398
1399     if (pme->nnodes_minor == 1)
1400     {
1401 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1402         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1403         {
1404             int iy, iz;
1405
1406             for (iy = 0; iy < overlap; iy++)
1407             {
1408                 for (iz = 0; iz < nz; iz++)
1409                 {
1410                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1411                         pmegrid[(ix*pny+iy)*pnz+iz];
1412                 }
1413             }
1414         }
1415     }
1416
1417     /* Copy periodic overlap in z */
1418 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1419     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1420     {
1421         int iy, iz;
1422
1423         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1424         {
1425             for (iz = 0; iz < overlap; iz++)
1426             {
1427                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1428                     pmegrid[(ix*pny+iy)*pnz+iz];
1429             }
1430         }
1431     }
1432 }
1433
1434 static void clear_grid(int nx, int ny, int nz, real *grid,
1435                        ivec fs, int *flag,
1436                        int fx, int fy, int fz,
1437                        int order)
1438 {
1439     int nc, ncz;
1440     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1441     int flind;
1442
1443     nc  = 2 + (order - 2)/FLBS;
1444     ncz = 2 + (order - 2)/FLBSZ;
1445
1446     for (fsx = fx; fsx < fx+nc; fsx++)
1447     {
1448         for (fsy = fy; fsy < fy+nc; fsy++)
1449         {
1450             for (fsz = fz; fsz < fz+ncz; fsz++)
1451             {
1452                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1453                 if (flag[flind] == 0)
1454                 {
1455                     gx  = fsx*FLBS;
1456                     gy  = fsy*FLBS;
1457                     gz  = fsz*FLBSZ;
1458                     g0x = (gx*ny + gy)*nz + gz;
1459                     for (x = 0; x < FLBS; x++)
1460                     {
1461                         g0y = g0x;
1462                         for (y = 0; y < FLBS; y++)
1463                         {
1464                             for (z = 0; z < FLBSZ; z++)
1465                             {
1466                                 grid[g0y+z] = 0;
1467                             }
1468                             g0y += nz;
1469                         }
1470                         g0x += ny*nz;
1471                     }
1472
1473                     flag[flind] = 1;
1474                 }
1475             }
1476         }
1477     }
1478 }
1479
1480 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1481 #define DO_BSPLINE(order)                            \
1482     for (ithx = 0; (ithx < order); ithx++)                    \
1483     {                                                    \
1484         index_x = (i0+ithx)*pny*pnz;                     \
1485         valx    = qn*thx[ithx];                          \
1486                                                      \
1487         for (ithy = 0; (ithy < order); ithy++)                \
1488         {                                                \
1489             valxy    = valx*thy[ithy];                   \
1490             index_xy = index_x+(j0+ithy)*pnz;            \
1491                                                      \
1492             for (ithz = 0; (ithz < order); ithz++)            \
1493             {                                            \
1494                 index_xyz        = index_xy+(k0+ithz);   \
1495                 grid[index_xyz] += valxy*thz[ithz];      \
1496             }                                            \
1497         }                                                \
1498     }
1499
1500
1501 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1502                                      pme_atomcomm_t *atc, splinedata_t *spline,
1503                                      pme_spline_work_t *work)
1504 {
1505
1506     /* spread charges from home atoms to local grid */
1507     real          *grid;
1508     pme_overlap_t *ol;
1509     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1510     int       *    idxptr;
1511     int            order, norder, index_x, index_xy, index_xyz;
1512     real           valx, valxy, qn;
1513     real          *thx, *thy, *thz;
1514     int            localsize, bndsize;
1515     int            pnx, pny, pnz, ndatatot;
1516     int            offx, offy, offz;
1517
1518     pnx = pmegrid->s[XX];
1519     pny = pmegrid->s[YY];
1520     pnz = pmegrid->s[ZZ];
1521
1522     offx = pmegrid->offset[XX];
1523     offy = pmegrid->offset[YY];
1524     offz = pmegrid->offset[ZZ];
1525
1526     ndatatot = pnx*pny*pnz;
1527     grid     = pmegrid->grid;
1528     for (i = 0; i < ndatatot; i++)
1529     {
1530         grid[i] = 0;
1531     }
1532
1533     order = pmegrid->order;
1534
1535     for (nn = 0; nn < spline->n; nn++)
1536     {
1537         n  = spline->ind[nn];
1538         qn = atc->q[n];
1539
1540         if (qn != 0)
1541         {
1542             idxptr = atc->idx[n];
1543             norder = nn*order;
1544
1545             i0   = idxptr[XX] - offx;
1546             j0   = idxptr[YY] - offy;
1547             k0   = idxptr[ZZ] - offz;
1548
1549             thx = spline->theta[XX] + norder;
1550             thy = spline->theta[YY] + norder;
1551             thz = spline->theta[ZZ] + norder;
1552
1553             switch (order)
1554             {
1555                 case 4:
1556 #ifdef PME_SSE_SPREAD_GATHER
1557 #ifdef PME_SSE_UNALIGNED
1558 #define PME_SPREAD_SSE_ORDER4
1559 #else
1560 #define PME_SPREAD_SSE_ALIGNED
1561 #define PME_ORDER 4
1562 #endif
1563 #include "pme_sse_single.h"
1564 #else
1565                     DO_BSPLINE(4);
1566 #endif
1567                     break;
1568                 case 5:
1569 #ifdef PME_SSE_SPREAD_GATHER
1570 #define PME_SPREAD_SSE_ALIGNED
1571 #define PME_ORDER 5
1572 #include "pme_sse_single.h"
1573 #else
1574                     DO_BSPLINE(5);
1575 #endif
1576                     break;
1577                 default:
1578                     DO_BSPLINE(order);
1579                     break;
1580             }
1581         }
1582     }
1583 }
1584
1585 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1586 {
1587 #ifdef PME_SSE_SPREAD_GATHER
1588     if (pme_order == 5
1589 #ifndef PME_SSE_UNALIGNED
1590         || pme_order == 4
1591 #endif
1592         )
1593     {
1594         /* Round nz up to a multiple of 4 to ensure alignment */
1595         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1596     }
1597 #endif
1598 }
1599
1600 static void set_gridsize_alignment(int *gridsize, int pme_order)
1601 {
1602 #ifdef PME_SSE_SPREAD_GATHER
1603 #ifndef PME_SSE_UNALIGNED
1604     if (pme_order == 4)
1605     {
1606         /* Add extra elements to ensured aligned operations do not go
1607          * beyond the allocated grid size.
1608          * Note that for pme_order=5, the pme grid z-size alignment
1609          * ensures that we will not go beyond the grid size.
1610          */
1611         *gridsize += 4;
1612     }
1613 #endif
1614 #endif
1615 }
1616
1617 static void pmegrid_init(pmegrid_t *grid,
1618                          int cx, int cy, int cz,
1619                          int x0, int y0, int z0,
1620                          int x1, int y1, int z1,
1621                          gmx_bool set_alignment,
1622                          int pme_order,
1623                          real *ptr)
1624 {
1625     int nz, gridsize;
1626
1627     grid->ci[XX]     = cx;
1628     grid->ci[YY]     = cy;
1629     grid->ci[ZZ]     = cz;
1630     grid->offset[XX] = x0;
1631     grid->offset[YY] = y0;
1632     grid->offset[ZZ] = z0;
1633     grid->n[XX]      = x1 - x0 + pme_order - 1;
1634     grid->n[YY]      = y1 - y0 + pme_order - 1;
1635     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1636     copy_ivec(grid->n, grid->s);
1637
1638     nz = grid->s[ZZ];
1639     set_grid_alignment(&nz, pme_order);
1640     if (set_alignment)
1641     {
1642         grid->s[ZZ] = nz;
1643     }
1644     else if (nz != grid->s[ZZ])
1645     {
1646         gmx_incons("pmegrid_init call with an unaligned z size");
1647     }
1648
1649     grid->order = pme_order;
1650     if (ptr == NULL)
1651     {
1652         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1653         set_gridsize_alignment(&gridsize, pme_order);
1654         snew_aligned(grid->grid, gridsize, 16);
1655     }
1656     else
1657     {
1658         grid->grid = ptr;
1659     }
1660 }
1661
1662 static int div_round_up(int enumerator, int denominator)
1663 {
1664     return (enumerator + denominator - 1)/denominator;
1665 }
1666
1667 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1668                                   ivec nsub)
1669 {
1670     int gsize_opt, gsize;
1671     int nsx, nsy, nsz;
1672     char *env;
1673
1674     gsize_opt = -1;
1675     for (nsx = 1; nsx <= nthread; nsx++)
1676     {
1677         if (nthread % nsx == 0)
1678         {
1679             for (nsy = 1; nsy <= nthread; nsy++)
1680             {
1681                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1682                 {
1683                     nsz = nthread/(nsx*nsy);
1684
1685                     /* Determine the number of grid points per thread */
1686                     gsize =
1687                         (div_round_up(n[XX], nsx) + ovl)*
1688                         (div_round_up(n[YY], nsy) + ovl)*
1689                         (div_round_up(n[ZZ], nsz) + ovl);
1690
1691                     /* Minimize the number of grids points per thread
1692                      * and, secondarily, the number of cuts in minor dimensions.
1693                      */
1694                     if (gsize_opt == -1 ||
1695                         gsize < gsize_opt ||
1696                         (gsize == gsize_opt &&
1697                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1698                     {
1699                         nsub[XX]  = nsx;
1700                         nsub[YY]  = nsy;
1701                         nsub[ZZ]  = nsz;
1702                         gsize_opt = gsize;
1703                     }
1704                 }
1705             }
1706         }
1707     }
1708
1709     env = getenv("GMX_PME_THREAD_DIVISION");
1710     if (env != NULL)
1711     {
1712         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1713     }
1714
1715     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1716     {
1717         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1718     }
1719 }
1720
1721 static void pmegrids_init(pmegrids_t *grids,
1722                           int nx, int ny, int nz, int nz_base,
1723                           int pme_order,
1724                           gmx_bool bUseThreads,
1725                           int nthread,
1726                           int overlap_x,
1727                           int overlap_y)
1728 {
1729     ivec n, n_base, g0, g1;
1730     int t, x, y, z, d, i, tfac;
1731     int max_comm_lines = -1;
1732
1733     n[XX] = nx - (pme_order - 1);
1734     n[YY] = ny - (pme_order - 1);
1735     n[ZZ] = nz - (pme_order - 1);
1736
1737     copy_ivec(n, n_base);
1738     n_base[ZZ] = nz_base;
1739
1740     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1741                  NULL);
1742
1743     grids->nthread = nthread;
1744
1745     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1746
1747     if (bUseThreads)
1748     {
1749         ivec nst;
1750         int gridsize;
1751
1752         for (d = 0; d < DIM; d++)
1753         {
1754             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1755         }
1756         set_grid_alignment(&nst[ZZ], pme_order);
1757
1758         if (debug)
1759         {
1760             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1761                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1762             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1763                     nx, ny, nz,
1764                     nst[XX], nst[YY], nst[ZZ]);
1765         }
1766
1767         snew(grids->grid_th, grids->nthread);
1768         t        = 0;
1769         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1770         set_gridsize_alignment(&gridsize, pme_order);
1771         snew_aligned(grids->grid_all,
1772                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1773                      16);
1774
1775         for (x = 0; x < grids->nc[XX]; x++)
1776         {
1777             for (y = 0; y < grids->nc[YY]; y++)
1778             {
1779                 for (z = 0; z < grids->nc[ZZ]; z++)
1780                 {
1781                     pmegrid_init(&grids->grid_th[t],
1782                                  x, y, z,
1783                                  (n[XX]*(x  ))/grids->nc[XX],
1784                                  (n[YY]*(y  ))/grids->nc[YY],
1785                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1786                                  (n[XX]*(x+1))/grids->nc[XX],
1787                                  (n[YY]*(y+1))/grids->nc[YY],
1788                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1789                                  TRUE,
1790                                  pme_order,
1791                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1792                     t++;
1793                 }
1794             }
1795         }
1796     }
1797     else
1798     {
1799         grids->grid_th = NULL;
1800     }
1801
1802     snew(grids->g2t, DIM);
1803     tfac = 1;
1804     for (d = DIM-1; d >= 0; d--)
1805     {
1806         snew(grids->g2t[d], n[d]);
1807         t = 0;
1808         for (i = 0; i < n[d]; i++)
1809         {
1810             /* The second check should match the parameters
1811              * of the pmegrid_init call above.
1812              */
1813             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1814             {
1815                 t++;
1816             }
1817             grids->g2t[d][i] = t*tfac;
1818         }
1819
1820         tfac *= grids->nc[d];
1821
1822         switch (d)
1823         {
1824             case XX: max_comm_lines = overlap_x;     break;
1825             case YY: max_comm_lines = overlap_y;     break;
1826             case ZZ: max_comm_lines = pme_order - 1; break;
1827         }
1828         grids->nthread_comm[d] = 0;
1829         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1830                grids->nthread_comm[d] < grids->nc[d])
1831         {
1832             grids->nthread_comm[d]++;
1833         }
1834         if (debug != NULL)
1835         {
1836             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1837                     'x'+d, grids->nthread_comm[d]);
1838         }
1839         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1840          * work, but this is not a problematic restriction.
1841          */
1842         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1843         {
1844             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1845         }
1846     }
1847 }
1848
1849
1850 static void pmegrids_destroy(pmegrids_t *grids)
1851 {
1852     int t;
1853
1854     if (grids->grid.grid != NULL)
1855     {
1856         sfree(grids->grid.grid);
1857
1858         if (grids->nthread > 0)
1859         {
1860             for (t = 0; t < grids->nthread; t++)
1861             {
1862                 sfree(grids->grid_th[t].grid);
1863             }
1864             sfree(grids->grid_th);
1865         }
1866     }
1867 }
1868
1869
1870 static void realloc_work(pme_work_t *work, int nkx)
1871 {
1872     if (nkx > work->nalloc)
1873     {
1874         work->nalloc = nkx;
1875         srenew(work->mhx, work->nalloc);
1876         srenew(work->mhy, work->nalloc);
1877         srenew(work->mhz, work->nalloc);
1878         srenew(work->m2, work->nalloc);
1879         /* Allocate an aligned pointer for SIMD operations, including extra
1880          * elements at the end for padding.
1881          */
1882 #ifdef PME_SIMD
1883 #define ALIGN_HERE  GMX_SIMD_WIDTH_HERE
1884 #else
1885 /* We can use any alignment, apart from 0, so we use 4 */
1886 #define ALIGN_HERE  4
1887 #endif
1888         sfree_aligned(work->denom);
1889         sfree_aligned(work->tmp1);
1890         sfree_aligned(work->eterm);
1891         snew_aligned(work->denom, work->nalloc+ALIGN_HERE, ALIGN_HERE*sizeof(real));
1892         snew_aligned(work->tmp1,  work->nalloc+ALIGN_HERE, ALIGN_HERE*sizeof(real));
1893         snew_aligned(work->eterm, work->nalloc+ALIGN_HERE, ALIGN_HERE*sizeof(real));
1894         srenew(work->m2inv, work->nalloc);
1895     }
1896 }
1897
1898
1899 static void free_work(pme_work_t *work)
1900 {
1901     sfree(work->mhx);
1902     sfree(work->mhy);
1903     sfree(work->mhz);
1904     sfree(work->m2);
1905     sfree_aligned(work->denom);
1906     sfree_aligned(work->tmp1);
1907     sfree_aligned(work->eterm);
1908     sfree(work->m2inv);
1909 }
1910
1911
1912 #ifdef PME_SIMD
1913 /* Calculate exponentials through SIMD */
1914 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1915 {
1916     {
1917         const gmx_mm_pr two = gmx_set1_pr(2.0);
1918         gmx_mm_pr f_simd;
1919         gmx_mm_pr lu;
1920         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1921         int kx;
1922         f_simd = gmx_load1_pr(&f);
1923         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1924         {
1925             tmp_d1   = gmx_load_pr(d_aligned+kx);
1926             d_inv    = gmx_inv_pr(tmp_d1);
1927             tmp_r    = gmx_load_pr(r_aligned+kx);
1928             tmp_r    = gmx_exp_pr(tmp_r);
1929             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1930             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1931             gmx_store_pr(e_aligned+kx, tmp_e);
1932         }
1933     }
1934 }
1935 #else
1936 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1937 {
1938     int kx;
1939     for (kx = start; kx < end; kx++)
1940     {
1941         d[kx] = 1.0/d[kx];
1942     }
1943     for (kx = start; kx < end; kx++)
1944     {
1945         r[kx] = exp(r[kx]);
1946     }
1947     for (kx = start; kx < end; kx++)
1948     {
1949         e[kx] = f*r[kx]*d[kx];
1950     }
1951 }
1952 #endif
1953
1954
1955 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1956                          real ewaldcoeff, real vol,
1957                          gmx_bool bEnerVir,
1958                          int nthread, int thread)
1959 {
1960     /* do recip sum over local cells in grid */
1961     /* y major, z middle, x minor or continuous */
1962     t_complex *p0;
1963     int     kx, ky, kz, maxkx, maxky, maxkz;
1964     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1965     real    mx, my, mz;
1966     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1967     real    ets2, struct2, vfactor, ets2vf;
1968     real    d1, d2, energy = 0;
1969     real    by, bz;
1970     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1971     real    rxx, ryx, ryy, rzx, rzy, rzz;
1972     pme_work_t *work;
1973     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1974     real    mhxk, mhyk, mhzk, m2k;
1975     real    corner_fac;
1976     ivec    complex_order;
1977     ivec    local_ndata, local_offset, local_size;
1978     real    elfac;
1979
1980     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1981
1982     nx = pme->nkx;
1983     ny = pme->nky;
1984     nz = pme->nkz;
1985
1986     /* Dimensions should be identical for A/B grid, so we just use A here */
1987     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1988                                       complex_order,
1989                                       local_ndata,
1990                                       local_offset,
1991                                       local_size);
1992
1993     rxx = pme->recipbox[XX][XX];
1994     ryx = pme->recipbox[YY][XX];
1995     ryy = pme->recipbox[YY][YY];
1996     rzx = pme->recipbox[ZZ][XX];
1997     rzy = pme->recipbox[ZZ][YY];
1998     rzz = pme->recipbox[ZZ][ZZ];
1999
2000     maxkx = (nx+1)/2;
2001     maxky = (ny+1)/2;
2002     maxkz = nz/2+1;
2003
2004     work  = &pme->work[thread];
2005     mhx   = work->mhx;
2006     mhy   = work->mhy;
2007     mhz   = work->mhz;
2008     m2    = work->m2;
2009     denom = work->denom;
2010     tmp1  = work->tmp1;
2011     eterm = work->eterm;
2012     m2inv = work->m2inv;
2013
2014     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2015     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2016
2017     for (iyz = iyz0; iyz < iyz1; iyz++)
2018     {
2019         iy = iyz/local_ndata[ZZ];
2020         iz = iyz - iy*local_ndata[ZZ];
2021
2022         ky = iy + local_offset[YY];
2023
2024         if (ky < maxky)
2025         {
2026             my = ky;
2027         }
2028         else
2029         {
2030             my = (ky - ny);
2031         }
2032
2033         by = M_PI*vol*pme->bsp_mod[YY][ky];
2034
2035         kz = iz + local_offset[ZZ];
2036
2037         mz = kz;
2038
2039         bz = pme->bsp_mod[ZZ][kz];
2040
2041         /* 0.5 correction for corner points */
2042         corner_fac = 1;
2043         if (kz == 0 || kz == (nz+1)/2)
2044         {
2045             corner_fac = 0.5;
2046         }
2047
2048         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2049
2050         /* We should skip the k-space point (0,0,0) */
2051         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2052         {
2053             kxstart = local_offset[XX];
2054         }
2055         else
2056         {
2057             kxstart = local_offset[XX] + 1;
2058             p0++;
2059         }
2060         kxend = local_offset[XX] + local_ndata[XX];
2061
2062         if (bEnerVir)
2063         {
2064             /* More expensive inner loop, especially because of the storage
2065              * of the mh elements in array's.
2066              * Because x is the minor grid index, all mh elements
2067              * depend on kx for triclinic unit cells.
2068              */
2069
2070             /* Two explicit loops to avoid a conditional inside the loop */
2071             for (kx = kxstart; kx < maxkx; kx++)
2072             {
2073                 mx = kx;
2074
2075                 mhxk      = mx * rxx;
2076                 mhyk      = mx * ryx + my * ryy;
2077                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2078                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2079                 mhx[kx]   = mhxk;
2080                 mhy[kx]   = mhyk;
2081                 mhz[kx]   = mhzk;
2082                 m2[kx]    = m2k;
2083                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2084                 tmp1[kx]  = -factor*m2k;
2085             }
2086
2087             for (kx = maxkx; kx < kxend; kx++)
2088             {
2089                 mx = (kx - nx);
2090
2091                 mhxk      = mx * rxx;
2092                 mhyk      = mx * ryx + my * ryy;
2093                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2094                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2095                 mhx[kx]   = mhxk;
2096                 mhy[kx]   = mhyk;
2097                 mhz[kx]   = mhzk;
2098                 m2[kx]    = m2k;
2099                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2100                 tmp1[kx]  = -factor*m2k;
2101             }
2102
2103             for (kx = kxstart; kx < kxend; kx++)
2104             {
2105                 m2inv[kx] = 1.0/m2[kx];
2106             }
2107
2108             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2109
2110             for (kx = kxstart; kx < kxend; kx++, p0++)
2111             {
2112                 d1      = p0->re;
2113                 d2      = p0->im;
2114
2115                 p0->re  = d1*eterm[kx];
2116                 p0->im  = d2*eterm[kx];
2117
2118                 struct2 = 2.0*(d1*d1+d2*d2);
2119
2120                 tmp1[kx] = eterm[kx]*struct2;
2121             }
2122
2123             for (kx = kxstart; kx < kxend; kx++)
2124             {
2125                 ets2     = corner_fac*tmp1[kx];
2126                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2127                 energy  += ets2;
2128
2129                 ets2vf   = ets2*vfactor;
2130                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2131                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2132                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2133                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2134                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2135                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2136             }
2137         }
2138         else
2139         {
2140             /* We don't need to calculate the energy and the virial.
2141              * In this case the triclinic overhead is small.
2142              */
2143
2144             /* Two explicit loops to avoid a conditional inside the loop */
2145
2146             for (kx = kxstart; kx < maxkx; kx++)
2147             {
2148                 mx = kx;
2149
2150                 mhxk      = mx * rxx;
2151                 mhyk      = mx * ryx + my * ryy;
2152                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2153                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2154                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2155                 tmp1[kx]  = -factor*m2k;
2156             }
2157
2158             for (kx = maxkx; kx < kxend; kx++)
2159             {
2160                 mx = (kx - nx);
2161
2162                 mhxk      = mx * rxx;
2163                 mhyk      = mx * ryx + my * ryy;
2164                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2165                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2166                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2167                 tmp1[kx]  = -factor*m2k;
2168             }
2169
2170             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2171
2172             for (kx = kxstart; kx < kxend; kx++, p0++)
2173             {
2174                 d1      = p0->re;
2175                 d2      = p0->im;
2176
2177                 p0->re  = d1*eterm[kx];
2178                 p0->im  = d2*eterm[kx];
2179             }
2180         }
2181     }
2182
2183     if (bEnerVir)
2184     {
2185         /* Update virial with local values.
2186          * The virial is symmetric by definition.
2187          * this virial seems ok for isotropic scaling, but I'm
2188          * experiencing problems on semiisotropic membranes.
2189          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2190          */
2191         work->vir[XX][XX] = 0.25*virxx;
2192         work->vir[YY][YY] = 0.25*viryy;
2193         work->vir[ZZ][ZZ] = 0.25*virzz;
2194         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2195         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2196         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2197
2198         /* This energy should be corrected for a charged system */
2199         work->energy = 0.5*energy;
2200     }
2201
2202     /* Return the loop count */
2203     return local_ndata[YY]*local_ndata[XX];
2204 }
2205
2206 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2207                              real *mesh_energy, matrix vir)
2208 {
2209     /* This function sums output over threads
2210      * and should therefore only be called after thread synchronization.
2211      */
2212     int thread;
2213
2214     *mesh_energy = pme->work[0].energy;
2215     copy_mat(pme->work[0].vir, vir);
2216
2217     for (thread = 1; thread < nthread; thread++)
2218     {
2219         *mesh_energy += pme->work[thread].energy;
2220         m_add(vir, pme->work[thread].vir, vir);
2221     }
2222 }
2223
2224 #define DO_FSPLINE(order)                      \
2225     for (ithx = 0; (ithx < order); ithx++)              \
2226     {                                              \
2227         index_x = (i0+ithx)*pny*pnz;               \
2228         tx      = thx[ithx];                       \
2229         dx      = dthx[ithx];                      \
2230                                                \
2231         for (ithy = 0; (ithy < order); ithy++)          \
2232         {                                          \
2233             index_xy = index_x+(j0+ithy)*pnz;      \
2234             ty       = thy[ithy];                  \
2235             dy       = dthy[ithy];                 \
2236             fxy1     = fz1 = 0;                    \
2237                                                \
2238             for (ithz = 0; (ithz < order); ithz++)      \
2239             {                                      \
2240                 gval  = grid[index_xy+(k0+ithz)];  \
2241                 fxy1 += thz[ithz]*gval;            \
2242                 fz1  += dthz[ithz]*gval;           \
2243             }                                      \
2244             fx += dx*ty*fxy1;                      \
2245             fy += tx*dy*fxy1;                      \
2246             fz += tx*ty*fz1;                       \
2247         }                                          \
2248     }
2249
2250
2251 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2252                               gmx_bool bClearF, pme_atomcomm_t *atc,
2253                               splinedata_t *spline,
2254                               real scale)
2255 {
2256     /* sum forces for local particles */
2257     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2258     int     index_x, index_xy;
2259     int     nx, ny, nz, pnx, pny, pnz;
2260     int *   idxptr;
2261     real    tx, ty, dx, dy, qn;
2262     real    fx, fy, fz, gval;
2263     real    fxy1, fz1;
2264     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2265     int     norder;
2266     real    rxx, ryx, ryy, rzx, rzy, rzz;
2267     int     order;
2268
2269     pme_spline_work_t *work;
2270
2271     work = pme->spline_work;
2272
2273     order = pme->pme_order;
2274     thx   = spline->theta[XX];
2275     thy   = spline->theta[YY];
2276     thz   = spline->theta[ZZ];
2277     dthx  = spline->dtheta[XX];
2278     dthy  = spline->dtheta[YY];
2279     dthz  = spline->dtheta[ZZ];
2280     nx    = pme->nkx;
2281     ny    = pme->nky;
2282     nz    = pme->nkz;
2283     pnx   = pme->pmegrid_nx;
2284     pny   = pme->pmegrid_ny;
2285     pnz   = pme->pmegrid_nz;
2286
2287     rxx   = pme->recipbox[XX][XX];
2288     ryx   = pme->recipbox[YY][XX];
2289     ryy   = pme->recipbox[YY][YY];
2290     rzx   = pme->recipbox[ZZ][XX];
2291     rzy   = pme->recipbox[ZZ][YY];
2292     rzz   = pme->recipbox[ZZ][ZZ];
2293
2294     for (nn = 0; nn < spline->n; nn++)
2295     {
2296         n  = spline->ind[nn];
2297         qn = scale*atc->q[n];
2298
2299         if (bClearF)
2300         {
2301             atc->f[n][XX] = 0;
2302             atc->f[n][YY] = 0;
2303             atc->f[n][ZZ] = 0;
2304         }
2305         if (qn != 0)
2306         {
2307             fx     = 0;
2308             fy     = 0;
2309             fz     = 0;
2310             idxptr = atc->idx[n];
2311             norder = nn*order;
2312
2313             i0   = idxptr[XX];
2314             j0   = idxptr[YY];
2315             k0   = idxptr[ZZ];
2316
2317             /* Pointer arithmetic alert, next six statements */
2318             thx  = spline->theta[XX] + norder;
2319             thy  = spline->theta[YY] + norder;
2320             thz  = spline->theta[ZZ] + norder;
2321             dthx = spline->dtheta[XX] + norder;
2322             dthy = spline->dtheta[YY] + norder;
2323             dthz = spline->dtheta[ZZ] + norder;
2324
2325             switch (order)
2326             {
2327                 case 4:
2328 #ifdef PME_SSE_SPREAD_GATHER
2329 #ifdef PME_SSE_UNALIGNED
2330 #define PME_GATHER_F_SSE_ORDER4
2331 #else
2332 #define PME_GATHER_F_SSE_ALIGNED
2333 #define PME_ORDER 4
2334 #endif
2335 #include "pme_sse_single.h"
2336 #else
2337                     DO_FSPLINE(4);
2338 #endif
2339                     break;
2340                 case 5:
2341 #ifdef PME_SSE_SPREAD_GATHER
2342 #define PME_GATHER_F_SSE_ALIGNED
2343 #define PME_ORDER 5
2344 #include "pme_sse_single.h"
2345 #else
2346                     DO_FSPLINE(5);
2347 #endif
2348                     break;
2349                 default:
2350                     DO_FSPLINE(order);
2351                     break;
2352             }
2353
2354             atc->f[n][XX] += -qn*( fx*nx*rxx );
2355             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2356             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2357         }
2358     }
2359     /* Since the energy and not forces are interpolated
2360      * the net force might not be exactly zero.
2361      * This can be solved by also interpolating F, but
2362      * that comes at a cost.
2363      * A better hack is to remove the net force every
2364      * step, but that must be done at a higher level
2365      * since this routine doesn't see all atoms if running
2366      * in parallel. Don't know how important it is?  EL 990726
2367      */
2368 }
2369
2370
2371 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2372                                    pme_atomcomm_t *atc)
2373 {
2374     splinedata_t *spline;
2375     int     n, ithx, ithy, ithz, i0, j0, k0;
2376     int     index_x, index_xy;
2377     int *   idxptr;
2378     real    energy, pot, tx, ty, qn, gval;
2379     real    *thx, *thy, *thz;
2380     int     norder;
2381     int     order;
2382
2383     spline = &atc->spline[0];
2384
2385     order = pme->pme_order;
2386
2387     energy = 0;
2388     for (n = 0; (n < atc->n); n++)
2389     {
2390         qn      = atc->q[n];
2391
2392         if (qn != 0)
2393         {
2394             idxptr = atc->idx[n];
2395             norder = n*order;
2396
2397             i0   = idxptr[XX];
2398             j0   = idxptr[YY];
2399             k0   = idxptr[ZZ];
2400
2401             /* Pointer arithmetic alert, next three statements */
2402             thx  = spline->theta[XX] + norder;
2403             thy  = spline->theta[YY] + norder;
2404             thz  = spline->theta[ZZ] + norder;
2405
2406             pot = 0;
2407             for (ithx = 0; (ithx < order); ithx++)
2408             {
2409                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2410                 tx      = thx[ithx];
2411
2412                 for (ithy = 0; (ithy < order); ithy++)
2413                 {
2414                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2415                     ty       = thy[ithy];
2416
2417                     for (ithz = 0; (ithz < order); ithz++)
2418                     {
2419                         gval  = grid[index_xy+(k0+ithz)];
2420                         pot  += tx*ty*thz[ithz]*gval;
2421                     }
2422
2423                 }
2424             }
2425
2426             energy += pot*qn;
2427         }
2428     }
2429
2430     return energy;
2431 }
2432
2433 /* Macro to force loop unrolling by fixing order.
2434  * This gives a significant performance gain.
2435  */
2436 #define CALC_SPLINE(order)                     \
2437     {                                              \
2438         int j, k, l;                                 \
2439         real dr, div;                               \
2440         real data[PME_ORDER_MAX];                  \
2441         real ddata[PME_ORDER_MAX];                 \
2442                                                \
2443         for (j = 0; (j < DIM); j++)                     \
2444         {                                          \
2445             dr  = xptr[j];                         \
2446                                                \
2447             /* dr is relative offset from lower cell limit */ \
2448             data[order-1] = 0;                     \
2449             data[1]       = dr;                          \
2450             data[0]       = 1 - dr;                      \
2451                                                \
2452             for (k = 3; (k < order); k++)               \
2453             {                                      \
2454                 div       = 1.0/(k - 1.0);               \
2455                 data[k-1] = div*dr*data[k-2];      \
2456                 for (l = 1; (l < (k-1)); l++)           \
2457                 {                                  \
2458                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2459                                        data[k-l-1]);                \
2460                 }                                  \
2461                 data[0] = div*(1-dr)*data[0];      \
2462             }                                      \
2463             /* differentiate */                    \
2464             ddata[0] = -data[0];                   \
2465             for (k = 1; (k < order); k++)               \
2466             {                                      \
2467                 ddata[k] = data[k-1] - data[k];    \
2468             }                                      \
2469                                                \
2470             div           = 1.0/(order - 1);                 \
2471             data[order-1] = div*dr*data[order-2];  \
2472             for (l = 1; (l < (order-1)); l++)           \
2473             {                                      \
2474                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2475                                        (order-l-dr)*data[order-l-1]); \
2476             }                                      \
2477             data[0] = div*(1 - dr)*data[0];        \
2478                                                \
2479             for (k = 0; k < order; k++)                 \
2480             {                                      \
2481                 theta[j][i*order+k]  = data[k];    \
2482                 dtheta[j][i*order+k] = ddata[k];   \
2483             }                                      \
2484         }                                          \
2485     }
2486
2487 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2488                    rvec fractx[], int nr, int ind[], real charge[],
2489                    gmx_bool bFreeEnergy)
2490 {
2491     /* construct splines for local atoms */
2492     int  i, ii;
2493     real *xptr;
2494
2495     for (i = 0; i < nr; i++)
2496     {
2497         /* With free energy we do not use the charge check.
2498          * In most cases this will be more efficient than calling make_bsplines
2499          * twice, since usually more than half the particles have charges.
2500          */
2501         ii = ind[i];
2502         if (bFreeEnergy || charge[ii] != 0.0)
2503         {
2504             xptr = fractx[ii];
2505             switch (order)
2506             {
2507                 case 4:  CALC_SPLINE(4);     break;
2508                 case 5:  CALC_SPLINE(5);     break;
2509                 default: CALC_SPLINE(order); break;
2510             }
2511         }
2512     }
2513 }
2514
2515
2516 void make_dft_mod(real *mod, real *data, int ndata)
2517 {
2518     int i, j;
2519     real sc, ss, arg;
2520
2521     for (i = 0; i < ndata; i++)
2522     {
2523         sc = ss = 0;
2524         for (j = 0; j < ndata; j++)
2525         {
2526             arg = (2.0*M_PI*i*j)/ndata;
2527             sc += data[j]*cos(arg);
2528             ss += data[j]*sin(arg);
2529         }
2530         mod[i] = sc*sc+ss*ss;
2531     }
2532     for (i = 0; i < ndata; i++)
2533     {
2534         if (mod[i] < 1e-7)
2535         {
2536             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2537         }
2538     }
2539 }
2540
2541
2542 static void make_bspline_moduli(splinevec bsp_mod,
2543                                 int nx, int ny, int nz, int order)
2544 {
2545     int nmax = max(nx, max(ny, nz));
2546     real *data, *ddata, *bsp_data;
2547     int i, k, l;
2548     real div;
2549
2550     snew(data, order);
2551     snew(ddata, order);
2552     snew(bsp_data, nmax);
2553
2554     data[order-1] = 0;
2555     data[1]       = 0;
2556     data[0]       = 1;
2557
2558     for (k = 3; k < order; k++)
2559     {
2560         div       = 1.0/(k-1.0);
2561         data[k-1] = 0;
2562         for (l = 1; l < (k-1); l++)
2563         {
2564             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2565         }
2566         data[0] = div*data[0];
2567     }
2568     /* differentiate */
2569     ddata[0] = -data[0];
2570     for (k = 1; k < order; k++)
2571     {
2572         ddata[k] = data[k-1]-data[k];
2573     }
2574     div           = 1.0/(order-1);
2575     data[order-1] = 0;
2576     for (l = 1; l < (order-1); l++)
2577     {
2578         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2579     }
2580     data[0] = div*data[0];
2581
2582     for (i = 0; i < nmax; i++)
2583     {
2584         bsp_data[i] = 0;
2585     }
2586     for (i = 1; i <= order; i++)
2587     {
2588         bsp_data[i] = data[i-1];
2589     }
2590
2591     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2592     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2593     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2594
2595     sfree(data);
2596     sfree(ddata);
2597     sfree(bsp_data);
2598 }
2599
2600
2601 /* Return the P3M optimal influence function */
2602 static double do_p3m_influence(double z, int order)
2603 {
2604     double z2, z4;
2605
2606     z2 = z*z;
2607     z4 = z2*z2;
2608
2609     /* The formula and most constants can be found in:
2610      * Ballenegger et al., JCTC 8, 936 (2012)
2611      */
2612     switch (order)
2613     {
2614         case 2:
2615             return 1.0 - 2.0*z2/3.0;
2616             break;
2617         case 3:
2618             return 1.0 - z2 + 2.0*z4/15.0;
2619             break;
2620         case 4:
2621             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2622             break;
2623         case 5:
2624             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2625             break;
2626         case 6:
2627             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2628             break;
2629         case 7:
2630             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2631         case 8:
2632             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2633             break;
2634     }
2635
2636     return 0.0;
2637 }
2638
2639 /* Calculate the P3M B-spline moduli for one dimension */
2640 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2641 {
2642     double zarg, zai, sinzai, infl;
2643     int    maxk, i;
2644
2645     if (order > 8)
2646     {
2647         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2648     }
2649
2650     zarg = M_PI/n;
2651
2652     maxk = (n + 1)/2;
2653
2654     for (i = -maxk; i < 0; i++)
2655     {
2656         zai          = zarg*i;
2657         sinzai       = sin(zai);
2658         infl         = do_p3m_influence(sinzai, order);
2659         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2660     }
2661     bsp_mod[0] = 1.0;
2662     for (i = 1; i < maxk; i++)
2663     {
2664         zai        = zarg*i;
2665         sinzai     = sin(zai);
2666         infl       = do_p3m_influence(sinzai, order);
2667         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2668     }
2669 }
2670
2671 /* Calculate the P3M B-spline moduli */
2672 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2673                                     int nx, int ny, int nz, int order)
2674 {
2675     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2676     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2677     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2678 }
2679
2680
2681 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2682 {
2683     int nslab, n, i;
2684     int fw, bw;
2685
2686     nslab = atc->nslab;
2687
2688     n = 0;
2689     for (i = 1; i <= nslab/2; i++)
2690     {
2691         fw = (atc->nodeid + i) % nslab;
2692         bw = (atc->nodeid - i + nslab) % nslab;
2693         if (n < nslab - 1)
2694         {
2695             atc->node_dest[n] = fw;
2696             atc->node_src[n]  = bw;
2697             n++;
2698         }
2699         if (n < nslab - 1)
2700         {
2701             atc->node_dest[n] = bw;
2702             atc->node_src[n]  = fw;
2703             n++;
2704         }
2705     }
2706 }
2707
2708 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2709 {
2710     int thread;
2711
2712     if (NULL != log)
2713     {
2714         fprintf(log, "Destroying PME data structures.\n");
2715     }
2716
2717     sfree((*pmedata)->nnx);
2718     sfree((*pmedata)->nny);
2719     sfree((*pmedata)->nnz);
2720
2721     pmegrids_destroy(&(*pmedata)->pmegridA);
2722
2723     sfree((*pmedata)->fftgridA);
2724     sfree((*pmedata)->cfftgridA);
2725     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2726
2727     if ((*pmedata)->pmegridB.grid.grid != NULL)
2728     {
2729         pmegrids_destroy(&(*pmedata)->pmegridB);
2730         sfree((*pmedata)->fftgridB);
2731         sfree((*pmedata)->cfftgridB);
2732         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2733     }
2734     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2735     {
2736         free_work(&(*pmedata)->work[thread]);
2737     }
2738     sfree((*pmedata)->work);
2739
2740     sfree(*pmedata);
2741     *pmedata = NULL;
2742
2743     return 0;
2744 }
2745
2746 static int mult_up(int n, int f)
2747 {
2748     return ((n + f - 1)/f)*f;
2749 }
2750
2751
2752 static double pme_load_imbalance(gmx_pme_t pme)
2753 {
2754     int    nma, nmi;
2755     double n1, n2, n3;
2756
2757     nma = pme->nnodes_major;
2758     nmi = pme->nnodes_minor;
2759
2760     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2761     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2762     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2763
2764     /* pme_solve is roughly double the cost of an fft */
2765
2766     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2767 }
2768
2769 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2770                           int dimind, gmx_bool bSpread)
2771 {
2772     int nk, k, s, thread;
2773
2774     atc->dimind    = dimind;
2775     atc->nslab     = 1;
2776     atc->nodeid    = 0;
2777     atc->pd_nalloc = 0;
2778 #ifdef GMX_MPI
2779     if (pme->nnodes > 1)
2780     {
2781         atc->mpi_comm = pme->mpi_comm_d[dimind];
2782         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2783         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2784     }
2785     if (debug)
2786     {
2787         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2788     }
2789 #endif
2790
2791     atc->bSpread   = bSpread;
2792     atc->pme_order = pme->pme_order;
2793
2794     if (atc->nslab > 1)
2795     {
2796         /* These three allocations are not required for particle decomp. */
2797         snew(atc->node_dest, atc->nslab);
2798         snew(atc->node_src, atc->nslab);
2799         setup_coordinate_communication(atc);
2800
2801         snew(atc->count_thread, pme->nthread);
2802         for (thread = 0; thread < pme->nthread; thread++)
2803         {
2804             snew(atc->count_thread[thread], atc->nslab);
2805         }
2806         atc->count = atc->count_thread[0];
2807         snew(atc->rcount, atc->nslab);
2808         snew(atc->buf_index, atc->nslab);
2809     }
2810
2811     atc->nthread = pme->nthread;
2812     if (atc->nthread > 1)
2813     {
2814         snew(atc->thread_plist, atc->nthread);
2815     }
2816     snew(atc->spline, atc->nthread);
2817     for (thread = 0; thread < atc->nthread; thread++)
2818     {
2819         if (atc->nthread > 1)
2820         {
2821             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2822             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2823         }
2824         snew(atc->spline[thread].thread_one, pme->nthread);
2825         atc->spline[thread].thread_one[thread] = 1;
2826     }
2827 }
2828
2829 static void
2830 init_overlap_comm(pme_overlap_t *  ol,
2831                   int              norder,
2832 #ifdef GMX_MPI
2833                   MPI_Comm         comm,
2834 #endif
2835                   int              nnodes,
2836                   int              nodeid,
2837                   int              ndata,
2838                   int              commplainsize)
2839 {
2840     int lbnd, rbnd, maxlr, b, i;
2841     int exten;
2842     int nn, nk;
2843     pme_grid_comm_t *pgc;
2844     gmx_bool bCont;
2845     int fft_start, fft_end, send_index1, recv_index1;
2846 #ifdef GMX_MPI
2847     MPI_Status stat;
2848
2849     ol->mpi_comm = comm;
2850 #endif
2851
2852     ol->nnodes = nnodes;
2853     ol->nodeid = nodeid;
2854
2855     /* Linear translation of the PME grid won't affect reciprocal space
2856      * calculations, so to optimize we only interpolate "upwards",
2857      * which also means we only have to consider overlap in one direction.
2858      * I.e., particles on this node might also be spread to grid indices
2859      * that belong to higher nodes (modulo nnodes)
2860      */
2861
2862     snew(ol->s2g0, ol->nnodes+1);
2863     snew(ol->s2g1, ol->nnodes);
2864     if (debug)
2865     {
2866         fprintf(debug, "PME slab boundaries:");
2867     }
2868     for (i = 0; i < nnodes; i++)
2869     {
2870         /* s2g0 the local interpolation grid start.
2871          * s2g1 the local interpolation grid end.
2872          * Because grid overlap communication only goes forward,
2873          * the grid the slabs for fft's should be rounded down.
2874          */
2875         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2876         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2877
2878         if (debug)
2879         {
2880             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2881         }
2882     }
2883     ol->s2g0[nnodes] = ndata;
2884     if (debug)
2885     {
2886         fprintf(debug, "\n");
2887     }
2888
2889     /* Determine with how many nodes we need to communicate the grid overlap */
2890     b = 0;
2891     do
2892     {
2893         b++;
2894         bCont = FALSE;
2895         for (i = 0; i < nnodes; i++)
2896         {
2897             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2898                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2899             {
2900                 bCont = TRUE;
2901             }
2902         }
2903     }
2904     while (bCont && b < nnodes);
2905     ol->noverlap_nodes = b - 1;
2906
2907     snew(ol->send_id, ol->noverlap_nodes);
2908     snew(ol->recv_id, ol->noverlap_nodes);
2909     for (b = 0; b < ol->noverlap_nodes; b++)
2910     {
2911         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2912         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2913     }
2914     snew(ol->comm_data, ol->noverlap_nodes);
2915
2916     ol->send_size = 0;
2917     for (b = 0; b < ol->noverlap_nodes; b++)
2918     {
2919         pgc = &ol->comm_data[b];
2920         /* Send */
2921         fft_start        = ol->s2g0[ol->send_id[b]];
2922         fft_end          = ol->s2g0[ol->send_id[b]+1];
2923         if (ol->send_id[b] < nodeid)
2924         {
2925             fft_start += ndata;
2926             fft_end   += ndata;
2927         }
2928         send_index1       = ol->s2g1[nodeid];
2929         send_index1       = min(send_index1, fft_end);
2930         pgc->send_index0  = fft_start;
2931         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2932         ol->send_size    += pgc->send_nindex;
2933
2934         /* We always start receiving to the first index of our slab */
2935         fft_start        = ol->s2g0[ol->nodeid];
2936         fft_end          = ol->s2g0[ol->nodeid+1];
2937         recv_index1      = ol->s2g1[ol->recv_id[b]];
2938         if (ol->recv_id[b] > nodeid)
2939         {
2940             recv_index1 -= ndata;
2941         }
2942         recv_index1      = min(recv_index1, fft_end);
2943         pgc->recv_index0 = fft_start;
2944         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2945     }
2946
2947 #ifdef GMX_MPI
2948     /* Communicate the buffer sizes to receive */
2949     for (b = 0; b < ol->noverlap_nodes; b++)
2950     {
2951         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2952                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2953                      ol->mpi_comm, &stat);
2954     }
2955 #endif
2956
2957     /* For non-divisible grid we need pme_order iso pme_order-1 */
2958     snew(ol->sendbuf, norder*commplainsize);
2959     snew(ol->recvbuf, norder*commplainsize);
2960 }
2961
2962 static void
2963 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2964                               int **global_to_local,
2965                               real **fraction_shift)
2966 {
2967     int i;
2968     int * gtl;
2969     real * fsh;
2970
2971     snew(gtl, 5*n);
2972     snew(fsh, 5*n);
2973     for (i = 0; (i < 5*n); i++)
2974     {
2975         /* Determine the global to local grid index */
2976         gtl[i] = (i - local_start + n) % n;
2977         /* For coordinates that fall within the local grid the fraction
2978          * is correct, we don't need to shift it.
2979          */
2980         fsh[i] = 0;
2981         if (local_range < n)
2982         {
2983             /* Due to rounding issues i could be 1 beyond the lower or
2984              * upper boundary of the local grid. Correct the index for this.
2985              * If we shift the index, we need to shift the fraction by
2986              * the same amount in the other direction to not affect
2987              * the weights.
2988              * Note that due to this shifting the weights at the end of
2989              * the spline might change, but that will only involve values
2990              * between zero and values close to the precision of a real,
2991              * which is anyhow the accuracy of the whole mesh calculation.
2992              */
2993             /* With local_range=0 we should not change i=local_start */
2994             if (i % n != local_start)
2995             {
2996                 if (gtl[i] == n-1)
2997                 {
2998                     gtl[i] = 0;
2999                     fsh[i] = -1;
3000                 }
3001                 else if (gtl[i] == local_range)
3002                 {
3003                     gtl[i] = local_range - 1;
3004                     fsh[i] = 1;
3005                 }
3006             }
3007         }
3008     }
3009
3010     *global_to_local = gtl;
3011     *fraction_shift  = fsh;
3012 }
3013
3014 static pme_spline_work_t *make_pme_spline_work(int order)
3015 {
3016     pme_spline_work_t *work;
3017
3018 #ifdef PME_SSE_SPREAD_GATHER
3019     float  tmp[8];
3020     __m128 zero_SSE;
3021     int    of, i;
3022
3023     snew_aligned(work, 1, 16);
3024
3025     zero_SSE = _mm_setzero_ps();
3026
3027     /* Generate bit masks to mask out the unused grid entries,
3028      * as we only operate on order of the 8 grid entries that are
3029      * load into 2 SSE float registers.
3030      */
3031     for (of = 0; of < 8-(order-1); of++)
3032     {
3033         for (i = 0; i < 8; i++)
3034         {
3035             tmp[i] = (i >= of && i < of+order ? 1 : 0);
3036         }
3037         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
3038         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
3039         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
3040         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
3041     }
3042 #else
3043     work = NULL;
3044 #endif
3045
3046     return work;
3047 }
3048
3049 void gmx_pme_check_restrictions(int pme_order,
3050                                 int nkx, int nky, int nkz,
3051                                 int nnodes_major,
3052                                 int nnodes_minor,
3053                                 gmx_bool bUseThreads,
3054                                 gmx_bool bFatal,
3055                                 gmx_bool *bValidSettings)
3056 {
3057     if (pme_order > PME_ORDER_MAX)
3058     {
3059         if (!bFatal)
3060         {
3061             *bValidSettings = FALSE;
3062             return;
3063         }
3064         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3065                   pme_order, PME_ORDER_MAX);
3066     }
3067
3068     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3069         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3070         nkz <= pme_order)
3071     {
3072         if (!bFatal)
3073         {
3074             *bValidSettings = FALSE;
3075             return;
3076         }
3077         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3078                   pme_order);
3079     }
3080
3081     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3082      * We only allow multiple communication pulses in dim 1, not in dim 0.
3083      */
3084     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3085                         nkx != nnodes_major*(pme_order - 1)))
3086     {
3087         if (!bFatal)
3088         {
3089             *bValidSettings = FALSE;
3090             return;
3091         }
3092         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3093                   nkx/(double)nnodes_major, pme_order);
3094     }
3095
3096     if (bValidSettings != NULL)
3097     {
3098         *bValidSettings = TRUE;
3099     }
3100
3101     return;
3102 }
3103
3104 int gmx_pme_init(gmx_pme_t *         pmedata,
3105                  t_commrec *         cr,
3106                  int                 nnodes_major,
3107                  int                 nnodes_minor,
3108                  t_inputrec *        ir,
3109                  int                 homenr,
3110                  gmx_bool            bFreeEnergy,
3111                  gmx_bool            bReproducible,
3112                  int                 nthread)
3113 {
3114     gmx_pme_t pme = NULL;
3115
3116     int  use_threads, sum_use_threads;
3117     ivec ndata;
3118
3119     if (debug)
3120     {
3121         fprintf(debug, "Creating PME data structures.\n");
3122     }
3123     snew(pme, 1);
3124
3125     pme->redist_init         = FALSE;
3126     pme->sum_qgrid_tmp       = NULL;
3127     pme->sum_qgrid_dd_tmp    = NULL;
3128     pme->buf_nalloc          = 0;
3129     pme->redist_buf_nalloc   = 0;
3130
3131     pme->nnodes              = 1;
3132     pme->bPPnode             = TRUE;
3133
3134     pme->nnodes_major        = nnodes_major;
3135     pme->nnodes_minor        = nnodes_minor;
3136
3137 #ifdef GMX_MPI
3138     if (nnodes_major*nnodes_minor > 1)
3139     {
3140         pme->mpi_comm = cr->mpi_comm_mygroup;
3141
3142         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3143         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3144         if (pme->nnodes != nnodes_major*nnodes_minor)
3145         {
3146             gmx_incons("PME node count mismatch");
3147         }
3148     }
3149     else
3150     {
3151         pme->mpi_comm = MPI_COMM_NULL;
3152     }
3153 #endif
3154
3155     if (pme->nnodes == 1)
3156     {
3157 #ifdef GMX_MPI
3158         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3159         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3160 #endif
3161         pme->ndecompdim   = 0;
3162         pme->nodeid_major = 0;
3163         pme->nodeid_minor = 0;
3164 #ifdef GMX_MPI
3165         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3166 #endif
3167     }
3168     else
3169     {
3170         if (nnodes_minor == 1)
3171         {
3172 #ifdef GMX_MPI
3173             pme->mpi_comm_d[0] = pme->mpi_comm;
3174             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3175 #endif
3176             pme->ndecompdim   = 1;
3177             pme->nodeid_major = pme->nodeid;
3178             pme->nodeid_minor = 0;
3179
3180         }
3181         else if (nnodes_major == 1)
3182         {
3183 #ifdef GMX_MPI
3184             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3185             pme->mpi_comm_d[1] = pme->mpi_comm;
3186 #endif
3187             pme->ndecompdim   = 1;
3188             pme->nodeid_major = 0;
3189             pme->nodeid_minor = pme->nodeid;
3190         }
3191         else
3192         {
3193             if (pme->nnodes % nnodes_major != 0)
3194             {
3195                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3196             }
3197             pme->ndecompdim = 2;
3198
3199 #ifdef GMX_MPI
3200             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3201                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3202             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3203                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3204
3205             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3206             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3207             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3208             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3209 #endif
3210         }
3211         pme->bPPnode = (cr->duty & DUTY_PP);
3212     }
3213
3214     pme->nthread = nthread;
3215
3216      /* Check if any of the PME MPI ranks uses threads */
3217     use_threads = (pme->nthread > 1 ? 1 : 0);
3218 #ifdef GMX_MPI
3219     if (pme->nnodes > 1)
3220     {
3221         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3222                       MPI_SUM, pme->mpi_comm);
3223     }
3224     else
3225 #endif
3226     {
3227         sum_use_threads = use_threads;
3228     }
3229     pme->bUseThreads = (sum_use_threads > 0);
3230
3231     if (ir->ePBC == epbcSCREW)
3232     {
3233         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3234     }
3235
3236     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3237     pme->nkx         = ir->nkx;
3238     pme->nky         = ir->nky;
3239     pme->nkz         = ir->nkz;
3240     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3241     pme->pme_order   = ir->pme_order;
3242     pme->epsilon_r   = ir->epsilon_r;
3243
3244     /* If we violate restrictions, generate a fatal error here */
3245     gmx_pme_check_restrictions(pme->pme_order,
3246                                pme->nkx, pme->nky, pme->nkz,
3247                                pme->nnodes_major,
3248                                pme->nnodes_minor,
3249                                pme->bUseThreads,
3250                                TRUE,
3251                                NULL);
3252
3253     if (pme->nnodes > 1)
3254     {
3255         double imbal;
3256
3257 #ifdef GMX_MPI
3258         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3259         MPI_Type_commit(&(pme->rvec_mpi));
3260 #endif
3261
3262         /* Note that the charge spreading and force gathering, which usually
3263          * takes about the same amount of time as FFT+solve_pme,
3264          * is always fully load balanced
3265          * (unless the charge distribution is inhomogeneous).
3266          */
3267
3268         imbal = pme_load_imbalance(pme);
3269         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3270         {
3271             fprintf(stderr,
3272                     "\n"
3273                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3274                     "      For optimal PME load balancing\n"
3275                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3276                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3277                     "\n",
3278                     (int)((imbal-1)*100 + 0.5),
3279                     pme->nkx, pme->nky, pme->nnodes_major,
3280                     pme->nky, pme->nkz, pme->nnodes_minor);
3281         }
3282     }
3283
3284     /* For non-divisible grid we need pme_order iso pme_order-1 */
3285     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3286      * y is always copied through a buffer: we don't need padding in z,
3287      * but we do need the overlap in x because of the communication order.
3288      */
3289     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3290 #ifdef GMX_MPI
3291                       pme->mpi_comm_d[0],
3292 #endif
3293                       pme->nnodes_major, pme->nodeid_major,
3294                       pme->nkx,
3295                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3296
3297     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3298      * We do this with an offset buffer of equal size, so we need to allocate
3299      * extra for the offset. That's what the (+1)*pme->nkz is for.
3300      */
3301     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3302 #ifdef GMX_MPI
3303                       pme->mpi_comm_d[1],
3304 #endif
3305                       pme->nnodes_minor, pme->nodeid_minor,
3306                       pme->nky,
3307                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3308
3309     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3310      * Note that gmx_pme_check_restrictions checked for this already.
3311      */
3312     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3313     {
3314         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3315     }
3316
3317     snew(pme->bsp_mod[XX], pme->nkx);
3318     snew(pme->bsp_mod[YY], pme->nky);
3319     snew(pme->bsp_mod[ZZ], pme->nkz);
3320
3321     /* The required size of the interpolation grid, including overlap.
3322      * The allocated size (pmegrid_n?) might be slightly larger.
3323      */
3324     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3325         pme->overlap[0].s2g0[pme->nodeid_major];
3326     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3327         pme->overlap[1].s2g0[pme->nodeid_minor];
3328     pme->pmegrid_nz_base = pme->nkz;
3329     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3330     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3331
3332     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3333     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3334     pme->pmegrid_start_iz = 0;
3335
3336     make_gridindex5_to_localindex(pme->nkx,
3337                                   pme->pmegrid_start_ix,
3338                                   pme->pmegrid_nx - (pme->pme_order-1),
3339                                   &pme->nnx, &pme->fshx);
3340     make_gridindex5_to_localindex(pme->nky,
3341                                   pme->pmegrid_start_iy,
3342                                   pme->pmegrid_ny - (pme->pme_order-1),
3343                                   &pme->nny, &pme->fshy);
3344     make_gridindex5_to_localindex(pme->nkz,
3345                                   pme->pmegrid_start_iz,
3346                                   pme->pmegrid_nz_base,
3347                                   &pme->nnz, &pme->fshz);
3348
3349     pmegrids_init(&pme->pmegridA,
3350                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3351                   pme->pmegrid_nz_base,
3352                   pme->pme_order,
3353                   pme->bUseThreads,
3354                   pme->nthread,
3355                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3356                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3357
3358     pme->spline_work = make_pme_spline_work(pme->pme_order);
3359
3360     ndata[0] = pme->nkx;
3361     ndata[1] = pme->nky;
3362     ndata[2] = pme->nkz;
3363
3364     /* This routine will allocate the grid data to fit the FFTs */
3365     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3366                             &pme->fftgridA, &pme->cfftgridA,
3367                             pme->mpi_comm_d,
3368                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3369                             bReproducible, pme->nthread);
3370
3371     if (bFreeEnergy)
3372     {
3373         pmegrids_init(&pme->pmegridB,
3374                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3375                       pme->pmegrid_nz_base,
3376                       pme->pme_order,
3377                       pme->bUseThreads,
3378                       pme->nthread,
3379                       pme->nkx % pme->nnodes_major != 0,
3380                       pme->nky % pme->nnodes_minor != 0);
3381
3382         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3383                                 &pme->fftgridB, &pme->cfftgridB,
3384                                 pme->mpi_comm_d,
3385                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3386                                 bReproducible, pme->nthread);
3387     }
3388     else
3389     {
3390         pme->pmegridB.grid.grid = NULL;
3391         pme->fftgridB           = NULL;
3392         pme->cfftgridB          = NULL;
3393     }
3394
3395     if (!pme->bP3M)
3396     {
3397         /* Use plain SPME B-spline interpolation */
3398         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3399     }
3400     else
3401     {
3402         /* Use the P3M grid-optimized influence function */
3403         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3404     }
3405
3406     /* Use atc[0] for spreading */
3407     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3408     if (pme->ndecompdim >= 2)
3409     {
3410         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3411     }
3412
3413     if (pme->nnodes == 1)
3414     {
3415         pme->atc[0].n = homenr;
3416         pme_realloc_atomcomm_things(&pme->atc[0]);
3417     }
3418
3419     {
3420         int thread;
3421
3422         /* Use fft5d, order after FFT is y major, z, x minor */
3423
3424         snew(pme->work, pme->nthread);
3425         for (thread = 0; thread < pme->nthread; thread++)
3426         {
3427             realloc_work(&pme->work[thread], pme->nkx);
3428         }
3429     }
3430
3431     *pmedata = pme;
3432
3433     return 0;
3434 }
3435
3436 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3437 {
3438     int d, t;
3439
3440     for (d = 0; d < DIM; d++)
3441     {
3442         if (new->grid.n[d] > old->grid.n[d])
3443         {
3444             return;
3445         }
3446     }
3447
3448     sfree_aligned(new->grid.grid);
3449     new->grid.grid = old->grid.grid;
3450
3451     if (new->grid_th != NULL && new->nthread == old->nthread)
3452     {
3453         sfree_aligned(new->grid_all);
3454         for (t = 0; t < new->nthread; t++)
3455         {
3456             new->grid_th[t].grid = old->grid_th[t].grid;
3457         }
3458     }
3459 }
3460
3461 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3462                    t_commrec *         cr,
3463                    gmx_pme_t           pme_src,
3464                    const t_inputrec *  ir,
3465                    ivec                grid_size)
3466 {
3467     t_inputrec irc;
3468     int homenr;
3469     int ret;
3470
3471     irc     = *ir;
3472     irc.nkx = grid_size[XX];
3473     irc.nky = grid_size[YY];
3474     irc.nkz = grid_size[ZZ];
3475
3476     if (pme_src->nnodes == 1)
3477     {
3478         homenr = pme_src->atc[0].n;
3479     }
3480     else
3481     {
3482         homenr = -1;
3483     }
3484
3485     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3486                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3487
3488     if (ret == 0)
3489     {
3490         /* We can easily reuse the allocated pme grids in pme_src */
3491         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3492         /* We would like to reuse the fft grids, but that's harder */
3493     }
3494
3495     return ret;
3496 }
3497
3498
3499 static void copy_local_grid(gmx_pme_t pme,
3500                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3501 {
3502     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3503     int  fft_my, fft_mz;
3504     int  nsx, nsy, nsz;
3505     ivec nf;
3506     int  offx, offy, offz, x, y, z, i0, i0t;
3507     int  d;
3508     pmegrid_t *pmegrid;
3509     real *grid_th;
3510
3511     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3512                                    local_fft_ndata,
3513                                    local_fft_offset,
3514                                    local_fft_size);
3515     fft_my = local_fft_size[YY];
3516     fft_mz = local_fft_size[ZZ];
3517
3518     pmegrid = &pmegrids->grid_th[thread];
3519
3520     nsx = pmegrid->s[XX];
3521     nsy = pmegrid->s[YY];
3522     nsz = pmegrid->s[ZZ];
3523
3524     for (d = 0; d < DIM; d++)
3525     {
3526         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3527                     local_fft_ndata[d] - pmegrid->offset[d]);
3528     }
3529
3530     offx = pmegrid->offset[XX];
3531     offy = pmegrid->offset[YY];
3532     offz = pmegrid->offset[ZZ];
3533
3534     /* Directly copy the non-overlapping parts of the local grids.
3535      * This also initializes the full grid.
3536      */
3537     grid_th = pmegrid->grid;
3538     for (x = 0; x < nf[XX]; x++)
3539     {
3540         for (y = 0; y < nf[YY]; y++)
3541         {
3542             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3543             i0t = (x*nsy + y)*nsz;
3544             for (z = 0; z < nf[ZZ]; z++)
3545             {
3546                 fftgrid[i0+z] = grid_th[i0t+z];
3547             }
3548         }
3549     }
3550 }
3551
3552 static void
3553 reduce_threadgrid_overlap(gmx_pme_t pme,
3554                           const pmegrids_t *pmegrids, int thread,
3555                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3556 {
3557     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3558     int  fft_nx, fft_ny, fft_nz;
3559     int  fft_my, fft_mz;
3560     int  buf_my = -1;
3561     int  nsx, nsy, nsz;
3562     ivec ne;
3563     int  offx, offy, offz, x, y, z, i0, i0t;
3564     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3565     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3566     gmx_bool bCommX, bCommY;
3567     int  d;
3568     int  thread_f;
3569     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3570     const real *grid_th;
3571     real *commbuf = NULL;
3572
3573     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3574                                    local_fft_ndata,
3575                                    local_fft_offset,
3576                                    local_fft_size);
3577     fft_nx = local_fft_ndata[XX];
3578     fft_ny = local_fft_ndata[YY];
3579     fft_nz = local_fft_ndata[ZZ];
3580
3581     fft_my = local_fft_size[YY];
3582     fft_mz = local_fft_size[ZZ];
3583
3584     /* This routine is called when all thread have finished spreading.
3585      * Here each thread sums grid contributions calculated by other threads
3586      * to the thread local grid volume.
3587      * To minimize the number of grid copying operations,
3588      * this routines sums immediately from the pmegrid to the fftgrid.
3589      */
3590
3591     /* Determine which part of the full node grid we should operate on,
3592      * this is our thread local part of the full grid.
3593      */
3594     pmegrid = &pmegrids->grid_th[thread];
3595
3596     for (d = 0; d < DIM; d++)
3597     {
3598         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3599                     local_fft_ndata[d]);
3600     }
3601
3602     offx = pmegrid->offset[XX];
3603     offy = pmegrid->offset[YY];
3604     offz = pmegrid->offset[ZZ];
3605
3606
3607     bClearBufX  = TRUE;
3608     bClearBufY  = TRUE;
3609     bClearBufXY = TRUE;
3610
3611     /* Now loop over all the thread data blocks that contribute
3612      * to the grid region we (our thread) are operating on.
3613      */
3614     /* Note that ffy_nx/y is equal to the number of grid points
3615      * between the first point of our node grid and the one of the next node.
3616      */
3617     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3618     {
3619         fx     = pmegrid->ci[XX] + sx;
3620         ox     = 0;
3621         bCommX = FALSE;
3622         if (fx < 0)
3623         {
3624             fx    += pmegrids->nc[XX];
3625             ox    -= fft_nx;
3626             bCommX = (pme->nnodes_major > 1);
3627         }
3628         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3629         ox       += pmegrid_g->offset[XX];
3630         if (!bCommX)
3631         {
3632             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3633         }
3634         else
3635         {
3636             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3637         }
3638
3639         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3640         {
3641             fy     = pmegrid->ci[YY] + sy;
3642             oy     = 0;
3643             bCommY = FALSE;
3644             if (fy < 0)
3645             {
3646                 fy    += pmegrids->nc[YY];
3647                 oy    -= fft_ny;
3648                 bCommY = (pme->nnodes_minor > 1);
3649             }
3650             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3651             oy       += pmegrid_g->offset[YY];
3652             if (!bCommY)
3653             {
3654                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3655             }
3656             else
3657             {
3658                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3659             }
3660
3661             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3662             {
3663                 fz = pmegrid->ci[ZZ] + sz;
3664                 oz = 0;
3665                 if (fz < 0)
3666                 {
3667                     fz += pmegrids->nc[ZZ];
3668                     oz -= fft_nz;
3669                 }
3670                 pmegrid_g = &pmegrids->grid_th[fz];
3671                 oz       += pmegrid_g->offset[ZZ];
3672                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3673
3674                 if (sx == 0 && sy == 0 && sz == 0)
3675                 {
3676                     /* We have already added our local contribution
3677                      * before calling this routine, so skip it here.
3678                      */
3679                     continue;
3680                 }
3681
3682                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3683
3684                 pmegrid_f = &pmegrids->grid_th[thread_f];
3685
3686                 grid_th = pmegrid_f->grid;
3687
3688                 nsx = pmegrid_f->s[XX];
3689                 nsy = pmegrid_f->s[YY];
3690                 nsz = pmegrid_f->s[ZZ];
3691
3692 #ifdef DEBUG_PME_REDUCE
3693                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3694                        pme->nodeid, thread, thread_f,
3695                        pme->pmegrid_start_ix,
3696                        pme->pmegrid_start_iy,
3697                        pme->pmegrid_start_iz,
3698                        sx, sy, sz,
3699                        offx-ox, tx1-ox, offx, tx1,
3700                        offy-oy, ty1-oy, offy, ty1,
3701                        offz-oz, tz1-oz, offz, tz1);
3702 #endif
3703
3704                 if (!(bCommX || bCommY))
3705                 {
3706                     /* Copy from the thread local grid to the node grid */
3707                     for (x = offx; x < tx1; x++)
3708                     {
3709                         for (y = offy; y < ty1; y++)
3710                         {
3711                             i0  = (x*fft_my + y)*fft_mz;
3712                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3713                             for (z = offz; z < tz1; z++)
3714                             {
3715                                 fftgrid[i0+z] += grid_th[i0t+z];
3716                             }
3717                         }
3718                     }
3719                 }
3720                 else
3721                 {
3722                     /* The order of this conditional decides
3723                      * where the corner volume gets stored with x+y decomp.
3724                      */
3725                     if (bCommY)
3726                     {
3727                         commbuf = commbuf_y;
3728                         buf_my  = ty1 - offy;
3729                         if (bCommX)
3730                         {
3731                             /* We index commbuf modulo the local grid size */
3732                             commbuf += buf_my*fft_nx*fft_nz;
3733
3734                             bClearBuf   = bClearBufXY;
3735                             bClearBufXY = FALSE;
3736                         }
3737                         else
3738                         {
3739                             bClearBuf  = bClearBufY;
3740                             bClearBufY = FALSE;
3741                         }
3742                     }
3743                     else
3744                     {
3745                         commbuf    = commbuf_x;
3746                         buf_my     = fft_ny;
3747                         bClearBuf  = bClearBufX;
3748                         bClearBufX = FALSE;
3749                     }
3750
3751                     /* Copy to the communication buffer */
3752                     for (x = offx; x < tx1; x++)
3753                     {
3754                         for (y = offy; y < ty1; y++)
3755                         {
3756                             i0  = (x*buf_my + y)*fft_nz;
3757                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3758
3759                             if (bClearBuf)
3760                             {
3761                                 /* First access of commbuf, initialize it */
3762                                 for (z = offz; z < tz1; z++)
3763                                 {
3764                                     commbuf[i0+z]  = grid_th[i0t+z];
3765                                 }
3766                             }
3767                             else
3768                             {
3769                                 for (z = offz; z < tz1; z++)
3770                                 {
3771                                     commbuf[i0+z] += grid_th[i0t+z];
3772                                 }
3773                             }
3774                         }
3775                     }
3776                 }
3777             }
3778         }
3779     }
3780 }
3781
3782
3783 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3784 {
3785     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3786     pme_overlap_t *overlap;
3787     int  send_index0, send_nindex;
3788     int  recv_nindex;
3789 #ifdef GMX_MPI
3790     MPI_Status stat;
3791 #endif
3792     int  send_size_y, recv_size_y;
3793     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3794     real *sendptr, *recvptr;
3795     int  x, y, z, indg, indb;
3796
3797     /* Note that this routine is only used for forward communication.
3798      * Since the force gathering, unlike the charge spreading,
3799      * can be trivially parallelized over the particles,
3800      * the backwards process is much simpler and can use the "old"
3801      * communication setup.
3802      */
3803
3804     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3805                                    local_fft_ndata,
3806                                    local_fft_offset,
3807                                    local_fft_size);
3808
3809     if (pme->nnodes_minor > 1)
3810     {
3811         /* Major dimension */
3812         overlap = &pme->overlap[1];
3813
3814         if (pme->nnodes_major > 1)
3815         {
3816             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3817         }
3818         else
3819         {
3820             size_yx = 0;
3821         }
3822         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3823
3824         send_size_y = overlap->send_size;
3825
3826         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3827         {
3828             send_id       = overlap->send_id[ipulse];
3829             recv_id       = overlap->recv_id[ipulse];
3830             send_index0   =
3831                 overlap->comm_data[ipulse].send_index0 -
3832                 overlap->comm_data[0].send_index0;
3833             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3834             /* We don't use recv_index0, as we always receive starting at 0 */
3835             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3836             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3837
3838             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3839             recvptr = overlap->recvbuf;
3840
3841 #ifdef GMX_MPI
3842             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3843                          send_id, ipulse,
3844                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3845                          recv_id, ipulse,
3846                          overlap->mpi_comm, &stat);
3847 #endif
3848
3849             for (x = 0; x < local_fft_ndata[XX]; x++)
3850             {
3851                 for (y = 0; y < recv_nindex; y++)
3852                 {
3853                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3854                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3855                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3856                     {
3857                         fftgrid[indg+z] += recvptr[indb+z];
3858                     }
3859                 }
3860             }
3861
3862             if (pme->nnodes_major > 1)
3863             {
3864                 /* Copy from the received buffer to the send buffer for dim 0 */
3865                 sendptr = pme->overlap[0].sendbuf;
3866                 for (x = 0; x < size_yx; x++)
3867                 {
3868                     for (y = 0; y < recv_nindex; y++)
3869                     {
3870                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3871                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3872                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3873                         {
3874                             sendptr[indg+z] += recvptr[indb+z];
3875                         }
3876                     }
3877                 }
3878             }
3879         }
3880     }
3881
3882     /* We only support a single pulse here.
3883      * This is not a severe limitation, as this code is only used
3884      * with OpenMP and with OpenMP the (PME) domains can be larger.
3885      */
3886     if (pme->nnodes_major > 1)
3887     {
3888         /* Major dimension */
3889         overlap = &pme->overlap[0];
3890
3891         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3892         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3893
3894         ipulse = 0;
3895
3896         send_id       = overlap->send_id[ipulse];
3897         recv_id       = overlap->recv_id[ipulse];
3898         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3899         /* We don't use recv_index0, as we always receive starting at 0 */
3900         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3901
3902         sendptr = overlap->sendbuf;
3903         recvptr = overlap->recvbuf;
3904
3905         if (debug != NULL)
3906         {
3907             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3908                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3909         }
3910
3911 #ifdef GMX_MPI
3912         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3913                      send_id, ipulse,
3914                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3915                      recv_id, ipulse,
3916                      overlap->mpi_comm, &stat);
3917 #endif
3918
3919         for (x = 0; x < recv_nindex; x++)
3920         {
3921             for (y = 0; y < local_fft_ndata[YY]; y++)
3922             {
3923                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3924                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3925                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3926                 {
3927                     fftgrid[indg+z] += recvptr[indb+z];
3928                 }
3929             }
3930         }
3931     }
3932 }
3933
3934
3935 static void spread_on_grid(gmx_pme_t pme,
3936                            pme_atomcomm_t *atc, pmegrids_t *grids,
3937                            gmx_bool bCalcSplines, gmx_bool bSpread,
3938                            real *fftgrid)
3939 {
3940     int nthread, thread;
3941 #ifdef PME_TIME_THREADS
3942     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3943     static double cs1     = 0, cs2 = 0, cs3 = 0;
3944     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3945     static int cnt        = 0;
3946 #endif
3947
3948     nthread = pme->nthread;
3949     assert(nthread > 0);
3950
3951 #ifdef PME_TIME_THREADS
3952     c1 = omp_cyc_start();
3953 #endif
3954     if (bCalcSplines)
3955     {
3956 #pragma omp parallel for num_threads(nthread) schedule(static)
3957         for (thread = 0; thread < nthread; thread++)
3958         {
3959             int start, end;
3960
3961             start = atc->n* thread   /nthread;
3962             end   = atc->n*(thread+1)/nthread;
3963
3964             /* Compute fftgrid index for all atoms,
3965              * with help of some extra variables.
3966              */
3967             calc_interpolation_idx(pme, atc, start, end, thread);
3968         }
3969     }
3970 #ifdef PME_TIME_THREADS
3971     c1   = omp_cyc_end(c1);
3972     cs1 += (double)c1;
3973 #endif
3974
3975 #ifdef PME_TIME_THREADS
3976     c2 = omp_cyc_start();
3977 #endif
3978 #pragma omp parallel for num_threads(nthread) schedule(static)
3979     for (thread = 0; thread < nthread; thread++)
3980     {
3981         splinedata_t *spline;
3982         pmegrid_t *grid = NULL;
3983
3984         /* make local bsplines  */
3985         if (grids == NULL || !pme->bUseThreads)
3986         {
3987             spline = &atc->spline[0];
3988
3989             spline->n = atc->n;
3990
3991             if (bSpread)
3992             {
3993                 grid = &grids->grid;
3994             }
3995         }
3996         else
3997         {
3998             spline = &atc->spline[thread];
3999
4000             if (grids->nthread == 1)
4001             {
4002                 /* One thread, we operate on all charges */
4003                 spline->n = atc->n;
4004             }
4005             else
4006             {
4007                 /* Get the indices our thread should operate on */
4008                 make_thread_local_ind(atc, thread, spline);
4009             }
4010
4011             grid = &grids->grid_th[thread];
4012         }
4013
4014         if (bCalcSplines)
4015         {
4016             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4017                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
4018         }
4019
4020         if (bSpread)
4021         {
4022             /* put local atoms on grid. */
4023 #ifdef PME_TIME_SPREAD
4024             ct1a = omp_cyc_start();
4025 #endif
4026             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4027
4028             if (pme->bUseThreads)
4029             {
4030                 copy_local_grid(pme, grids, thread, fftgrid);
4031             }
4032 #ifdef PME_TIME_SPREAD
4033             ct1a          = omp_cyc_end(ct1a);
4034             cs1a[thread] += (double)ct1a;
4035 #endif
4036         }
4037     }
4038 #ifdef PME_TIME_THREADS
4039     c2   = omp_cyc_end(c2);
4040     cs2 += (double)c2;
4041 #endif
4042
4043     if (bSpread && pme->bUseThreads)
4044     {
4045 #ifdef PME_TIME_THREADS
4046         c3 = omp_cyc_start();
4047 #endif
4048 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4049         for (thread = 0; thread < grids->nthread; thread++)
4050         {
4051             reduce_threadgrid_overlap(pme, grids, thread,
4052                                       fftgrid,
4053                                       pme->overlap[0].sendbuf,
4054                                       pme->overlap[1].sendbuf);
4055         }
4056 #ifdef PME_TIME_THREADS
4057         c3   = omp_cyc_end(c3);
4058         cs3 += (double)c3;
4059 #endif
4060
4061         if (pme->nnodes > 1)
4062         {
4063             /* Communicate the overlapping part of the fftgrid.
4064              * For this communication call we need to check pme->bUseThreads
4065              * to have all ranks communicate here, regardless of pme->nthread.
4066              */
4067             sum_fftgrid_dd(pme, fftgrid);
4068         }
4069     }
4070
4071 #ifdef PME_TIME_THREADS
4072     cnt++;
4073     if (cnt % 20 == 0)
4074     {
4075         printf("idx %.2f spread %.2f red %.2f",
4076                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4077 #ifdef PME_TIME_SPREAD
4078         for (thread = 0; thread < nthread; thread++)
4079         {
4080             printf(" %.2f", cs1a[thread]*1e-9);
4081         }
4082 #endif
4083         printf("\n");
4084     }
4085 #endif
4086 }
4087
4088
4089 static void dump_grid(FILE *fp,
4090                       int sx, int sy, int sz, int nx, int ny, int nz,
4091                       int my, int mz, const real *g)
4092 {
4093     int x, y, z;
4094
4095     for (x = 0; x < nx; x++)
4096     {
4097         for (y = 0; y < ny; y++)
4098         {
4099             for (z = 0; z < nz; z++)
4100             {
4101                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4102                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4103             }
4104         }
4105     }
4106 }
4107
4108 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4109 {
4110     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4111
4112     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4113                                    local_fft_ndata,
4114                                    local_fft_offset,
4115                                    local_fft_size);
4116
4117     dump_grid(stderr,
4118               pme->pmegrid_start_ix,
4119               pme->pmegrid_start_iy,
4120               pme->pmegrid_start_iz,
4121               pme->pmegrid_nx-pme->pme_order+1,
4122               pme->pmegrid_ny-pme->pme_order+1,
4123               pme->pmegrid_nz-pme->pme_order+1,
4124               local_fft_size[YY],
4125               local_fft_size[ZZ],
4126               fftgrid);
4127 }
4128
4129
4130 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4131 {
4132     pme_atomcomm_t *atc;
4133     pmegrids_t *grid;
4134
4135     if (pme->nnodes > 1)
4136     {
4137         gmx_incons("gmx_pme_calc_energy called in parallel");
4138     }
4139     if (pme->bFEP > 1)
4140     {
4141         gmx_incons("gmx_pme_calc_energy with free energy");
4142     }
4143
4144     atc            = &pme->atc_energy;
4145     atc->nthread   = 1;
4146     if (atc->spline == NULL)
4147     {
4148         snew(atc->spline, atc->nthread);
4149     }
4150     atc->nslab     = 1;
4151     atc->bSpread   = TRUE;
4152     atc->pme_order = pme->pme_order;
4153     atc->n         = n;
4154     pme_realloc_atomcomm_things(atc);
4155     atc->x         = x;
4156     atc->q         = q;
4157
4158     /* We only use the A-charges grid */
4159     grid = &pme->pmegridA;
4160
4161     /* Only calculate the spline coefficients, don't actually spread */
4162     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4163
4164     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4165 }
4166
4167
4168 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4169                                    t_nrnb *nrnb, t_inputrec *ir,
4170                                    gmx_large_int_t step)
4171 {
4172     /* Reset all the counters related to performance over the run */
4173     wallcycle_stop(wcycle, ewcRUN);
4174     wallcycle_reset_all(wcycle);
4175     init_nrnb(nrnb);
4176     if (ir->nsteps >= 0)
4177     {
4178         /* ir->nsteps is not used here, but we update it for consistency */
4179         ir->nsteps -= step - ir->init_step;
4180     }
4181     ir->init_step = step;
4182     wallcycle_start(wcycle, ewcRUN);
4183 }
4184
4185
4186 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4187                                ivec grid_size,
4188                                t_commrec *cr, t_inputrec *ir,
4189                                gmx_pme_t *pme_ret)
4190 {
4191     int ind;
4192     gmx_pme_t pme = NULL;
4193
4194     ind = 0;
4195     while (ind < *npmedata)
4196     {
4197         pme = (*pmedata)[ind];
4198         if (pme->nkx == grid_size[XX] &&
4199             pme->nky == grid_size[YY] &&
4200             pme->nkz == grid_size[ZZ])
4201         {
4202             *pme_ret = pme;
4203
4204             return;
4205         }
4206
4207         ind++;
4208     }
4209
4210     (*npmedata)++;
4211     srenew(*pmedata, *npmedata);
4212
4213     /* Generate a new PME data structure, copying part of the old pointers */
4214     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4215
4216     *pme_ret = (*pmedata)[ind];
4217 }
4218
4219
4220 int gmx_pmeonly(gmx_pme_t pme,
4221                 t_commrec *cr,    t_nrnb *nrnb,
4222                 gmx_wallcycle_t wcycle,
4223                 real ewaldcoeff,  gmx_bool bGatherOnly,
4224                 t_inputrec *ir)
4225 {
4226     int npmedata;
4227     gmx_pme_t *pmedata;
4228     gmx_pme_pp_t pme_pp;
4229     int  ret;
4230     int  natoms;
4231     matrix box;
4232     rvec *x_pp      = NULL, *f_pp = NULL;
4233     real *chargeA   = NULL, *chargeB = NULL;
4234     real lambda     = 0;
4235     int  maxshift_x = 0, maxshift_y = 0;
4236     real energy, dvdlambda;
4237     matrix vir;
4238     float cycles;
4239     int  count;
4240     gmx_bool bEnerVir;
4241     gmx_large_int_t step, step_rel;
4242     ivec grid_switch;
4243
4244     /* This data will only use with PME tuning, i.e. switching PME grids */
4245     npmedata = 1;
4246     snew(pmedata, npmedata);
4247     pmedata[0] = pme;
4248
4249     pme_pp = gmx_pme_pp_init(cr);
4250
4251     init_nrnb(nrnb);
4252
4253     count = 0;
4254     do /****** this is a quasi-loop over time steps! */
4255     {
4256         /* The reason for having a loop here is PME grid tuning/switching */
4257         do
4258         {
4259             /* Domain decomposition */
4260             ret = gmx_pme_recv_q_x(pme_pp,
4261                                    &natoms,
4262                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4263                                    &maxshift_x, &maxshift_y,
4264                                    &pme->bFEP, &lambda,
4265                                    &bEnerVir,
4266                                    &step,
4267                                    grid_switch, &ewaldcoeff);
4268
4269             if (ret == pmerecvqxSWITCHGRID)
4270             {
4271                 /* Switch the PME grid to grid_switch */
4272                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4273             }
4274
4275             if (ret == pmerecvqxRESETCOUNTERS)
4276             {
4277                 /* Reset the cycle and flop counters */
4278                 reset_pmeonly_counters(cr, wcycle, nrnb, ir, step);
4279             }
4280         }
4281         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4282
4283         if (ret == pmerecvqxFINISH)
4284         {
4285             /* We should stop: break out of the loop */
4286             break;
4287         }
4288
4289         step_rel = step - ir->init_step;
4290
4291         if (count == 0)
4292         {
4293             wallcycle_start(wcycle, ewcRUN);
4294         }
4295
4296         wallcycle_start(wcycle, ewcPMEMESH);
4297
4298         dvdlambda = 0;
4299         clear_mat(vir);
4300         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4301                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4302                    &energy, lambda, &dvdlambda,
4303                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4304
4305         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4306
4307         gmx_pme_send_force_vir_ener(pme_pp,
4308                                     f_pp, vir, energy, dvdlambda,
4309                                     cycles);
4310
4311         count++;
4312     } /***** end of quasi-loop, we stop with the break above */
4313     while (TRUE);
4314
4315     return 0;
4316 }
4317
4318 int gmx_pme_do(gmx_pme_t pme,
4319                int start,       int homenr,
4320                rvec x[],        rvec f[],
4321                real *chargeA,   real *chargeB,
4322                matrix box, t_commrec *cr,
4323                int  maxshift_x, int maxshift_y,
4324                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4325                matrix vir,      real ewaldcoeff,
4326                real *energy,    real lambda,
4327                real *dvdlambda, int flags)
4328 {
4329     int     q, d, i, j, ntot, npme;
4330     int     nx, ny, nz;
4331     int     n_d, local_ny;
4332     pme_atomcomm_t *atc = NULL;
4333     pmegrids_t *pmegrid = NULL;
4334     real    *grid       = NULL;
4335     real    *ptr;
4336     rvec    *x_d, *f_d;
4337     real    *charge = NULL, *q_d;
4338     real    energy_AB[2];
4339     matrix  vir_AB[2];
4340     gmx_bool bClearF;
4341     gmx_parallel_3dfft_t pfft_setup;
4342     real *  fftgrid;
4343     t_complex * cfftgrid;
4344     int     thread;
4345     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4346     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4347
4348     assert(pme->nnodes > 0);
4349     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4350
4351     if (pme->nnodes > 1)
4352     {
4353         atc      = &pme->atc[0];
4354         atc->npd = homenr;
4355         if (atc->npd > atc->pd_nalloc)
4356         {
4357             atc->pd_nalloc = over_alloc_dd(atc->npd);
4358             srenew(atc->pd, atc->pd_nalloc);
4359         }
4360         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4361     }
4362     else
4363     {
4364         /* This could be necessary for TPI */
4365         pme->atc[0].n = homenr;
4366     }
4367
4368     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4369     {
4370         if (q == 0)
4371         {
4372             pmegrid    = &pme->pmegridA;
4373             fftgrid    = pme->fftgridA;
4374             cfftgrid   = pme->cfftgridA;
4375             pfft_setup = pme->pfft_setupA;
4376             charge     = chargeA+start;
4377         }
4378         else
4379         {
4380             pmegrid    = &pme->pmegridB;
4381             fftgrid    = pme->fftgridB;
4382             cfftgrid   = pme->cfftgridB;
4383             pfft_setup = pme->pfft_setupB;
4384             charge     = chargeB+start;
4385         }
4386         grid = pmegrid->grid.grid;
4387         /* Unpack structure */
4388         if (debug)
4389         {
4390             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4391                     cr->nnodes, cr->nodeid);
4392             fprintf(debug, "Grid = %p\n", (void*)grid);
4393             if (grid == NULL)
4394             {
4395                 gmx_fatal(FARGS, "No grid!");
4396             }
4397         }
4398         where();
4399
4400         m_inv_ur0(box, pme->recipbox);
4401
4402         if (pme->nnodes == 1)
4403         {
4404             atc = &pme->atc[0];
4405             if (DOMAINDECOMP(cr))
4406             {
4407                 atc->n = homenr;
4408                 pme_realloc_atomcomm_things(atc);
4409             }
4410             atc->x = x;
4411             atc->q = charge;
4412             atc->f = f;
4413         }
4414         else
4415         {
4416             wallcycle_start(wcycle, ewcPME_REDISTXF);
4417             for (d = pme->ndecompdim-1; d >= 0; d--)
4418             {
4419                 if (d == pme->ndecompdim-1)
4420                 {
4421                     n_d = homenr;
4422                     x_d = x + start;
4423                     q_d = charge;
4424                 }
4425                 else
4426                 {
4427                     n_d = pme->atc[d+1].n;
4428                     x_d = atc->x;
4429                     q_d = atc->q;
4430                 }
4431                 atc      = &pme->atc[d];
4432                 atc->npd = n_d;
4433                 if (atc->npd > atc->pd_nalloc)
4434                 {
4435                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4436                     srenew(atc->pd, atc->pd_nalloc);
4437                 }
4438                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4439                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4440                 where();
4441
4442                 GMX_BARRIER(cr->mpi_comm_mygroup);
4443                 /* Redistribute x (only once) and qA or qB */
4444                 if (DOMAINDECOMP(cr))
4445                 {
4446                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4447                 }
4448                 else
4449                 {
4450                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4451                 }
4452             }
4453             where();
4454
4455             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4456         }
4457
4458         if (debug)
4459         {
4460             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4461                     cr->nodeid, atc->n);
4462         }
4463
4464         if (flags & GMX_PME_SPREAD_Q)
4465         {
4466             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4467
4468             /* Spread the charges on a grid */
4469             GMX_MPE_LOG(ev_spread_on_grid_start);
4470
4471             /* Spread the charges on a grid */
4472             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4473             GMX_MPE_LOG(ev_spread_on_grid_finish);
4474
4475             if (q == 0)
4476             {
4477                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4478             }
4479             inc_nrnb(nrnb, eNR_SPREADQBSP,
4480                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4481
4482             if (!pme->bUseThreads)
4483             {
4484                 wrap_periodic_pmegrid(pme, grid);
4485
4486                 /* sum contributions to local grid from other nodes */
4487 #ifdef GMX_MPI
4488                 if (pme->nnodes > 1)
4489                 {
4490                     GMX_BARRIER(cr->mpi_comm_mygroup);
4491                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4492                     where();
4493                 }
4494 #endif
4495
4496                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4497             }
4498
4499             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4500
4501             /*
4502                dump_local_fftgrid(pme,fftgrid);
4503                exit(0);
4504              */
4505         }
4506
4507         /* Here we start a large thread parallel region */
4508 #pragma omp parallel num_threads(pme->nthread) private(thread)
4509         {
4510             thread = gmx_omp_get_thread_num();
4511             if (flags & GMX_PME_SOLVE)
4512             {
4513                 int loop_count;
4514
4515                 /* do 3d-fft */
4516                 if (thread == 0)
4517                 {
4518                     GMX_BARRIER(cr->mpi_comm_mygroup);
4519                     GMX_MPE_LOG(ev_gmxfft3d_start);
4520                     wallcycle_start(wcycle, ewcPME_FFT);
4521                 }
4522                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4523                                            fftgrid, cfftgrid, thread, wcycle);
4524                 if (thread == 0)
4525                 {
4526                     wallcycle_stop(wcycle, ewcPME_FFT);
4527                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4528                 }
4529                 where();
4530
4531                 /* solve in k-space for our local cells */
4532                 if (thread == 0)
4533                 {
4534                     GMX_BARRIER(cr->mpi_comm_mygroup);
4535                     GMX_MPE_LOG(ev_solve_pme_start);
4536                     wallcycle_start(wcycle, ewcPME_SOLVE);
4537                 }
4538                 loop_count =
4539                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4540                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4541                                   bCalcEnerVir,
4542                                   pme->nthread, thread);
4543                 if (thread == 0)
4544                 {
4545                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4546                     where();
4547                     GMX_MPE_LOG(ev_solve_pme_finish);
4548                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4549                 }
4550             }
4551
4552             if (bCalcF)
4553             {
4554                 /* do 3d-invfft */
4555                 if (thread == 0)
4556                 {
4557                     GMX_BARRIER(cr->mpi_comm_mygroup);
4558                     GMX_MPE_LOG(ev_gmxfft3d_start);
4559                     where();
4560                     wallcycle_start(wcycle, ewcPME_FFT);
4561                 }
4562                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4563                                            cfftgrid, fftgrid, thread, wcycle);
4564                 if (thread == 0)
4565                 {
4566                     wallcycle_stop(wcycle, ewcPME_FFT);
4567
4568                     where();
4569                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4570
4571                     if (pme->nodeid == 0)
4572                     {
4573                         ntot  = pme->nkx*pme->nky*pme->nkz;
4574                         npme  = ntot*log((real)ntot)/log(2.0);
4575                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4576                     }
4577
4578                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4579                 }
4580
4581                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4582             }
4583         }
4584         /* End of thread parallel section.
4585          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4586          */
4587
4588         if (bCalcF)
4589         {
4590             /* distribute local grid to all nodes */
4591 #ifdef GMX_MPI
4592             if (pme->nnodes > 1)
4593             {
4594                 GMX_BARRIER(cr->mpi_comm_mygroup);
4595                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4596             }
4597 #endif
4598             where();
4599
4600             unwrap_periodic_pmegrid(pme, grid);
4601
4602             /* interpolate forces for our local atoms */
4603             GMX_BARRIER(cr->mpi_comm_mygroup);
4604             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4605
4606             where();
4607
4608             /* If we are running without parallelization,
4609              * atc->f is the actual force array, not a buffer,
4610              * therefore we should not clear it.
4611              */
4612             bClearF = (q == 0 && PAR(cr));
4613 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4614             for (thread = 0; thread < pme->nthread; thread++)
4615             {
4616                 gather_f_bsplines(pme, grid, bClearF, atc,
4617                                   &atc->spline[thread],
4618                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4619             }
4620
4621             where();
4622
4623             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4624
4625             inc_nrnb(nrnb, eNR_GATHERFBSP,
4626                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4627             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4628         }
4629
4630         if (bCalcEnerVir)
4631         {
4632             /* This should only be called on the master thread
4633              * and after the threads have synchronized.
4634              */
4635             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4636         }
4637     } /* of q-loop */
4638
4639     if (bCalcF && pme->nnodes > 1)
4640     {
4641         wallcycle_start(wcycle, ewcPME_REDISTXF);
4642         for (d = 0; d < pme->ndecompdim; d++)
4643         {
4644             atc = &pme->atc[d];
4645             if (d == pme->ndecompdim - 1)
4646             {
4647                 n_d = homenr;
4648                 f_d = f + start;
4649             }
4650             else
4651             {
4652                 n_d = pme->atc[d+1].n;
4653                 f_d = pme->atc[d+1].f;
4654             }
4655             GMX_BARRIER(cr->mpi_comm_mygroup);
4656             if (DOMAINDECOMP(cr))
4657             {
4658                 dd_pmeredist_f(pme, atc, n_d, f_d,
4659                                d == pme->ndecompdim-1 && pme->bPPnode);
4660             }
4661             else
4662             {
4663                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4664             }
4665         }
4666
4667         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4668     }
4669     where();
4670
4671     if (bCalcEnerVir)
4672     {
4673         if (!pme->bFEP)
4674         {
4675             *energy = energy_AB[0];
4676             m_add(vir, vir_AB[0], vir);
4677         }
4678         else
4679         {
4680             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4681             *dvdlambda += energy_AB[1] - energy_AB[0];
4682             for (i = 0; i < DIM; i++)
4683             {
4684                 for (j = 0; j < DIM; j++)
4685                 {
4686                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4687                         lambda*vir_AB[1][i][j];
4688                 }
4689             }
4690         }
4691     }
4692     else
4693     {
4694         *energy = 0;
4695     }
4696
4697     if (debug)
4698     {
4699         fprintf(debug, "PME mesh energy: %g\n", *energy);
4700     }
4701
4702     return 0;
4703 }