c641e468219c39425b2fe8112ac6e8b70f58f742
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team,
6  * check out http://www.gromacs.org for more information.
7  * Copyright (c) 2012,2013, by the GROMACS development team, led by
8  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
9  * others, as listed in the AUTHORS file in the top-level source
10  * directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* IMPORTANT FOR DEVELOPERS:
39  *
40  * Triclinic pme stuff isn't entirely trivial, and we've experienced
41  * some bugs during development (many of them due to me). To avoid
42  * this in the future, please check the following things if you make
43  * changes in this file:
44  *
45  * 1. You should obtain identical (at least to the PME precision)
46  *    energies, forces, and virial for
47  *    a rectangular box and a triclinic one where the z (or y) axis is
48  *    tilted a whole box side. For instance you could use these boxes:
49  *
50  *    rectangular       triclinic
51  *     2  0  0           2  0  0
52  *     0  2  0           0  2  0
53  *     0  0  6           2  2  6
54  *
55  * 2. You should check the energy conservation in a triclinic box.
56  *
57  * It might seem an overkill, but better safe than sorry.
58  * /Erik 001109
59  */
60
61 #ifdef HAVE_CONFIG_H
62 #include <config.h>
63 #endif
64
65 #ifdef GMX_LIB_MPI
66 #include <mpi.h>
67 #endif
68 #ifdef GMX_THREAD_MPI
69 #include "tmpi.h"
70 #endif
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93 #include "gmx_omp.h"
94
95 /* Include the SIMD macro file and then check for support */
96 #include "gmx_simd_macros.h"
97 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
98 /* Turn on SIMD intrinsics for PME solve */
99 #define PME_SIMD
100 #endif
101
102 /* SIMD spread+gather only in single precision with SSE2 or higher available.
103  * We might want to switch to use gmx_simd_macros.h, but this is somewhat
104  * complicated, as we use unaligned and/or 4-wide only loads.
105  */
106 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
107 #define PME_SSE_SPREAD_GATHER
108 #include <emmintrin.h>
109 /* Some old AMD processors could have problems with unaligned loads+stores */
110 #ifndef GMX_FAHCORE
111 #define PME_SSE_UNALIGNED
112 #endif
113 #endif
114
115 #include "mpelogging.h"
116
117 #define DFT_TOL 1e-7
118 /* #define PRT_FORCE */
119 /* conditions for on the fly time-measurement */
120 /* #define TAKETIME (step > 1 && timesteps < 10) */
121 #define TAKETIME FALSE
122
123 /* #define PME_TIME_THREADS */
124
125 #ifdef GMX_DOUBLE
126 #define mpi_type MPI_DOUBLE
127 #else
128 #define mpi_type MPI_FLOAT
129 #endif
130
131 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
132 #define GMX_CACHE_SEP 64
133
134 /* We only define a maximum to be able to use local arrays without allocation.
135  * An order larger than 12 should never be needed, even for test cases.
136  * If needed it can be changed here.
137  */
138 #define PME_ORDER_MAX 12
139
140 /* Internal datastructures */
141 typedef struct {
142     int send_index0;
143     int send_nindex;
144     int recv_index0;
145     int recv_nindex;
146     int recv_size;   /* Receive buffer width, used with OpenMP */
147 } pme_grid_comm_t;
148
149 typedef struct {
150 #ifdef GMX_MPI
151     MPI_Comm         mpi_comm;
152 #endif
153     int              nnodes, nodeid;
154     int             *s2g0;
155     int             *s2g1;
156     int              noverlap_nodes;
157     int             *send_id, *recv_id;
158     int              send_size; /* Send buffer width, used with OpenMP */
159     pme_grid_comm_t *comm_data;
160     real            *sendbuf;
161     real            *recvbuf;
162 } pme_overlap_t;
163
164 typedef struct {
165     int *n;      /* Cumulative counts of the number of particles per thread */
166     int  nalloc; /* Allocation size of i */
167     int *i;      /* Particle indices ordered on thread index (n) */
168 } thread_plist_t;
169
170 typedef struct {
171     int      *thread_one;
172     int       n;
173     int      *ind;
174     splinevec theta;
175     real     *ptr_theta_z;
176     splinevec dtheta;
177     real     *ptr_dtheta_z;
178 } splinedata_t;
179
180 typedef struct {
181     int      dimind;        /* The index of the dimension, 0=x, 1=y */
182     int      nslab;
183     int      nodeid;
184 #ifdef GMX_MPI
185     MPI_Comm mpi_comm;
186 #endif
187
188     int     *node_dest;     /* The nodes to send x and q to with DD */
189     int     *node_src;      /* The nodes to receive x and q from with DD */
190     int     *buf_index;     /* Index for commnode into the buffers */
191
192     int      maxshift;
193
194     int      npd;
195     int      pd_nalloc;
196     int     *pd;
197     int     *count;         /* The number of atoms to send to each node */
198     int    **count_thread;
199     int     *rcount;        /* The number of atoms to receive */
200
201     int      n;
202     int      nalloc;
203     rvec    *x;
204     real    *q;
205     rvec    *f;
206     gmx_bool bSpread;       /* These coordinates are used for spreading */
207     int      pme_order;
208     ivec    *idx;
209     rvec    *fractx;            /* Fractional coordinate relative to the
210                                  * lower cell boundary
211                                  */
212     int             nthread;
213     int            *thread_idx; /* Which thread should spread which charge */
214     thread_plist_t *thread_plist;
215     splinedata_t   *spline;
216 } pme_atomcomm_t;
217
218 #define FLBS  3
219 #define FLBSZ 4
220
221 typedef struct {
222     ivec  ci;     /* The spatial location of this grid         */
223     ivec  n;      /* The used size of *grid, including order-1 */
224     ivec  offset; /* The grid offset from the full node grid   */
225     int   order;  /* PME spreading order                       */
226     ivec  s;      /* The allocated size of *grid, s >= n       */
227     real *grid;   /* The grid local thread, size n             */
228 } pmegrid_t;
229
230 typedef struct {
231     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
232     int        nthread;      /* The number of threads operating on this grid     */
233     ivec       nc;           /* The local spatial decomposition over the threads */
234     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
235     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
236     int      **g2t;          /* The grid to thread index                         */
237     ivec       nthread_comm; /* The number of threads to communicate with        */
238 } pmegrids_t;
239
240
241 typedef struct {
242 #ifdef PME_SSE_SPREAD_GATHER
243     /* Masks for SSE aligned spreading and gathering */
244     __m128 mask_SSE0[6], mask_SSE1[6];
245 #else
246     int    dummy; /* C89 requires that struct has at least one member */
247 #endif
248 } pme_spline_work_t;
249
250 typedef struct {
251     /* work data for solve_pme */
252     int      nalloc;
253     real *   mhx;
254     real *   mhy;
255     real *   mhz;
256     real *   m2;
257     real *   denom;
258     real *   tmp1_alloc;
259     real *   tmp1;
260     real *   eterm;
261     real *   m2inv;
262
263     real     energy;
264     matrix   vir;
265 } pme_work_t;
266
267 typedef struct gmx_pme {
268     int           ndecompdim; /* The number of decomposition dimensions */
269     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
270     int           nodeid_major;
271     int           nodeid_minor;
272     int           nnodes;    /* The number of nodes doing PME */
273     int           nnodes_major;
274     int           nnodes_minor;
275
276     MPI_Comm      mpi_comm;
277     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
278 #ifdef GMX_MPI
279     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
280 #endif
281
282     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
283     int        nthread;       /* The number of threads doing PME on our rank */
284
285     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
286     gmx_bool   bFEP;          /* Compute Free energy contribution */
287     int        nkx, nky, nkz; /* Grid dimensions */
288     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
289     int        pme_order;
290     real       epsilon_r;
291
292     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
293     pmegrids_t pmegridB;
294     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
295     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
296     /* pmegrid_nz might be larger than strictly necessary to ensure
297      * memory alignment, pmegrid_nz_base gives the real base size.
298      */
299     int     pmegrid_nz_base;
300     /* The local PME grid starting indices */
301     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
302
303     /* Work data for spreading and gathering */
304     pme_spline_work_t    *spline_work;
305
306     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
307     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
308     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
309
310     t_complex            *cfftgridA;  /* Grids for complex FFT data */
311     t_complex            *cfftgridB;
312     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
313
314     gmx_parallel_3dfft_t  pfft_setupA;
315     gmx_parallel_3dfft_t  pfft_setupB;
316
317     int                  *nnx, *nny, *nnz;
318     real                 *fshx, *fshy, *fshz;
319
320     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
321     matrix                recipbox;
322     splinevec             bsp_mod;
323
324     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
325
326     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
327
328     rvec                 *bufv;       /* Communication buffer */
329     real                 *bufr;       /* Communication buffer */
330     int                   buf_nalloc; /* The communication buffer size */
331
332     /* thread local work data for solve_pme */
333     pme_work_t *work;
334
335     /* Work data for PME_redist */
336     gmx_bool redist_init;
337     int *    scounts;
338     int *    rcounts;
339     int *    sdispls;
340     int *    rdispls;
341     int *    sidx;
342     int *    idxa;
343     real *   redist_buf;
344     int      redist_buf_nalloc;
345
346     /* Work data for sum_qgrid */
347     real *   sum_qgrid_tmp;
348     real *   sum_qgrid_dd_tmp;
349 } t_gmx_pme;
350
351
352 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
353                                    int start, int end, int thread)
354 {
355     int             i;
356     int            *idxptr, tix, tiy, tiz;
357     real           *xptr, *fptr, tx, ty, tz;
358     real            rxx, ryx, ryy, rzx, rzy, rzz;
359     int             nx, ny, nz;
360     int             start_ix, start_iy, start_iz;
361     int            *g2tx, *g2ty, *g2tz;
362     gmx_bool        bThreads;
363     int            *thread_idx = NULL;
364     thread_plist_t *tpl        = NULL;
365     int            *tpl_n      = NULL;
366     int             thread_i;
367
368     nx  = pme->nkx;
369     ny  = pme->nky;
370     nz  = pme->nkz;
371
372     start_ix = pme->pmegrid_start_ix;
373     start_iy = pme->pmegrid_start_iy;
374     start_iz = pme->pmegrid_start_iz;
375
376     rxx = pme->recipbox[XX][XX];
377     ryx = pme->recipbox[YY][XX];
378     ryy = pme->recipbox[YY][YY];
379     rzx = pme->recipbox[ZZ][XX];
380     rzy = pme->recipbox[ZZ][YY];
381     rzz = pme->recipbox[ZZ][ZZ];
382
383     g2tx = pme->pmegridA.g2t[XX];
384     g2ty = pme->pmegridA.g2t[YY];
385     g2tz = pme->pmegridA.g2t[ZZ];
386
387     bThreads = (atc->nthread > 1);
388     if (bThreads)
389     {
390         thread_idx = atc->thread_idx;
391
392         tpl   = &atc->thread_plist[thread];
393         tpl_n = tpl->n;
394         for (i = 0; i < atc->nthread; i++)
395         {
396             tpl_n[i] = 0;
397         }
398     }
399
400     for (i = start; i < end; i++)
401     {
402         xptr   = atc->x[i];
403         idxptr = atc->idx[i];
404         fptr   = atc->fractx[i];
405
406         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
407         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
408         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
409         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
410
411         tix = (int)(tx);
412         tiy = (int)(ty);
413         tiz = (int)(tz);
414
415         /* Because decomposition only occurs in x and y,
416          * we never have a fraction correction in z.
417          */
418         fptr[XX] = tx - tix + pme->fshx[tix];
419         fptr[YY] = ty - tiy + pme->fshy[tiy];
420         fptr[ZZ] = tz - tiz;
421
422         idxptr[XX] = pme->nnx[tix];
423         idxptr[YY] = pme->nny[tiy];
424         idxptr[ZZ] = pme->nnz[tiz];
425
426 #ifdef DEBUG
427         range_check(idxptr[XX], 0, pme->pmegrid_nx);
428         range_check(idxptr[YY], 0, pme->pmegrid_ny);
429         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
430 #endif
431
432         if (bThreads)
433         {
434             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
435             thread_idx[i] = thread_i;
436             tpl_n[thread_i]++;
437         }
438     }
439
440     if (bThreads)
441     {
442         /* Make a list of particle indices sorted on thread */
443
444         /* Get the cumulative count */
445         for (i = 1; i < atc->nthread; i++)
446         {
447             tpl_n[i] += tpl_n[i-1];
448         }
449         /* The current implementation distributes particles equally
450          * over the threads, so we could actually allocate for that
451          * in pme_realloc_atomcomm_things.
452          */
453         if (tpl_n[atc->nthread-1] > tpl->nalloc)
454         {
455             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
456             srenew(tpl->i, tpl->nalloc);
457         }
458         /* Set tpl_n to the cumulative start */
459         for (i = atc->nthread-1; i >= 1; i--)
460         {
461             tpl_n[i] = tpl_n[i-1];
462         }
463         tpl_n[0] = 0;
464
465         /* Fill our thread local array with indices sorted on thread */
466         for (i = start; i < end; i++)
467         {
468             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
469         }
470         /* Now tpl_n contains the cummulative count again */
471     }
472 }
473
474 static void make_thread_local_ind(pme_atomcomm_t *atc,
475                                   int thread, splinedata_t *spline)
476 {
477     int             n, t, i, start, end;
478     thread_plist_t *tpl;
479
480     /* Combine the indices made by each thread into one index */
481
482     n     = 0;
483     start = 0;
484     for (t = 0; t < atc->nthread; t++)
485     {
486         tpl = &atc->thread_plist[t];
487         /* Copy our part (start - end) from the list of thread t */
488         if (thread > 0)
489         {
490             start = tpl->n[thread-1];
491         }
492         end = tpl->n[thread];
493         for (i = start; i < end; i++)
494         {
495             spline->ind[n++] = tpl->i[i];
496         }
497     }
498
499     spline->n = n;
500 }
501
502
503 static void pme_calc_pidx(int start, int end,
504                           matrix recipbox, rvec x[],
505                           pme_atomcomm_t *atc, int *count)
506 {
507     int   nslab, i;
508     int   si;
509     real *xptr, s;
510     real  rxx, ryx, rzx, ryy, rzy;
511     int  *pd;
512
513     /* Calculate PME task index (pidx) for each grid index.
514      * Here we always assign equally sized slabs to each node
515      * for load balancing reasons (the PME grid spacing is not used).
516      */
517
518     nslab = atc->nslab;
519     pd    = atc->pd;
520
521     /* Reset the count */
522     for (i = 0; i < nslab; i++)
523     {
524         count[i] = 0;
525     }
526
527     if (atc->dimind == 0)
528     {
529         rxx = recipbox[XX][XX];
530         ryx = recipbox[YY][XX];
531         rzx = recipbox[ZZ][XX];
532         /* Calculate the node index in x-dimension */
533         for (i = start; i < end; i++)
534         {
535             xptr   = x[i];
536             /* Fractional coordinates along box vectors */
537             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
538             si    = (int)(s + 2*nslab) % nslab;
539             pd[i] = si;
540             count[si]++;
541         }
542     }
543     else
544     {
545         ryy = recipbox[YY][YY];
546         rzy = recipbox[ZZ][YY];
547         /* Calculate the node index in y-dimension */
548         for (i = start; i < end; i++)
549         {
550             xptr   = x[i];
551             /* Fractional coordinates along box vectors */
552             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
553             si    = (int)(s + 2*nslab) % nslab;
554             pd[i] = si;
555             count[si]++;
556         }
557     }
558 }
559
560 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
561                                   pme_atomcomm_t *atc)
562 {
563     int nthread, thread, slab;
564
565     nthread = atc->nthread;
566
567 #pragma omp parallel for num_threads(nthread) schedule(static)
568     for (thread = 0; thread < nthread; thread++)
569     {
570         pme_calc_pidx(natoms* thread   /nthread,
571                       natoms*(thread+1)/nthread,
572                       recipbox, x, atc, atc->count_thread[thread]);
573     }
574     /* Non-parallel reduction, since nslab is small */
575
576     for (thread = 1; thread < nthread; thread++)
577     {
578         for (slab = 0; slab < atc->nslab; slab++)
579         {
580             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
581         }
582     }
583 }
584
585 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
586 {
587     const int padding = 4;
588     int       i;
589
590     srenew(th[XX], nalloc);
591     srenew(th[YY], nalloc);
592     /* In z we add padding, this is only required for the aligned SSE code */
593     srenew(*ptr_z, nalloc+2*padding);
594     th[ZZ] = *ptr_z + padding;
595
596     for (i = 0; i < padding; i++)
597     {
598         (*ptr_z)[               i] = 0;
599         (*ptr_z)[padding+nalloc+i] = 0;
600     }
601 }
602
603 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
604 {
605     int i, d;
606
607     srenew(spline->ind, atc->nalloc);
608     /* Initialize the index to identity so it works without threads */
609     for (i = 0; i < atc->nalloc; i++)
610     {
611         spline->ind[i] = i;
612     }
613
614     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
615                       atc->pme_order*atc->nalloc);
616     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
617                       atc->pme_order*atc->nalloc);
618 }
619
620 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
621 {
622     int nalloc_old, i, j, nalloc_tpl;
623
624     /* We have to avoid a NULL pointer for atc->x to avoid
625      * possible fatal errors in MPI routines.
626      */
627     if (atc->n > atc->nalloc || atc->nalloc == 0)
628     {
629         nalloc_old  = atc->nalloc;
630         atc->nalloc = over_alloc_dd(max(atc->n, 1));
631
632         if (atc->nslab > 1)
633         {
634             srenew(atc->x, atc->nalloc);
635             srenew(atc->q, atc->nalloc);
636             srenew(atc->f, atc->nalloc);
637             for (i = nalloc_old; i < atc->nalloc; i++)
638             {
639                 clear_rvec(atc->f[i]);
640             }
641         }
642         if (atc->bSpread)
643         {
644             srenew(atc->fractx, atc->nalloc);
645             srenew(atc->idx, atc->nalloc);
646
647             if (atc->nthread > 1)
648             {
649                 srenew(atc->thread_idx, atc->nalloc);
650             }
651
652             for (i = 0; i < atc->nthread; i++)
653             {
654                 pme_realloc_splinedata(&atc->spline[i], atc);
655             }
656         }
657     }
658 }
659
660 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
661                          int n, gmx_bool bXF, rvec *x_f, real *charge,
662                          pme_atomcomm_t *atc)
663 /* Redistribute particle data for PME calculation */
664 /* domain decomposition by x coordinate           */
665 {
666     int *idxa;
667     int  i, ii;
668
669     if (FALSE == pme->redist_init)
670     {
671         snew(pme->scounts, atc->nslab);
672         snew(pme->rcounts, atc->nslab);
673         snew(pme->sdispls, atc->nslab);
674         snew(pme->rdispls, atc->nslab);
675         snew(pme->sidx, atc->nslab);
676         pme->redist_init = TRUE;
677     }
678     if (n > pme->redist_buf_nalloc)
679     {
680         pme->redist_buf_nalloc = over_alloc_dd(n);
681         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
682     }
683
684     pme->idxa = atc->pd;
685
686 #ifdef GMX_MPI
687     if (forw && bXF)
688     {
689         /* forward, redistribution from pp to pme */
690
691         /* Calculate send counts and exchange them with other nodes */
692         for (i = 0; (i < atc->nslab); i++)
693         {
694             pme->scounts[i] = 0;
695         }
696         for (i = 0; (i < n); i++)
697         {
698             pme->scounts[pme->idxa[i]]++;
699         }
700         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
701
702         /* Calculate send and receive displacements and index into send
703            buffer */
704         pme->sdispls[0] = 0;
705         pme->rdispls[0] = 0;
706         pme->sidx[0]    = 0;
707         for (i = 1; i < atc->nslab; i++)
708         {
709             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
710             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
711             pme->sidx[i]    = pme->sdispls[i];
712         }
713         /* Total # of particles to be received */
714         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
715
716         pme_realloc_atomcomm_things(atc);
717
718         /* Copy particle coordinates into send buffer and exchange*/
719         for (i = 0; (i < n); i++)
720         {
721             ii = DIM*pme->sidx[pme->idxa[i]];
722             pme->sidx[pme->idxa[i]]++;
723             pme->redist_buf[ii+XX] = x_f[i][XX];
724             pme->redist_buf[ii+YY] = x_f[i][YY];
725             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
726         }
727         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
728                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
729                       pme->rvec_mpi, atc->mpi_comm);
730     }
731     if (forw)
732     {
733         /* Copy charge into send buffer and exchange*/
734         for (i = 0; i < atc->nslab; i++)
735         {
736             pme->sidx[i] = pme->sdispls[i];
737         }
738         for (i = 0; (i < n); i++)
739         {
740             ii = pme->sidx[pme->idxa[i]];
741             pme->sidx[pme->idxa[i]]++;
742             pme->redist_buf[ii] = charge[i];
743         }
744         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
745                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
746                       atc->mpi_comm);
747     }
748     else   /* backward, redistribution from pme to pp */
749     {
750         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
751                       pme->redist_buf, pme->scounts, pme->sdispls,
752                       pme->rvec_mpi, atc->mpi_comm);
753
754         /* Copy data from receive buffer */
755         for (i = 0; i < atc->nslab; i++)
756         {
757             pme->sidx[i] = pme->sdispls[i];
758         }
759         for (i = 0; (i < n); i++)
760         {
761             ii          = DIM*pme->sidx[pme->idxa[i]];
762             x_f[i][XX] += pme->redist_buf[ii+XX];
763             x_f[i][YY] += pme->redist_buf[ii+YY];
764             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
765             pme->sidx[pme->idxa[i]]++;
766         }
767     }
768 #endif
769 }
770
771 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
772                             gmx_bool bBackward, int shift,
773                             void *buf_s, int nbyte_s,
774                             void *buf_r, int nbyte_r)
775 {
776 #ifdef GMX_MPI
777     int        dest, src;
778     MPI_Status stat;
779
780     if (bBackward == FALSE)
781     {
782         dest = atc->node_dest[shift];
783         src  = atc->node_src[shift];
784     }
785     else
786     {
787         dest = atc->node_src[shift];
788         src  = atc->node_dest[shift];
789     }
790
791     if (nbyte_s > 0 && nbyte_r > 0)
792     {
793         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
794                      dest, shift,
795                      buf_r, nbyte_r, MPI_BYTE,
796                      src, shift,
797                      atc->mpi_comm, &stat);
798     }
799     else if (nbyte_s > 0)
800     {
801         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
802                  dest, shift,
803                  atc->mpi_comm);
804     }
805     else if (nbyte_r > 0)
806     {
807         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
808                  src, shift,
809                  atc->mpi_comm, &stat);
810     }
811 #endif
812 }
813
814 static void dd_pmeredist_x_q(gmx_pme_t pme,
815                              int n, gmx_bool bX, rvec *x, real *charge,
816                              pme_atomcomm_t *atc)
817 {
818     int *commnode, *buf_index;
819     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
820
821     commnode  = atc->node_dest;
822     buf_index = atc->buf_index;
823
824     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
825
826     nsend = 0;
827     for (i = 0; i < nnodes_comm; i++)
828     {
829         buf_index[commnode[i]] = nsend;
830         nsend                 += atc->count[commnode[i]];
831     }
832     if (bX)
833     {
834         if (atc->count[atc->nodeid] + nsend != n)
835         {
836             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
837                       "This usually means that your system is not well equilibrated.",
838                       n - (atc->count[atc->nodeid] + nsend),
839                       pme->nodeid, 'x'+atc->dimind);
840         }
841
842         if (nsend > pme->buf_nalloc)
843         {
844             pme->buf_nalloc = over_alloc_dd(nsend);
845             srenew(pme->bufv, pme->buf_nalloc);
846             srenew(pme->bufr, pme->buf_nalloc);
847         }
848
849         atc->n = atc->count[atc->nodeid];
850         for (i = 0; i < nnodes_comm; i++)
851         {
852             scount = atc->count[commnode[i]];
853             /* Communicate the count */
854             if (debug)
855             {
856                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
857                         atc->dimind, atc->nodeid, commnode[i], scount);
858             }
859             pme_dd_sendrecv(atc, FALSE, i,
860                             &scount, sizeof(int),
861                             &atc->rcount[i], sizeof(int));
862             atc->n += atc->rcount[i];
863         }
864
865         pme_realloc_atomcomm_things(atc);
866     }
867
868     local_pos = 0;
869     for (i = 0; i < n; i++)
870     {
871         node = atc->pd[i];
872         if (node == atc->nodeid)
873         {
874             /* Copy direct to the receive buffer */
875             if (bX)
876             {
877                 copy_rvec(x[i], atc->x[local_pos]);
878             }
879             atc->q[local_pos] = charge[i];
880             local_pos++;
881         }
882         else
883         {
884             /* Copy to the send buffer */
885             if (bX)
886             {
887                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
888             }
889             pme->bufr[buf_index[node]] = charge[i];
890             buf_index[node]++;
891         }
892     }
893
894     buf_pos = 0;
895     for (i = 0; i < nnodes_comm; i++)
896     {
897         scount = atc->count[commnode[i]];
898         rcount = atc->rcount[i];
899         if (scount > 0 || rcount > 0)
900         {
901             if (bX)
902             {
903                 /* Communicate the coordinates */
904                 pme_dd_sendrecv(atc, FALSE, i,
905                                 pme->bufv[buf_pos], scount*sizeof(rvec),
906                                 atc->x[local_pos], rcount*sizeof(rvec));
907             }
908             /* Communicate the charges */
909             pme_dd_sendrecv(atc, FALSE, i,
910                             pme->bufr+buf_pos, scount*sizeof(real),
911                             atc->q+local_pos, rcount*sizeof(real));
912             buf_pos   += scount;
913             local_pos += atc->rcount[i];
914         }
915     }
916 }
917
918 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
919                            int n, rvec *f,
920                            gmx_bool bAddF)
921 {
922     int *commnode, *buf_index;
923     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
924
925     commnode  = atc->node_dest;
926     buf_index = atc->buf_index;
927
928     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
929
930     local_pos = atc->count[atc->nodeid];
931     buf_pos   = 0;
932     for (i = 0; i < nnodes_comm; i++)
933     {
934         scount = atc->rcount[i];
935         rcount = atc->count[commnode[i]];
936         if (scount > 0 || rcount > 0)
937         {
938             /* Communicate the forces */
939             pme_dd_sendrecv(atc, TRUE, i,
940                             atc->f[local_pos], scount*sizeof(rvec),
941                             pme->bufv[buf_pos], rcount*sizeof(rvec));
942             local_pos += scount;
943         }
944         buf_index[commnode[i]] = buf_pos;
945         buf_pos               += rcount;
946     }
947
948     local_pos = 0;
949     if (bAddF)
950     {
951         for (i = 0; i < n; i++)
952         {
953             node = atc->pd[i];
954             if (node == atc->nodeid)
955             {
956                 /* Add from the local force array */
957                 rvec_inc(f[i], atc->f[local_pos]);
958                 local_pos++;
959             }
960             else
961             {
962                 /* Add from the receive buffer */
963                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
964                 buf_index[node]++;
965             }
966         }
967     }
968     else
969     {
970         for (i = 0; i < n; i++)
971         {
972             node = atc->pd[i];
973             if (node == atc->nodeid)
974             {
975                 /* Copy from the local force array */
976                 copy_rvec(atc->f[local_pos], f[i]);
977                 local_pos++;
978             }
979             else
980             {
981                 /* Copy from the receive buffer */
982                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
983                 buf_index[node]++;
984             }
985         }
986     }
987 }
988
989 #ifdef GMX_MPI
990 static void
991 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
992 {
993     pme_overlap_t *overlap;
994     int            send_index0, send_nindex;
995     int            recv_index0, recv_nindex;
996     MPI_Status     stat;
997     int            i, j, k, ix, iy, iz, icnt;
998     int            ipulse, send_id, recv_id, datasize;
999     real          *p;
1000     real          *sendptr, *recvptr;
1001
1002     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1003     overlap = &pme->overlap[1];
1004
1005     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1006     {
1007         /* Since we have already (un)wrapped the overlap in the z-dimension,
1008          * we only have to communicate 0 to nkz (not pmegrid_nz).
1009          */
1010         if (direction == GMX_SUM_QGRID_FORWARD)
1011         {
1012             send_id       = overlap->send_id[ipulse];
1013             recv_id       = overlap->recv_id[ipulse];
1014             send_index0   = overlap->comm_data[ipulse].send_index0;
1015             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1016             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1017             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1018         }
1019         else
1020         {
1021             send_id       = overlap->recv_id[ipulse];
1022             recv_id       = overlap->send_id[ipulse];
1023             send_index0   = overlap->comm_data[ipulse].recv_index0;
1024             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1025             recv_index0   = overlap->comm_data[ipulse].send_index0;
1026             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1027         }
1028
1029         /* Copy data to contiguous send buffer */
1030         if (debug)
1031         {
1032             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1033                     pme->nodeid, overlap->nodeid, send_id,
1034                     pme->pmegrid_start_iy,
1035                     send_index0-pme->pmegrid_start_iy,
1036                     send_index0-pme->pmegrid_start_iy+send_nindex);
1037         }
1038         icnt = 0;
1039         for (i = 0; i < pme->pmegrid_nx; i++)
1040         {
1041             ix = i;
1042             for (j = 0; j < send_nindex; j++)
1043             {
1044                 iy = j + send_index0 - pme->pmegrid_start_iy;
1045                 for (k = 0; k < pme->nkz; k++)
1046                 {
1047                     iz = k;
1048                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1049                 }
1050             }
1051         }
1052
1053         datasize      = pme->pmegrid_nx * pme->nkz;
1054
1055         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1056                      send_id, ipulse,
1057                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1058                      recv_id, ipulse,
1059                      overlap->mpi_comm, &stat);
1060
1061         /* Get data from contiguous recv buffer */
1062         if (debug)
1063         {
1064             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1065                     pme->nodeid, overlap->nodeid, recv_id,
1066                     pme->pmegrid_start_iy,
1067                     recv_index0-pme->pmegrid_start_iy,
1068                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1069         }
1070         icnt = 0;
1071         for (i = 0; i < pme->pmegrid_nx; i++)
1072         {
1073             ix = i;
1074             for (j = 0; j < recv_nindex; j++)
1075             {
1076                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1077                 for (k = 0; k < pme->nkz; k++)
1078                 {
1079                     iz = k;
1080                     if (direction == GMX_SUM_QGRID_FORWARD)
1081                     {
1082                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1083                     }
1084                     else
1085                     {
1086                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1087                     }
1088                 }
1089             }
1090         }
1091     }
1092
1093     /* Major dimension is easier, no copying required,
1094      * but we might have to sum to separate array.
1095      * Since we don't copy, we have to communicate up to pmegrid_nz,
1096      * not nkz as for the minor direction.
1097      */
1098     overlap = &pme->overlap[0];
1099
1100     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1101     {
1102         if (direction == GMX_SUM_QGRID_FORWARD)
1103         {
1104             send_id       = overlap->send_id[ipulse];
1105             recv_id       = overlap->recv_id[ipulse];
1106             send_index0   = overlap->comm_data[ipulse].send_index0;
1107             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1108             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1109             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1110             recvptr       = overlap->recvbuf;
1111         }
1112         else
1113         {
1114             send_id       = overlap->recv_id[ipulse];
1115             recv_id       = overlap->send_id[ipulse];
1116             send_index0   = overlap->comm_data[ipulse].recv_index0;
1117             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1118             recv_index0   = overlap->comm_data[ipulse].send_index0;
1119             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1120             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1121         }
1122
1123         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1124         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1125
1126         if (debug)
1127         {
1128             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1129                     pme->nodeid, overlap->nodeid, send_id,
1130                     pme->pmegrid_start_ix,
1131                     send_index0-pme->pmegrid_start_ix,
1132                     send_index0-pme->pmegrid_start_ix+send_nindex);
1133             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1134                     pme->nodeid, overlap->nodeid, recv_id,
1135                     pme->pmegrid_start_ix,
1136                     recv_index0-pme->pmegrid_start_ix,
1137                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1138         }
1139
1140         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1141                      send_id, ipulse,
1142                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1143                      recv_id, ipulse,
1144                      overlap->mpi_comm, &stat);
1145
1146         /* ADD data from contiguous recv buffer */
1147         if (direction == GMX_SUM_QGRID_FORWARD)
1148         {
1149             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1150             for (i = 0; i < recv_nindex*datasize; i++)
1151             {
1152                 p[i] += overlap->recvbuf[i];
1153             }
1154         }
1155     }
1156 }
1157 #endif
1158
1159
1160 static int
1161 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1162 {
1163     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1164     ivec    local_pme_size;
1165     int     i, ix, iy, iz;
1166     int     pmeidx, fftidx;
1167
1168     /* Dimensions should be identical for A/B grid, so we just use A here */
1169     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1170                                    local_fft_ndata,
1171                                    local_fft_offset,
1172                                    local_fft_size);
1173
1174     local_pme_size[0] = pme->pmegrid_nx;
1175     local_pme_size[1] = pme->pmegrid_ny;
1176     local_pme_size[2] = pme->pmegrid_nz;
1177
1178     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1179        the offset is identical, and the PME grid always has more data (due to overlap)
1180      */
1181     {
1182 #ifdef DEBUG_PME
1183         FILE *fp, *fp2;
1184         char  fn[STRLEN], format[STRLEN];
1185         real  val;
1186         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1187         fp = ffopen(fn, "w");
1188         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1189         fp2 = ffopen(fn, "w");
1190         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1191 #endif
1192
1193         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1194         {
1195             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1196             {
1197                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1198                 {
1199                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1200                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1201                     fftgrid[fftidx] = pmegrid[pmeidx];
1202 #ifdef DEBUG_PME
1203                     val = 100*pmegrid[pmeidx];
1204                     if (pmegrid[pmeidx] != 0)
1205                     {
1206                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1207                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1208                     }
1209                     if (pmegrid[pmeidx] != 0)
1210                     {
1211                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1212                                 "qgrid",
1213                                 pme->pmegrid_start_ix + ix,
1214                                 pme->pmegrid_start_iy + iy,
1215                                 pme->pmegrid_start_iz + iz,
1216                                 pmegrid[pmeidx]);
1217                     }
1218 #endif
1219                 }
1220             }
1221         }
1222 #ifdef DEBUG_PME
1223         ffclose(fp);
1224         ffclose(fp2);
1225 #endif
1226     }
1227     return 0;
1228 }
1229
1230
1231 static gmx_cycles_t omp_cyc_start()
1232 {
1233     return gmx_cycles_read();
1234 }
1235
1236 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1237 {
1238     return gmx_cycles_read() - c;
1239 }
1240
1241
1242 static int
1243 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1244                         int nthread, int thread)
1245 {
1246     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1247     ivec          local_pme_size;
1248     int           ixy0, ixy1, ixy, ix, iy, iz;
1249     int           pmeidx, fftidx;
1250 #ifdef PME_TIME_THREADS
1251     gmx_cycles_t  c1;
1252     static double cs1 = 0;
1253     static int    cnt = 0;
1254 #endif
1255
1256 #ifdef PME_TIME_THREADS
1257     c1 = omp_cyc_start();
1258 #endif
1259     /* Dimensions should be identical for A/B grid, so we just use A here */
1260     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1261                                    local_fft_ndata,
1262                                    local_fft_offset,
1263                                    local_fft_size);
1264
1265     local_pme_size[0] = pme->pmegrid_nx;
1266     local_pme_size[1] = pme->pmegrid_ny;
1267     local_pme_size[2] = pme->pmegrid_nz;
1268
1269     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1270        the offset is identical, and the PME grid always has more data (due to overlap)
1271      */
1272     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1273     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1274
1275     for (ixy = ixy0; ixy < ixy1; ixy++)
1276     {
1277         ix = ixy/local_fft_ndata[YY];
1278         iy = ixy - ix*local_fft_ndata[YY];
1279
1280         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1281         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1282         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1283         {
1284             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1285         }
1286     }
1287
1288 #ifdef PME_TIME_THREADS
1289     c1   = omp_cyc_end(c1);
1290     cs1 += (double)c1;
1291     cnt++;
1292     if (cnt % 20 == 0)
1293     {
1294         printf("copy %.2f\n", cs1*1e-9);
1295     }
1296 #endif
1297
1298     return 0;
1299 }
1300
1301
1302 static void
1303 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1304 {
1305     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1306
1307     nx = pme->nkx;
1308     ny = pme->nky;
1309     nz = pme->nkz;
1310
1311     pnx = pme->pmegrid_nx;
1312     pny = pme->pmegrid_ny;
1313     pnz = pme->pmegrid_nz;
1314
1315     overlap = pme->pme_order - 1;
1316
1317     /* Add periodic overlap in z */
1318     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1319     {
1320         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1321         {
1322             for (iz = 0; iz < overlap; iz++)
1323             {
1324                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1325                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1326             }
1327         }
1328     }
1329
1330     if (pme->nnodes_minor == 1)
1331     {
1332         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1333         {
1334             for (iy = 0; iy < overlap; iy++)
1335             {
1336                 for (iz = 0; iz < nz; iz++)
1337                 {
1338                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1339                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1340                 }
1341             }
1342         }
1343     }
1344
1345     if (pme->nnodes_major == 1)
1346     {
1347         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1348
1349         for (ix = 0; ix < overlap; ix++)
1350         {
1351             for (iy = 0; iy < ny_x; iy++)
1352             {
1353                 for (iz = 0; iz < nz; iz++)
1354                 {
1355                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1356                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1357                 }
1358             }
1359         }
1360     }
1361 }
1362
1363
1364 static void
1365 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1366 {
1367     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1368
1369     nx = pme->nkx;
1370     ny = pme->nky;
1371     nz = pme->nkz;
1372
1373     pnx = pme->pmegrid_nx;
1374     pny = pme->pmegrid_ny;
1375     pnz = pme->pmegrid_nz;
1376
1377     overlap = pme->pme_order - 1;
1378
1379     if (pme->nnodes_major == 1)
1380     {
1381         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1382
1383         for (ix = 0; ix < overlap; ix++)
1384         {
1385             int iy, iz;
1386
1387             for (iy = 0; iy < ny_x; iy++)
1388             {
1389                 for (iz = 0; iz < nz; iz++)
1390                 {
1391                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1392                         pmegrid[(ix*pny+iy)*pnz+iz];
1393                 }
1394             }
1395         }
1396     }
1397
1398     if (pme->nnodes_minor == 1)
1399     {
1400 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1401         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1402         {
1403             int iy, iz;
1404
1405             for (iy = 0; iy < overlap; iy++)
1406             {
1407                 for (iz = 0; iz < nz; iz++)
1408                 {
1409                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1410                         pmegrid[(ix*pny+iy)*pnz+iz];
1411                 }
1412             }
1413         }
1414     }
1415
1416     /* Copy periodic overlap in z */
1417 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1418     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1419     {
1420         int iy, iz;
1421
1422         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1423         {
1424             for (iz = 0; iz < overlap; iz++)
1425             {
1426                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1427                     pmegrid[(ix*pny+iy)*pnz+iz];
1428             }
1429         }
1430     }
1431 }
1432
1433 static void clear_grid(int nx, int ny, int nz, real *grid,
1434                        ivec fs, int *flag,
1435                        int fx, int fy, int fz,
1436                        int order)
1437 {
1438     int nc, ncz;
1439     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1440     int flind;
1441
1442     nc  = 2 + (order - 2)/FLBS;
1443     ncz = 2 + (order - 2)/FLBSZ;
1444
1445     for (fsx = fx; fsx < fx+nc; fsx++)
1446     {
1447         for (fsy = fy; fsy < fy+nc; fsy++)
1448         {
1449             for (fsz = fz; fsz < fz+ncz; fsz++)
1450             {
1451                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1452                 if (flag[flind] == 0)
1453                 {
1454                     gx  = fsx*FLBS;
1455                     gy  = fsy*FLBS;
1456                     gz  = fsz*FLBSZ;
1457                     g0x = (gx*ny + gy)*nz + gz;
1458                     for (x = 0; x < FLBS; x++)
1459                     {
1460                         g0y = g0x;
1461                         for (y = 0; y < FLBS; y++)
1462                         {
1463                             for (z = 0; z < FLBSZ; z++)
1464                             {
1465                                 grid[g0y+z] = 0;
1466                             }
1467                             g0y += nz;
1468                         }
1469                         g0x += ny*nz;
1470                     }
1471
1472                     flag[flind] = 1;
1473                 }
1474             }
1475         }
1476     }
1477 }
1478
1479 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1480 #define DO_BSPLINE(order)                            \
1481     for (ithx = 0; (ithx < order); ithx++)                    \
1482     {                                                    \
1483         index_x = (i0+ithx)*pny*pnz;                     \
1484         valx    = qn*thx[ithx];                          \
1485                                                      \
1486         for (ithy = 0; (ithy < order); ithy++)                \
1487         {                                                \
1488             valxy    = valx*thy[ithy];                   \
1489             index_xy = index_x+(j0+ithy)*pnz;            \
1490                                                      \
1491             for (ithz = 0; (ithz < order); ithz++)            \
1492             {                                            \
1493                 index_xyz        = index_xy+(k0+ithz);   \
1494                 grid[index_xyz] += valxy*thz[ithz];      \
1495             }                                            \
1496         }                                                \
1497     }
1498
1499
1500 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1501                                      pme_atomcomm_t *atc, splinedata_t *spline,
1502                                      pme_spline_work_t *work)
1503 {
1504
1505     /* spread charges from home atoms to local grid */
1506     real          *grid;
1507     pme_overlap_t *ol;
1508     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1509     int       *    idxptr;
1510     int            order, norder, index_x, index_xy, index_xyz;
1511     real           valx, valxy, qn;
1512     real          *thx, *thy, *thz;
1513     int            localsize, bndsize;
1514     int            pnx, pny, pnz, ndatatot;
1515     int            offx, offy, offz;
1516
1517     pnx = pmegrid->s[XX];
1518     pny = pmegrid->s[YY];
1519     pnz = pmegrid->s[ZZ];
1520
1521     offx = pmegrid->offset[XX];
1522     offy = pmegrid->offset[YY];
1523     offz = pmegrid->offset[ZZ];
1524
1525     ndatatot = pnx*pny*pnz;
1526     grid     = pmegrid->grid;
1527     for (i = 0; i < ndatatot; i++)
1528     {
1529         grid[i] = 0;
1530     }
1531
1532     order = pmegrid->order;
1533
1534     for (nn = 0; nn < spline->n; nn++)
1535     {
1536         n  = spline->ind[nn];
1537         qn = atc->q[n];
1538
1539         if (qn != 0)
1540         {
1541             idxptr = atc->idx[n];
1542             norder = nn*order;
1543
1544             i0   = idxptr[XX] - offx;
1545             j0   = idxptr[YY] - offy;
1546             k0   = idxptr[ZZ] - offz;
1547
1548             thx = spline->theta[XX] + norder;
1549             thy = spline->theta[YY] + norder;
1550             thz = spline->theta[ZZ] + norder;
1551
1552             switch (order)
1553             {
1554                 case 4:
1555 #ifdef PME_SSE_SPREAD_GATHER
1556 #ifdef PME_SSE_UNALIGNED
1557 #define PME_SPREAD_SSE_ORDER4
1558 #else
1559 #define PME_SPREAD_SSE_ALIGNED
1560 #define PME_ORDER 4
1561 #endif
1562 #include "pme_sse_single.h"
1563 #else
1564                     DO_BSPLINE(4);
1565 #endif
1566                     break;
1567                 case 5:
1568 #ifdef PME_SSE_SPREAD_GATHER
1569 #define PME_SPREAD_SSE_ALIGNED
1570 #define PME_ORDER 5
1571 #include "pme_sse_single.h"
1572 #else
1573                     DO_BSPLINE(5);
1574 #endif
1575                     break;
1576                 default:
1577                     DO_BSPLINE(order);
1578                     break;
1579             }
1580         }
1581     }
1582 }
1583
1584 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1585 {
1586 #ifdef PME_SSE_SPREAD_GATHER
1587     if (pme_order == 5
1588 #ifndef PME_SSE_UNALIGNED
1589         || pme_order == 4
1590 #endif
1591         )
1592     {
1593         /* Round nz up to a multiple of 4 to ensure alignment */
1594         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1595     }
1596 #endif
1597 }
1598
1599 static void set_gridsize_alignment(int *gridsize, int pme_order)
1600 {
1601 #ifdef PME_SSE_SPREAD_GATHER
1602 #ifndef PME_SSE_UNALIGNED
1603     if (pme_order == 4)
1604     {
1605         /* Add extra elements to ensured aligned operations do not go
1606          * beyond the allocated grid size.
1607          * Note that for pme_order=5, the pme grid z-size alignment
1608          * ensures that we will not go beyond the grid size.
1609          */
1610         *gridsize += 4;
1611     }
1612 #endif
1613 #endif
1614 }
1615
1616 static void pmegrid_init(pmegrid_t *grid,
1617                          int cx, int cy, int cz,
1618                          int x0, int y0, int z0,
1619                          int x1, int y1, int z1,
1620                          gmx_bool set_alignment,
1621                          int pme_order,
1622                          real *ptr)
1623 {
1624     int nz, gridsize;
1625
1626     grid->ci[XX]     = cx;
1627     grid->ci[YY]     = cy;
1628     grid->ci[ZZ]     = cz;
1629     grid->offset[XX] = x0;
1630     grid->offset[YY] = y0;
1631     grid->offset[ZZ] = z0;
1632     grid->n[XX]      = x1 - x0 + pme_order - 1;
1633     grid->n[YY]      = y1 - y0 + pme_order - 1;
1634     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1635     copy_ivec(grid->n, grid->s);
1636
1637     nz = grid->s[ZZ];
1638     set_grid_alignment(&nz, pme_order);
1639     if (set_alignment)
1640     {
1641         grid->s[ZZ] = nz;
1642     }
1643     else if (nz != grid->s[ZZ])
1644     {
1645         gmx_incons("pmegrid_init call with an unaligned z size");
1646     }
1647
1648     grid->order = pme_order;
1649     if (ptr == NULL)
1650     {
1651         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1652         set_gridsize_alignment(&gridsize, pme_order);
1653         snew_aligned(grid->grid, gridsize, 16);
1654     }
1655     else
1656     {
1657         grid->grid = ptr;
1658     }
1659 }
1660
1661 static int div_round_up(int enumerator, int denominator)
1662 {
1663     return (enumerator + denominator - 1)/denominator;
1664 }
1665
1666 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1667                                   ivec nsub)
1668 {
1669     int gsize_opt, gsize;
1670     int nsx, nsy, nsz;
1671     char *env;
1672
1673     gsize_opt = -1;
1674     for (nsx = 1; nsx <= nthread; nsx++)
1675     {
1676         if (nthread % nsx == 0)
1677         {
1678             for (nsy = 1; nsy <= nthread; nsy++)
1679             {
1680                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1681                 {
1682                     nsz = nthread/(nsx*nsy);
1683
1684                     /* Determine the number of grid points per thread */
1685                     gsize =
1686                         (div_round_up(n[XX], nsx) + ovl)*
1687                         (div_round_up(n[YY], nsy) + ovl)*
1688                         (div_round_up(n[ZZ], nsz) + ovl);
1689
1690                     /* Minimize the number of grids points per thread
1691                      * and, secondarily, the number of cuts in minor dimensions.
1692                      */
1693                     if (gsize_opt == -1 ||
1694                         gsize < gsize_opt ||
1695                         (gsize == gsize_opt &&
1696                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1697                     {
1698                         nsub[XX]  = nsx;
1699                         nsub[YY]  = nsy;
1700                         nsub[ZZ]  = nsz;
1701                         gsize_opt = gsize;
1702                     }
1703                 }
1704             }
1705         }
1706     }
1707
1708     env = getenv("GMX_PME_THREAD_DIVISION");
1709     if (env != NULL)
1710     {
1711         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1712     }
1713
1714     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1715     {
1716         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1717     }
1718 }
1719
1720 static void pmegrids_init(pmegrids_t *grids,
1721                           int nx, int ny, int nz, int nz_base,
1722                           int pme_order,
1723                           gmx_bool bUseThreads,
1724                           int nthread,
1725                           int overlap_x,
1726                           int overlap_y)
1727 {
1728     ivec n, n_base, g0, g1;
1729     int t, x, y, z, d, i, tfac;
1730     int max_comm_lines = -1;
1731
1732     n[XX] = nx - (pme_order - 1);
1733     n[YY] = ny - (pme_order - 1);
1734     n[ZZ] = nz - (pme_order - 1);
1735
1736     copy_ivec(n, n_base);
1737     n_base[ZZ] = nz_base;
1738
1739     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1740                  NULL);
1741
1742     grids->nthread = nthread;
1743
1744     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1745
1746     if (bUseThreads)
1747     {
1748         ivec nst;
1749         int gridsize;
1750
1751         for (d = 0; d < DIM; d++)
1752         {
1753             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1754         }
1755         set_grid_alignment(&nst[ZZ], pme_order);
1756
1757         if (debug)
1758         {
1759             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1760                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1761             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1762                     nx, ny, nz,
1763                     nst[XX], nst[YY], nst[ZZ]);
1764         }
1765
1766         snew(grids->grid_th, grids->nthread);
1767         t        = 0;
1768         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1769         set_gridsize_alignment(&gridsize, pme_order);
1770         snew_aligned(grids->grid_all,
1771                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1772                      16);
1773
1774         for (x = 0; x < grids->nc[XX]; x++)
1775         {
1776             for (y = 0; y < grids->nc[YY]; y++)
1777             {
1778                 for (z = 0; z < grids->nc[ZZ]; z++)
1779                 {
1780                     pmegrid_init(&grids->grid_th[t],
1781                                  x, y, z,
1782                                  (n[XX]*(x  ))/grids->nc[XX],
1783                                  (n[YY]*(y  ))/grids->nc[YY],
1784                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1785                                  (n[XX]*(x+1))/grids->nc[XX],
1786                                  (n[YY]*(y+1))/grids->nc[YY],
1787                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1788                                  TRUE,
1789                                  pme_order,
1790                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1791                     t++;
1792                 }
1793             }
1794         }
1795     }
1796     else
1797     {
1798         grids->grid_th = NULL;
1799     }
1800
1801     snew(grids->g2t, DIM);
1802     tfac = 1;
1803     for (d = DIM-1; d >= 0; d--)
1804     {
1805         snew(grids->g2t[d], n[d]);
1806         t = 0;
1807         for (i = 0; i < n[d]; i++)
1808         {
1809             /* The second check should match the parameters
1810              * of the pmegrid_init call above.
1811              */
1812             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1813             {
1814                 t++;
1815             }
1816             grids->g2t[d][i] = t*tfac;
1817         }
1818
1819         tfac *= grids->nc[d];
1820
1821         switch (d)
1822         {
1823             case XX: max_comm_lines = overlap_x;     break;
1824             case YY: max_comm_lines = overlap_y;     break;
1825             case ZZ: max_comm_lines = pme_order - 1; break;
1826         }
1827         grids->nthread_comm[d] = 0;
1828         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1829                grids->nthread_comm[d] < grids->nc[d])
1830         {
1831             grids->nthread_comm[d]++;
1832         }
1833         if (debug != NULL)
1834         {
1835             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1836                     'x'+d, grids->nthread_comm[d]);
1837         }
1838         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1839          * work, but this is not a problematic restriction.
1840          */
1841         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1842         {
1843             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1844         }
1845     }
1846 }
1847
1848
1849 static void pmegrids_destroy(pmegrids_t *grids)
1850 {
1851     int t;
1852
1853     if (grids->grid.grid != NULL)
1854     {
1855         sfree(grids->grid.grid);
1856
1857         if (grids->nthread > 0)
1858         {
1859             for (t = 0; t < grids->nthread; t++)
1860             {
1861                 sfree(grids->grid_th[t].grid);
1862             }
1863             sfree(grids->grid_th);
1864         }
1865     }
1866 }
1867
1868
1869 static void realloc_work(pme_work_t *work, int nkx)
1870 {
1871     if (nkx > work->nalloc)
1872     {
1873         work->nalloc = nkx;
1874         srenew(work->mhx, work->nalloc);
1875         srenew(work->mhy, work->nalloc);
1876         srenew(work->mhz, work->nalloc);
1877         srenew(work->m2, work->nalloc);
1878         /* Allocate an aligned pointer for SIMD operations, including extra
1879          * elements at the end for padding.
1880          */
1881 #ifdef PME_SIMD
1882 #define ALIGN_HERE  GMX_SIMD_WIDTH_HERE
1883 #else
1884 /* We can use any alignment, apart from 0, so we use 4 */
1885 #define ALIGN_HERE  4
1886 #endif
1887         sfree_aligned(work->denom);
1888         sfree_aligned(work->tmp1);
1889         sfree_aligned(work->eterm);
1890         snew_aligned(work->denom, work->nalloc+ALIGN_HERE, ALIGN_HERE*sizeof(real));
1891         snew_aligned(work->tmp1,  work->nalloc+ALIGN_HERE, ALIGN_HERE*sizeof(real));
1892         snew_aligned(work->eterm, work->nalloc+ALIGN_HERE, ALIGN_HERE*sizeof(real));
1893         srenew(work->m2inv, work->nalloc);
1894     }
1895 }
1896
1897
1898 static void free_work(pme_work_t *work)
1899 {
1900     sfree(work->mhx);
1901     sfree(work->mhy);
1902     sfree(work->mhz);
1903     sfree(work->m2);
1904     sfree_aligned(work->denom);
1905     sfree_aligned(work->tmp1);
1906     sfree_aligned(work->eterm);
1907     sfree(work->m2inv);
1908 }
1909
1910
1911 #ifdef PME_SIMD
1912 /* Calculate exponentials through SIMD */
1913 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1914 {
1915     {
1916         const gmx_mm_pr two = gmx_set1_pr(2.0);
1917         gmx_mm_pr f_simd;
1918         gmx_mm_pr lu;
1919         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1920         int kx;
1921         f_simd = gmx_load1_pr(&f);
1922         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1923         {
1924             tmp_d1   = gmx_load_pr(d_aligned+kx);
1925             d_inv    = gmx_inv_pr(tmp_d1);
1926             tmp_r    = gmx_load_pr(r_aligned+kx);
1927             tmp_r    = gmx_exp_pr(tmp_r);
1928             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1929             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1930             gmx_store_pr(e_aligned+kx, tmp_e);
1931         }
1932     }
1933 }
1934 #else
1935 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1936 {
1937     int kx;
1938     for (kx = start; kx < end; kx++)
1939     {
1940         d[kx] = 1.0/d[kx];
1941     }
1942     for (kx = start; kx < end; kx++)
1943     {
1944         r[kx] = exp(r[kx]);
1945     }
1946     for (kx = start; kx < end; kx++)
1947     {
1948         e[kx] = f*r[kx]*d[kx];
1949     }
1950 }
1951 #endif
1952
1953
1954 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1955                          real ewaldcoeff, real vol,
1956                          gmx_bool bEnerVir,
1957                          int nthread, int thread)
1958 {
1959     /* do recip sum over local cells in grid */
1960     /* y major, z middle, x minor or continuous */
1961     t_complex *p0;
1962     int     kx, ky, kz, maxkx, maxky, maxkz;
1963     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1964     real    mx, my, mz;
1965     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1966     real    ets2, struct2, vfactor, ets2vf;
1967     real    d1, d2, energy = 0;
1968     real    by, bz;
1969     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1970     real    rxx, ryx, ryy, rzx, rzy, rzz;
1971     pme_work_t *work;
1972     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1973     real    mhxk, mhyk, mhzk, m2k;
1974     real    corner_fac;
1975     ivec    complex_order;
1976     ivec    local_ndata, local_offset, local_size;
1977     real    elfac;
1978
1979     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1980
1981     nx = pme->nkx;
1982     ny = pme->nky;
1983     nz = pme->nkz;
1984
1985     /* Dimensions should be identical for A/B grid, so we just use A here */
1986     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1987                                       complex_order,
1988                                       local_ndata,
1989                                       local_offset,
1990                                       local_size);
1991
1992     rxx = pme->recipbox[XX][XX];
1993     ryx = pme->recipbox[YY][XX];
1994     ryy = pme->recipbox[YY][YY];
1995     rzx = pme->recipbox[ZZ][XX];
1996     rzy = pme->recipbox[ZZ][YY];
1997     rzz = pme->recipbox[ZZ][ZZ];
1998
1999     maxkx = (nx+1)/2;
2000     maxky = (ny+1)/2;
2001     maxkz = nz/2+1;
2002
2003     work  = &pme->work[thread];
2004     mhx   = work->mhx;
2005     mhy   = work->mhy;
2006     mhz   = work->mhz;
2007     m2    = work->m2;
2008     denom = work->denom;
2009     tmp1  = work->tmp1;
2010     eterm = work->eterm;
2011     m2inv = work->m2inv;
2012
2013     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2014     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2015
2016     for (iyz = iyz0; iyz < iyz1; iyz++)
2017     {
2018         iy = iyz/local_ndata[ZZ];
2019         iz = iyz - iy*local_ndata[ZZ];
2020
2021         ky = iy + local_offset[YY];
2022
2023         if (ky < maxky)
2024         {
2025             my = ky;
2026         }
2027         else
2028         {
2029             my = (ky - ny);
2030         }
2031
2032         by = M_PI*vol*pme->bsp_mod[YY][ky];
2033
2034         kz = iz + local_offset[ZZ];
2035
2036         mz = kz;
2037
2038         bz = pme->bsp_mod[ZZ][kz];
2039
2040         /* 0.5 correction for corner points */
2041         corner_fac = 1;
2042         if (kz == 0 || kz == (nz+1)/2)
2043         {
2044             corner_fac = 0.5;
2045         }
2046
2047         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2048
2049         /* We should skip the k-space point (0,0,0) */
2050         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2051         {
2052             kxstart = local_offset[XX];
2053         }
2054         else
2055         {
2056             kxstart = local_offset[XX] + 1;
2057             p0++;
2058         }
2059         kxend = local_offset[XX] + local_ndata[XX];
2060
2061         if (bEnerVir)
2062         {
2063             /* More expensive inner loop, especially because of the storage
2064              * of the mh elements in array's.
2065              * Because x is the minor grid index, all mh elements
2066              * depend on kx for triclinic unit cells.
2067              */
2068
2069             /* Two explicit loops to avoid a conditional inside the loop */
2070             for (kx = kxstart; kx < maxkx; kx++)
2071             {
2072                 mx = kx;
2073
2074                 mhxk      = mx * rxx;
2075                 mhyk      = mx * ryx + my * ryy;
2076                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2077                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2078                 mhx[kx]   = mhxk;
2079                 mhy[kx]   = mhyk;
2080                 mhz[kx]   = mhzk;
2081                 m2[kx]    = m2k;
2082                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2083                 tmp1[kx]  = -factor*m2k;
2084             }
2085
2086             for (kx = maxkx; kx < kxend; kx++)
2087             {
2088                 mx = (kx - nx);
2089
2090                 mhxk      = mx * rxx;
2091                 mhyk      = mx * ryx + my * ryy;
2092                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2093                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2094                 mhx[kx]   = mhxk;
2095                 mhy[kx]   = mhyk;
2096                 mhz[kx]   = mhzk;
2097                 m2[kx]    = m2k;
2098                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2099                 tmp1[kx]  = -factor*m2k;
2100             }
2101
2102             for (kx = kxstart; kx < kxend; kx++)
2103             {
2104                 m2inv[kx] = 1.0/m2[kx];
2105             }
2106
2107             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2108
2109             for (kx = kxstart; kx < kxend; kx++, p0++)
2110             {
2111                 d1      = p0->re;
2112                 d2      = p0->im;
2113
2114                 p0->re  = d1*eterm[kx];
2115                 p0->im  = d2*eterm[kx];
2116
2117                 struct2 = 2.0*(d1*d1+d2*d2);
2118
2119                 tmp1[kx] = eterm[kx]*struct2;
2120             }
2121
2122             for (kx = kxstart; kx < kxend; kx++)
2123             {
2124                 ets2     = corner_fac*tmp1[kx];
2125                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2126                 energy  += ets2;
2127
2128                 ets2vf   = ets2*vfactor;
2129                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2130                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2131                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2132                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2133                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2134                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2135             }
2136         }
2137         else
2138         {
2139             /* We don't need to calculate the energy and the virial.
2140              * In this case the triclinic overhead is small.
2141              */
2142
2143             /* Two explicit loops to avoid a conditional inside the loop */
2144
2145             for (kx = kxstart; kx < maxkx; kx++)
2146             {
2147                 mx = kx;
2148
2149                 mhxk      = mx * rxx;
2150                 mhyk      = mx * ryx + my * ryy;
2151                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2152                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2153                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2154                 tmp1[kx]  = -factor*m2k;
2155             }
2156
2157             for (kx = maxkx; kx < kxend; kx++)
2158             {
2159                 mx = (kx - nx);
2160
2161                 mhxk      = mx * rxx;
2162                 mhyk      = mx * ryx + my * ryy;
2163                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2164                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2165                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2166                 tmp1[kx]  = -factor*m2k;
2167             }
2168
2169             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2170
2171             for (kx = kxstart; kx < kxend; kx++, p0++)
2172             {
2173                 d1      = p0->re;
2174                 d2      = p0->im;
2175
2176                 p0->re  = d1*eterm[kx];
2177                 p0->im  = d2*eterm[kx];
2178             }
2179         }
2180     }
2181
2182     if (bEnerVir)
2183     {
2184         /* Update virial with local values.
2185          * The virial is symmetric by definition.
2186          * this virial seems ok for isotropic scaling, but I'm
2187          * experiencing problems on semiisotropic membranes.
2188          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2189          */
2190         work->vir[XX][XX] = 0.25*virxx;
2191         work->vir[YY][YY] = 0.25*viryy;
2192         work->vir[ZZ][ZZ] = 0.25*virzz;
2193         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2194         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2195         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2196
2197         /* This energy should be corrected for a charged system */
2198         work->energy = 0.5*energy;
2199     }
2200
2201     /* Return the loop count */
2202     return local_ndata[YY]*local_ndata[XX];
2203 }
2204
2205 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2206                              real *mesh_energy, matrix vir)
2207 {
2208     /* This function sums output over threads
2209      * and should therefore only be called after thread synchronization.
2210      */
2211     int thread;
2212
2213     *mesh_energy = pme->work[0].energy;
2214     copy_mat(pme->work[0].vir, vir);
2215
2216     for (thread = 1; thread < nthread; thread++)
2217     {
2218         *mesh_energy += pme->work[thread].energy;
2219         m_add(vir, pme->work[thread].vir, vir);
2220     }
2221 }
2222
2223 #define DO_FSPLINE(order)                      \
2224     for (ithx = 0; (ithx < order); ithx++)              \
2225     {                                              \
2226         index_x = (i0+ithx)*pny*pnz;               \
2227         tx      = thx[ithx];                       \
2228         dx      = dthx[ithx];                      \
2229                                                \
2230         for (ithy = 0; (ithy < order); ithy++)          \
2231         {                                          \
2232             index_xy = index_x+(j0+ithy)*pnz;      \
2233             ty       = thy[ithy];                  \
2234             dy       = dthy[ithy];                 \
2235             fxy1     = fz1 = 0;                    \
2236                                                \
2237             for (ithz = 0; (ithz < order); ithz++)      \
2238             {                                      \
2239                 gval  = grid[index_xy+(k0+ithz)];  \
2240                 fxy1 += thz[ithz]*gval;            \
2241                 fz1  += dthz[ithz]*gval;           \
2242             }                                      \
2243             fx += dx*ty*fxy1;                      \
2244             fy += tx*dy*fxy1;                      \
2245             fz += tx*ty*fz1;                       \
2246         }                                          \
2247     }
2248
2249
2250 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2251                               gmx_bool bClearF, pme_atomcomm_t *atc,
2252                               splinedata_t *spline,
2253                               real scale)
2254 {
2255     /* sum forces for local particles */
2256     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2257     int     index_x, index_xy;
2258     int     nx, ny, nz, pnx, pny, pnz;
2259     int *   idxptr;
2260     real    tx, ty, dx, dy, qn;
2261     real    fx, fy, fz, gval;
2262     real    fxy1, fz1;
2263     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2264     int     norder;
2265     real    rxx, ryx, ryy, rzx, rzy, rzz;
2266     int     order;
2267
2268     pme_spline_work_t *work;
2269
2270     work = pme->spline_work;
2271
2272     order = pme->pme_order;
2273     thx   = spline->theta[XX];
2274     thy   = spline->theta[YY];
2275     thz   = spline->theta[ZZ];
2276     dthx  = spline->dtheta[XX];
2277     dthy  = spline->dtheta[YY];
2278     dthz  = spline->dtheta[ZZ];
2279     nx    = pme->nkx;
2280     ny    = pme->nky;
2281     nz    = pme->nkz;
2282     pnx   = pme->pmegrid_nx;
2283     pny   = pme->pmegrid_ny;
2284     pnz   = pme->pmegrid_nz;
2285
2286     rxx   = pme->recipbox[XX][XX];
2287     ryx   = pme->recipbox[YY][XX];
2288     ryy   = pme->recipbox[YY][YY];
2289     rzx   = pme->recipbox[ZZ][XX];
2290     rzy   = pme->recipbox[ZZ][YY];
2291     rzz   = pme->recipbox[ZZ][ZZ];
2292
2293     for (nn = 0; nn < spline->n; nn++)
2294     {
2295         n  = spline->ind[nn];
2296         qn = scale*atc->q[n];
2297
2298         if (bClearF)
2299         {
2300             atc->f[n][XX] = 0;
2301             atc->f[n][YY] = 0;
2302             atc->f[n][ZZ] = 0;
2303         }
2304         if (qn != 0)
2305         {
2306             fx     = 0;
2307             fy     = 0;
2308             fz     = 0;
2309             idxptr = atc->idx[n];
2310             norder = nn*order;
2311
2312             i0   = idxptr[XX];
2313             j0   = idxptr[YY];
2314             k0   = idxptr[ZZ];
2315
2316             /* Pointer arithmetic alert, next six statements */
2317             thx  = spline->theta[XX] + norder;
2318             thy  = spline->theta[YY] + norder;
2319             thz  = spline->theta[ZZ] + norder;
2320             dthx = spline->dtheta[XX] + norder;
2321             dthy = spline->dtheta[YY] + norder;
2322             dthz = spline->dtheta[ZZ] + norder;
2323
2324             switch (order)
2325             {
2326                 case 4:
2327 #ifdef PME_SSE_SPREAD_GATHER
2328 #ifdef PME_SSE_UNALIGNED
2329 #define PME_GATHER_F_SSE_ORDER4
2330 #else
2331 #define PME_GATHER_F_SSE_ALIGNED
2332 #define PME_ORDER 4
2333 #endif
2334 #include "pme_sse_single.h"
2335 #else
2336                     DO_FSPLINE(4);
2337 #endif
2338                     break;
2339                 case 5:
2340 #ifdef PME_SSE_SPREAD_GATHER
2341 #define PME_GATHER_F_SSE_ALIGNED
2342 #define PME_ORDER 5
2343 #include "pme_sse_single.h"
2344 #else
2345                     DO_FSPLINE(5);
2346 #endif
2347                     break;
2348                 default:
2349                     DO_FSPLINE(order);
2350                     break;
2351             }
2352
2353             atc->f[n][XX] += -qn*( fx*nx*rxx );
2354             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2355             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2356         }
2357     }
2358     /* Since the energy and not forces are interpolated
2359      * the net force might not be exactly zero.
2360      * This can be solved by also interpolating F, but
2361      * that comes at a cost.
2362      * A better hack is to remove the net force every
2363      * step, but that must be done at a higher level
2364      * since this routine doesn't see all atoms if running
2365      * in parallel. Don't know how important it is?  EL 990726
2366      */
2367 }
2368
2369
2370 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2371                                    pme_atomcomm_t *atc)
2372 {
2373     splinedata_t *spline;
2374     int     n, ithx, ithy, ithz, i0, j0, k0;
2375     int     index_x, index_xy;
2376     int *   idxptr;
2377     real    energy, pot, tx, ty, qn, gval;
2378     real    *thx, *thy, *thz;
2379     int     norder;
2380     int     order;
2381
2382     spline = &atc->spline[0];
2383
2384     order = pme->pme_order;
2385
2386     energy = 0;
2387     for (n = 0; (n < atc->n); n++)
2388     {
2389         qn      = atc->q[n];
2390
2391         if (qn != 0)
2392         {
2393             idxptr = atc->idx[n];
2394             norder = n*order;
2395
2396             i0   = idxptr[XX];
2397             j0   = idxptr[YY];
2398             k0   = idxptr[ZZ];
2399
2400             /* Pointer arithmetic alert, next three statements */
2401             thx  = spline->theta[XX] + norder;
2402             thy  = spline->theta[YY] + norder;
2403             thz  = spline->theta[ZZ] + norder;
2404
2405             pot = 0;
2406             for (ithx = 0; (ithx < order); ithx++)
2407             {
2408                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2409                 tx      = thx[ithx];
2410
2411                 for (ithy = 0; (ithy < order); ithy++)
2412                 {
2413                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2414                     ty       = thy[ithy];
2415
2416                     for (ithz = 0; (ithz < order); ithz++)
2417                     {
2418                         gval  = grid[index_xy+(k0+ithz)];
2419                         pot  += tx*ty*thz[ithz]*gval;
2420                     }
2421
2422                 }
2423             }
2424
2425             energy += pot*qn;
2426         }
2427     }
2428
2429     return energy;
2430 }
2431
2432 /* Macro to force loop unrolling by fixing order.
2433  * This gives a significant performance gain.
2434  */
2435 #define CALC_SPLINE(order)                     \
2436     {                                              \
2437         int j, k, l;                                 \
2438         real dr, div;                               \
2439         real data[PME_ORDER_MAX];                  \
2440         real ddata[PME_ORDER_MAX];                 \
2441                                                \
2442         for (j = 0; (j < DIM); j++)                     \
2443         {                                          \
2444             dr  = xptr[j];                         \
2445                                                \
2446             /* dr is relative offset from lower cell limit */ \
2447             data[order-1] = 0;                     \
2448             data[1]       = dr;                          \
2449             data[0]       = 1 - dr;                      \
2450                                                \
2451             for (k = 3; (k < order); k++)               \
2452             {                                      \
2453                 div       = 1.0/(k - 1.0);               \
2454                 data[k-1] = div*dr*data[k-2];      \
2455                 for (l = 1; (l < (k-1)); l++)           \
2456                 {                                  \
2457                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2458                                        data[k-l-1]);                \
2459                 }                                  \
2460                 data[0] = div*(1-dr)*data[0];      \
2461             }                                      \
2462             /* differentiate */                    \
2463             ddata[0] = -data[0];                   \
2464             for (k = 1; (k < order); k++)               \
2465             {                                      \
2466                 ddata[k] = data[k-1] - data[k];    \
2467             }                                      \
2468                                                \
2469             div           = 1.0/(order - 1);                 \
2470             data[order-1] = div*dr*data[order-2];  \
2471             for (l = 1; (l < (order-1)); l++)           \
2472             {                                      \
2473                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2474                                        (order-l-dr)*data[order-l-1]); \
2475             }                                      \
2476             data[0] = div*(1 - dr)*data[0];        \
2477                                                \
2478             for (k = 0; k < order; k++)                 \
2479             {                                      \
2480                 theta[j][i*order+k]  = data[k];    \
2481                 dtheta[j][i*order+k] = ddata[k];   \
2482             }                                      \
2483         }                                          \
2484     }
2485
2486 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2487                    rvec fractx[], int nr, int ind[], real charge[],
2488                    gmx_bool bFreeEnergy)
2489 {
2490     /* construct splines for local atoms */
2491     int  i, ii;
2492     real *xptr;
2493
2494     for (i = 0; i < nr; i++)
2495     {
2496         /* With free energy we do not use the charge check.
2497          * In most cases this will be more efficient than calling make_bsplines
2498          * twice, since usually more than half the particles have charges.
2499          */
2500         ii = ind[i];
2501         if (bFreeEnergy || charge[ii] != 0.0)
2502         {
2503             xptr = fractx[ii];
2504             switch (order)
2505             {
2506                 case 4:  CALC_SPLINE(4);     break;
2507                 case 5:  CALC_SPLINE(5);     break;
2508                 default: CALC_SPLINE(order); break;
2509             }
2510         }
2511     }
2512 }
2513
2514
2515 void make_dft_mod(real *mod, real *data, int ndata)
2516 {
2517     int i, j;
2518     real sc, ss, arg;
2519
2520     for (i = 0; i < ndata; i++)
2521     {
2522         sc = ss = 0;
2523         for (j = 0; j < ndata; j++)
2524         {
2525             arg = (2.0*M_PI*i*j)/ndata;
2526             sc += data[j]*cos(arg);
2527             ss += data[j]*sin(arg);
2528         }
2529         mod[i] = sc*sc+ss*ss;
2530     }
2531     for (i = 0; i < ndata; i++)
2532     {
2533         if (mod[i] < 1e-7)
2534         {
2535             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2536         }
2537     }
2538 }
2539
2540
2541 static void make_bspline_moduli(splinevec bsp_mod,
2542                                 int nx, int ny, int nz, int order)
2543 {
2544     int nmax = max(nx, max(ny, nz));
2545     real *data, *ddata, *bsp_data;
2546     int i, k, l;
2547     real div;
2548
2549     snew(data, order);
2550     snew(ddata, order);
2551     snew(bsp_data, nmax);
2552
2553     data[order-1] = 0;
2554     data[1]       = 0;
2555     data[0]       = 1;
2556
2557     for (k = 3; k < order; k++)
2558     {
2559         div       = 1.0/(k-1.0);
2560         data[k-1] = 0;
2561         for (l = 1; l < (k-1); l++)
2562         {
2563             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2564         }
2565         data[0] = div*data[0];
2566     }
2567     /* differentiate */
2568     ddata[0] = -data[0];
2569     for (k = 1; k < order; k++)
2570     {
2571         ddata[k] = data[k-1]-data[k];
2572     }
2573     div           = 1.0/(order-1);
2574     data[order-1] = 0;
2575     for (l = 1; l < (order-1); l++)
2576     {
2577         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2578     }
2579     data[0] = div*data[0];
2580
2581     for (i = 0; i < nmax; i++)
2582     {
2583         bsp_data[i] = 0;
2584     }
2585     for (i = 1; i <= order; i++)
2586     {
2587         bsp_data[i] = data[i-1];
2588     }
2589
2590     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2591     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2592     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2593
2594     sfree(data);
2595     sfree(ddata);
2596     sfree(bsp_data);
2597 }
2598
2599
2600 /* Return the P3M optimal influence function */
2601 static double do_p3m_influence(double z, int order)
2602 {
2603     double z2, z4;
2604
2605     z2 = z*z;
2606     z4 = z2*z2;
2607
2608     /* The formula and most constants can be found in:
2609      * Ballenegger et al., JCTC 8, 936 (2012)
2610      */
2611     switch (order)
2612     {
2613         case 2:
2614             return 1.0 - 2.0*z2/3.0;
2615             break;
2616         case 3:
2617             return 1.0 - z2 + 2.0*z4/15.0;
2618             break;
2619         case 4:
2620             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2621             break;
2622         case 5:
2623             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2624             break;
2625         case 6:
2626             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2627             break;
2628         case 7:
2629             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2630         case 8:
2631             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2632             break;
2633     }
2634
2635     return 0.0;
2636 }
2637
2638 /* Calculate the P3M B-spline moduli for one dimension */
2639 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2640 {
2641     double zarg, zai, sinzai, infl;
2642     int    maxk, i;
2643
2644     if (order > 8)
2645     {
2646         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2647     }
2648
2649     zarg = M_PI/n;
2650
2651     maxk = (n + 1)/2;
2652
2653     for (i = -maxk; i < 0; i++)
2654     {
2655         zai          = zarg*i;
2656         sinzai       = sin(zai);
2657         infl         = do_p3m_influence(sinzai, order);
2658         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2659     }
2660     bsp_mod[0] = 1.0;
2661     for (i = 1; i < maxk; i++)
2662     {
2663         zai        = zarg*i;
2664         sinzai     = sin(zai);
2665         infl       = do_p3m_influence(sinzai, order);
2666         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2667     }
2668 }
2669
2670 /* Calculate the P3M B-spline moduli */
2671 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2672                                     int nx, int ny, int nz, int order)
2673 {
2674     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2675     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2676     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2677 }
2678
2679
2680 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2681 {
2682     int nslab, n, i;
2683     int fw, bw;
2684
2685     nslab = atc->nslab;
2686
2687     n = 0;
2688     for (i = 1; i <= nslab/2; i++)
2689     {
2690         fw = (atc->nodeid + i) % nslab;
2691         bw = (atc->nodeid - i + nslab) % nslab;
2692         if (n < nslab - 1)
2693         {
2694             atc->node_dest[n] = fw;
2695             atc->node_src[n]  = bw;
2696             n++;
2697         }
2698         if (n < nslab - 1)
2699         {
2700             atc->node_dest[n] = bw;
2701             atc->node_src[n]  = fw;
2702             n++;
2703         }
2704     }
2705 }
2706
2707 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2708 {
2709     int thread;
2710
2711     if (NULL != log)
2712     {
2713         fprintf(log, "Destroying PME data structures.\n");
2714     }
2715
2716     sfree((*pmedata)->nnx);
2717     sfree((*pmedata)->nny);
2718     sfree((*pmedata)->nnz);
2719
2720     pmegrids_destroy(&(*pmedata)->pmegridA);
2721
2722     sfree((*pmedata)->fftgridA);
2723     sfree((*pmedata)->cfftgridA);
2724     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2725
2726     if ((*pmedata)->pmegridB.grid.grid != NULL)
2727     {
2728         pmegrids_destroy(&(*pmedata)->pmegridB);
2729         sfree((*pmedata)->fftgridB);
2730         sfree((*pmedata)->cfftgridB);
2731         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2732     }
2733     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2734     {
2735         free_work(&(*pmedata)->work[thread]);
2736     }
2737     sfree((*pmedata)->work);
2738
2739     sfree(*pmedata);
2740     *pmedata = NULL;
2741
2742     return 0;
2743 }
2744
2745 static int mult_up(int n, int f)
2746 {
2747     return ((n + f - 1)/f)*f;
2748 }
2749
2750
2751 static double pme_load_imbalance(gmx_pme_t pme)
2752 {
2753     int    nma, nmi;
2754     double n1, n2, n3;
2755
2756     nma = pme->nnodes_major;
2757     nmi = pme->nnodes_minor;
2758
2759     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2760     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2761     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2762
2763     /* pme_solve is roughly double the cost of an fft */
2764
2765     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2766 }
2767
2768 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2769                           int dimind, gmx_bool bSpread)
2770 {
2771     int nk, k, s, thread;
2772
2773     atc->dimind    = dimind;
2774     atc->nslab     = 1;
2775     atc->nodeid    = 0;
2776     atc->pd_nalloc = 0;
2777 #ifdef GMX_MPI
2778     if (pme->nnodes > 1)
2779     {
2780         atc->mpi_comm = pme->mpi_comm_d[dimind];
2781         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2782         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2783     }
2784     if (debug)
2785     {
2786         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2787     }
2788 #endif
2789
2790     atc->bSpread   = bSpread;
2791     atc->pme_order = pme->pme_order;
2792
2793     if (atc->nslab > 1)
2794     {
2795         /* These three allocations are not required for particle decomp. */
2796         snew(atc->node_dest, atc->nslab);
2797         snew(atc->node_src, atc->nslab);
2798         setup_coordinate_communication(atc);
2799
2800         snew(atc->count_thread, pme->nthread);
2801         for (thread = 0; thread < pme->nthread; thread++)
2802         {
2803             snew(atc->count_thread[thread], atc->nslab);
2804         }
2805         atc->count = atc->count_thread[0];
2806         snew(atc->rcount, atc->nslab);
2807         snew(atc->buf_index, atc->nslab);
2808     }
2809
2810     atc->nthread = pme->nthread;
2811     if (atc->nthread > 1)
2812     {
2813         snew(atc->thread_plist, atc->nthread);
2814     }
2815     snew(atc->spline, atc->nthread);
2816     for (thread = 0; thread < atc->nthread; thread++)
2817     {
2818         if (atc->nthread > 1)
2819         {
2820             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2821             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2822         }
2823         snew(atc->spline[thread].thread_one, pme->nthread);
2824         atc->spline[thread].thread_one[thread] = 1;
2825     }
2826 }
2827
2828 static void
2829 init_overlap_comm(pme_overlap_t *  ol,
2830                   int              norder,
2831 #ifdef GMX_MPI
2832                   MPI_Comm         comm,
2833 #endif
2834                   int              nnodes,
2835                   int              nodeid,
2836                   int              ndata,
2837                   int              commplainsize)
2838 {
2839     int lbnd, rbnd, maxlr, b, i;
2840     int exten;
2841     int nn, nk;
2842     pme_grid_comm_t *pgc;
2843     gmx_bool bCont;
2844     int fft_start, fft_end, send_index1, recv_index1;
2845 #ifdef GMX_MPI
2846     MPI_Status stat;
2847
2848     ol->mpi_comm = comm;
2849 #endif
2850
2851     ol->nnodes = nnodes;
2852     ol->nodeid = nodeid;
2853
2854     /* Linear translation of the PME grid won't affect reciprocal space
2855      * calculations, so to optimize we only interpolate "upwards",
2856      * which also means we only have to consider overlap in one direction.
2857      * I.e., particles on this node might also be spread to grid indices
2858      * that belong to higher nodes (modulo nnodes)
2859      */
2860
2861     snew(ol->s2g0, ol->nnodes+1);
2862     snew(ol->s2g1, ol->nnodes);
2863     if (debug)
2864     {
2865         fprintf(debug, "PME slab boundaries:");
2866     }
2867     for (i = 0; i < nnodes; i++)
2868     {
2869         /* s2g0 the local interpolation grid start.
2870          * s2g1 the local interpolation grid end.
2871          * Because grid overlap communication only goes forward,
2872          * the grid the slabs for fft's should be rounded down.
2873          */
2874         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2875         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2876
2877         if (debug)
2878         {
2879             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2880         }
2881     }
2882     ol->s2g0[nnodes] = ndata;
2883     if (debug)
2884     {
2885         fprintf(debug, "\n");
2886     }
2887
2888     /* Determine with how many nodes we need to communicate the grid overlap */
2889     b = 0;
2890     do
2891     {
2892         b++;
2893         bCont = FALSE;
2894         for (i = 0; i < nnodes; i++)
2895         {
2896             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2897                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2898             {
2899                 bCont = TRUE;
2900             }
2901         }
2902     }
2903     while (bCont && b < nnodes);
2904     ol->noverlap_nodes = b - 1;
2905
2906     snew(ol->send_id, ol->noverlap_nodes);
2907     snew(ol->recv_id, ol->noverlap_nodes);
2908     for (b = 0; b < ol->noverlap_nodes; b++)
2909     {
2910         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2911         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2912     }
2913     snew(ol->comm_data, ol->noverlap_nodes);
2914
2915     ol->send_size = 0;
2916     for (b = 0; b < ol->noverlap_nodes; b++)
2917     {
2918         pgc = &ol->comm_data[b];
2919         /* Send */
2920         fft_start        = ol->s2g0[ol->send_id[b]];
2921         fft_end          = ol->s2g0[ol->send_id[b]+1];
2922         if (ol->send_id[b] < nodeid)
2923         {
2924             fft_start += ndata;
2925             fft_end   += ndata;
2926         }
2927         send_index1       = ol->s2g1[nodeid];
2928         send_index1       = min(send_index1, fft_end);
2929         pgc->send_index0  = fft_start;
2930         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2931         ol->send_size    += pgc->send_nindex;
2932
2933         /* We always start receiving to the first index of our slab */
2934         fft_start        = ol->s2g0[ol->nodeid];
2935         fft_end          = ol->s2g0[ol->nodeid+1];
2936         recv_index1      = ol->s2g1[ol->recv_id[b]];
2937         if (ol->recv_id[b] > nodeid)
2938         {
2939             recv_index1 -= ndata;
2940         }
2941         recv_index1      = min(recv_index1, fft_end);
2942         pgc->recv_index0 = fft_start;
2943         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2944     }
2945
2946 #ifdef GMX_MPI
2947     /* Communicate the buffer sizes to receive */
2948     for (b = 0; b < ol->noverlap_nodes; b++)
2949     {
2950         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2951                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2952                      ol->mpi_comm, &stat);
2953     }
2954 #endif
2955
2956     /* For non-divisible grid we need pme_order iso pme_order-1 */
2957     snew(ol->sendbuf, norder*commplainsize);
2958     snew(ol->recvbuf, norder*commplainsize);
2959 }
2960
2961 static void
2962 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2963                               int **global_to_local,
2964                               real **fraction_shift)
2965 {
2966     int i;
2967     int * gtl;
2968     real * fsh;
2969
2970     snew(gtl, 5*n);
2971     snew(fsh, 5*n);
2972     for (i = 0; (i < 5*n); i++)
2973     {
2974         /* Determine the global to local grid index */
2975         gtl[i] = (i - local_start + n) % n;
2976         /* For coordinates that fall within the local grid the fraction
2977          * is correct, we don't need to shift it.
2978          */
2979         fsh[i] = 0;
2980         if (local_range < n)
2981         {
2982             /* Due to rounding issues i could be 1 beyond the lower or
2983              * upper boundary of the local grid. Correct the index for this.
2984              * If we shift the index, we need to shift the fraction by
2985              * the same amount in the other direction to not affect
2986              * the weights.
2987              * Note that due to this shifting the weights at the end of
2988              * the spline might change, but that will only involve values
2989              * between zero and values close to the precision of a real,
2990              * which is anyhow the accuracy of the whole mesh calculation.
2991              */
2992             /* With local_range=0 we should not change i=local_start */
2993             if (i % n != local_start)
2994             {
2995                 if (gtl[i] == n-1)
2996                 {
2997                     gtl[i] = 0;
2998                     fsh[i] = -1;
2999                 }
3000                 else if (gtl[i] == local_range)
3001                 {
3002                     gtl[i] = local_range - 1;
3003                     fsh[i] = 1;
3004                 }
3005             }
3006         }
3007     }
3008
3009     *global_to_local = gtl;
3010     *fraction_shift  = fsh;
3011 }
3012
3013 static pme_spline_work_t *make_pme_spline_work(int order)
3014 {
3015     pme_spline_work_t *work;
3016
3017 #ifdef PME_SSE_SPREAD_GATHER
3018     float  tmp[8];
3019     __m128 zero_SSE;
3020     int    of, i;
3021
3022     snew_aligned(work, 1, 16);
3023
3024     zero_SSE = _mm_setzero_ps();
3025
3026     /* Generate bit masks to mask out the unused grid entries,
3027      * as we only operate on order of the 8 grid entries that are
3028      * load into 2 SSE float registers.
3029      */
3030     for (of = 0; of < 8-(order-1); of++)
3031     {
3032         for (i = 0; i < 8; i++)
3033         {
3034             tmp[i] = (i >= of && i < of+order ? 1 : 0);
3035         }
3036         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
3037         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
3038         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
3039         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
3040     }
3041 #else
3042     work = NULL;
3043 #endif
3044
3045     return work;
3046 }
3047
3048 void gmx_pme_check_restrictions(int pme_order,
3049                                 int nkx, int nky, int nkz,
3050                                 int nnodes_major,
3051                                 int nnodes_minor,
3052                                 gmx_bool bUseThreads,
3053                                 gmx_bool bFatal,
3054                                 gmx_bool *bValidSettings)
3055 {
3056     if (pme_order > PME_ORDER_MAX)
3057     {
3058         if (!bFatal)
3059         {
3060             *bValidSettings = FALSE;
3061             return;
3062         }
3063         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3064                   pme_order, PME_ORDER_MAX);
3065     }
3066
3067     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3068         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3069         nkz <= pme_order)
3070     {
3071         if (!bFatal)
3072         {
3073             *bValidSettings = FALSE;
3074             return;
3075         }
3076         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3077                   pme_order);
3078     }
3079
3080     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3081      * We only allow multiple communication pulses in dim 1, not in dim 0.
3082      */
3083     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3084                         nkx != nnodes_major*(pme_order - 1)))
3085     {
3086         if (!bFatal)
3087         {
3088             *bValidSettings = FALSE;
3089             return;
3090         }
3091         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3092                   nkx/(double)nnodes_major, pme_order);
3093     }
3094
3095     if (bValidSettings != NULL)
3096     {
3097         *bValidSettings = TRUE;
3098     }
3099
3100     return;
3101 }
3102
3103 int gmx_pme_init(gmx_pme_t *         pmedata,
3104                  t_commrec *         cr,
3105                  int                 nnodes_major,
3106                  int                 nnodes_minor,
3107                  t_inputrec *        ir,
3108                  int                 homenr,
3109                  gmx_bool            bFreeEnergy,
3110                  gmx_bool            bReproducible,
3111                  int                 nthread)
3112 {
3113     gmx_pme_t pme = NULL;
3114
3115     int  use_threads, sum_use_threads;
3116     ivec ndata;
3117
3118     if (debug)
3119     {
3120         fprintf(debug, "Creating PME data structures.\n");
3121     }
3122     snew(pme, 1);
3123
3124     pme->redist_init         = FALSE;
3125     pme->sum_qgrid_tmp       = NULL;
3126     pme->sum_qgrid_dd_tmp    = NULL;
3127     pme->buf_nalloc          = 0;
3128     pme->redist_buf_nalloc   = 0;
3129
3130     pme->nnodes              = 1;
3131     pme->bPPnode             = TRUE;
3132
3133     pme->nnodes_major        = nnodes_major;
3134     pme->nnodes_minor        = nnodes_minor;
3135
3136 #ifdef GMX_MPI
3137     if (nnodes_major*nnodes_minor > 1)
3138     {
3139         pme->mpi_comm = cr->mpi_comm_mygroup;
3140
3141         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3142         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3143         if (pme->nnodes != nnodes_major*nnodes_minor)
3144         {
3145             gmx_incons("PME node count mismatch");
3146         }
3147     }
3148     else
3149     {
3150         pme->mpi_comm = MPI_COMM_NULL;
3151     }
3152 #endif
3153
3154     if (pme->nnodes == 1)
3155     {
3156 #ifdef GMX_MPI
3157         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3158         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3159 #endif
3160         pme->ndecompdim   = 0;
3161         pme->nodeid_major = 0;
3162         pme->nodeid_minor = 0;
3163 #ifdef GMX_MPI
3164         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3165 #endif
3166     }
3167     else
3168     {
3169         if (nnodes_minor == 1)
3170         {
3171 #ifdef GMX_MPI
3172             pme->mpi_comm_d[0] = pme->mpi_comm;
3173             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3174 #endif
3175             pme->ndecompdim   = 1;
3176             pme->nodeid_major = pme->nodeid;
3177             pme->nodeid_minor = 0;
3178
3179         }
3180         else if (nnodes_major == 1)
3181         {
3182 #ifdef GMX_MPI
3183             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3184             pme->mpi_comm_d[1] = pme->mpi_comm;
3185 #endif
3186             pme->ndecompdim   = 1;
3187             pme->nodeid_major = 0;
3188             pme->nodeid_minor = pme->nodeid;
3189         }
3190         else
3191         {
3192             if (pme->nnodes % nnodes_major != 0)
3193             {
3194                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3195             }
3196             pme->ndecompdim = 2;
3197
3198 #ifdef GMX_MPI
3199             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3200                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3201             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3202                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3203
3204             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3205             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3206             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3207             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3208 #endif
3209         }
3210         pme->bPPnode = (cr->duty & DUTY_PP);
3211     }
3212
3213     pme->nthread = nthread;
3214
3215      /* Check if any of the PME MPI ranks uses threads */
3216     use_threads = (pme->nthread > 1 ? 1 : 0);
3217 #ifdef GMX_MPI
3218     if (pme->nnodes > 1)
3219     {
3220         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3221                       MPI_SUM, pme->mpi_comm);
3222     }
3223     else
3224 #endif
3225     {
3226         sum_use_threads = use_threads;
3227     }
3228     pme->bUseThreads = (sum_use_threads > 0);
3229
3230     if (ir->ePBC == epbcSCREW)
3231     {
3232         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3233     }
3234
3235     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3236     pme->nkx         = ir->nkx;
3237     pme->nky         = ir->nky;
3238     pme->nkz         = ir->nkz;
3239     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3240     pme->pme_order   = ir->pme_order;
3241     pme->epsilon_r   = ir->epsilon_r;
3242
3243     /* If we violate restrictions, generate a fatal error here */
3244     gmx_pme_check_restrictions(pme->pme_order,
3245                                pme->nkx, pme->nky, pme->nkz,
3246                                pme->nnodes_major,
3247                                pme->nnodes_minor,
3248                                pme->bUseThreads,
3249                                TRUE,
3250                                NULL);
3251
3252     if (pme->nnodes > 1)
3253     {
3254         double imbal;
3255
3256 #ifdef GMX_MPI
3257         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3258         MPI_Type_commit(&(pme->rvec_mpi));
3259 #endif
3260
3261         /* Note that the charge spreading and force gathering, which usually
3262          * takes about the same amount of time as FFT+solve_pme,
3263          * is always fully load balanced
3264          * (unless the charge distribution is inhomogeneous).
3265          */
3266
3267         imbal = pme_load_imbalance(pme);
3268         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3269         {
3270             fprintf(stderr,
3271                     "\n"
3272                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3273                     "      For optimal PME load balancing\n"
3274                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3275                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3276                     "\n",
3277                     (int)((imbal-1)*100 + 0.5),
3278                     pme->nkx, pme->nky, pme->nnodes_major,
3279                     pme->nky, pme->nkz, pme->nnodes_minor);
3280         }
3281     }
3282
3283     /* For non-divisible grid we need pme_order iso pme_order-1 */
3284     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3285      * y is always copied through a buffer: we don't need padding in z,
3286      * but we do need the overlap in x because of the communication order.
3287      */
3288     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3289 #ifdef GMX_MPI
3290                       pme->mpi_comm_d[0],
3291 #endif
3292                       pme->nnodes_major, pme->nodeid_major,
3293                       pme->nkx,
3294                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3295
3296     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3297      * We do this with an offset buffer of equal size, so we need to allocate
3298      * extra for the offset. That's what the (+1)*pme->nkz is for.
3299      */
3300     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3301 #ifdef GMX_MPI
3302                       pme->mpi_comm_d[1],
3303 #endif
3304                       pme->nnodes_minor, pme->nodeid_minor,
3305                       pme->nky,
3306                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3307
3308     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3309      * Note that gmx_pme_check_restrictions checked for this already.
3310      */
3311     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3312     {
3313         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3314     }
3315
3316     snew(pme->bsp_mod[XX], pme->nkx);
3317     snew(pme->bsp_mod[YY], pme->nky);
3318     snew(pme->bsp_mod[ZZ], pme->nkz);
3319
3320     /* The required size of the interpolation grid, including overlap.
3321      * The allocated size (pmegrid_n?) might be slightly larger.
3322      */
3323     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3324         pme->overlap[0].s2g0[pme->nodeid_major];
3325     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3326         pme->overlap[1].s2g0[pme->nodeid_minor];
3327     pme->pmegrid_nz_base = pme->nkz;
3328     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3329     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3330
3331     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3332     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3333     pme->pmegrid_start_iz = 0;
3334
3335     make_gridindex5_to_localindex(pme->nkx,
3336                                   pme->pmegrid_start_ix,
3337                                   pme->pmegrid_nx - (pme->pme_order-1),
3338                                   &pme->nnx, &pme->fshx);
3339     make_gridindex5_to_localindex(pme->nky,
3340                                   pme->pmegrid_start_iy,
3341                                   pme->pmegrid_ny - (pme->pme_order-1),
3342                                   &pme->nny, &pme->fshy);
3343     make_gridindex5_to_localindex(pme->nkz,
3344                                   pme->pmegrid_start_iz,
3345                                   pme->pmegrid_nz_base,
3346                                   &pme->nnz, &pme->fshz);
3347
3348     pmegrids_init(&pme->pmegridA,
3349                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3350                   pme->pmegrid_nz_base,
3351                   pme->pme_order,
3352                   pme->bUseThreads,
3353                   pme->nthread,
3354                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3355                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3356
3357     pme->spline_work = make_pme_spline_work(pme->pme_order);
3358
3359     ndata[0] = pme->nkx;
3360     ndata[1] = pme->nky;
3361     ndata[2] = pme->nkz;
3362
3363     /* This routine will allocate the grid data to fit the FFTs */
3364     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3365                             &pme->fftgridA, &pme->cfftgridA,
3366                             pme->mpi_comm_d,
3367                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3368                             bReproducible, pme->nthread);
3369
3370     if (bFreeEnergy)
3371     {
3372         pmegrids_init(&pme->pmegridB,
3373                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3374                       pme->pmegrid_nz_base,
3375                       pme->pme_order,
3376                       pme->bUseThreads,
3377                       pme->nthread,
3378                       pme->nkx % pme->nnodes_major != 0,
3379                       pme->nky % pme->nnodes_minor != 0);
3380
3381         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3382                                 &pme->fftgridB, &pme->cfftgridB,
3383                                 pme->mpi_comm_d,
3384                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3385                                 bReproducible, pme->nthread);
3386     }
3387     else
3388     {
3389         pme->pmegridB.grid.grid = NULL;
3390         pme->fftgridB           = NULL;
3391         pme->cfftgridB          = NULL;
3392     }
3393
3394     if (!pme->bP3M)
3395     {
3396         /* Use plain SPME B-spline interpolation */
3397         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3398     }
3399     else
3400     {
3401         /* Use the P3M grid-optimized influence function */
3402         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3403     }
3404
3405     /* Use atc[0] for spreading */
3406     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3407     if (pme->ndecompdim >= 2)
3408     {
3409         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3410     }
3411
3412     if (pme->nnodes == 1)
3413     {
3414         pme->atc[0].n = homenr;
3415         pme_realloc_atomcomm_things(&pme->atc[0]);
3416     }
3417
3418     {
3419         int thread;
3420
3421         /* Use fft5d, order after FFT is y major, z, x minor */
3422
3423         snew(pme->work, pme->nthread);
3424         for (thread = 0; thread < pme->nthread; thread++)
3425         {
3426             realloc_work(&pme->work[thread], pme->nkx);
3427         }
3428     }
3429
3430     *pmedata = pme;
3431
3432     return 0;
3433 }
3434
3435 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3436 {
3437     int d, t;
3438
3439     for (d = 0; d < DIM; d++)
3440     {
3441         if (new->grid.n[d] > old->grid.n[d])
3442         {
3443             return;
3444         }
3445     }
3446
3447     sfree_aligned(new->grid.grid);
3448     new->grid.grid = old->grid.grid;
3449
3450     if (new->grid_th != NULL && new->nthread == old->nthread)
3451     {
3452         sfree_aligned(new->grid_all);
3453         for (t = 0; t < new->nthread; t++)
3454         {
3455             new->grid_th[t].grid = old->grid_th[t].grid;
3456         }
3457     }
3458 }
3459
3460 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3461                    t_commrec *         cr,
3462                    gmx_pme_t           pme_src,
3463                    const t_inputrec *  ir,
3464                    ivec                grid_size)
3465 {
3466     t_inputrec irc;
3467     int homenr;
3468     int ret;
3469
3470     irc     = *ir;
3471     irc.nkx = grid_size[XX];
3472     irc.nky = grid_size[YY];
3473     irc.nkz = grid_size[ZZ];
3474
3475     if (pme_src->nnodes == 1)
3476     {
3477         homenr = pme_src->atc[0].n;
3478     }
3479     else
3480     {
3481         homenr = -1;
3482     }
3483
3484     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3485                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3486
3487     if (ret == 0)
3488     {
3489         /* We can easily reuse the allocated pme grids in pme_src */
3490         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3491         /* We would like to reuse the fft grids, but that's harder */
3492     }
3493
3494     return ret;
3495 }
3496
3497
3498 static void copy_local_grid(gmx_pme_t pme,
3499                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3500 {
3501     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3502     int  fft_my, fft_mz;
3503     int  nsx, nsy, nsz;
3504     ivec nf;
3505     int  offx, offy, offz, x, y, z, i0, i0t;
3506     int  d;
3507     pmegrid_t *pmegrid;
3508     real *grid_th;
3509
3510     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3511                                    local_fft_ndata,
3512                                    local_fft_offset,
3513                                    local_fft_size);
3514     fft_my = local_fft_size[YY];
3515     fft_mz = local_fft_size[ZZ];
3516
3517     pmegrid = &pmegrids->grid_th[thread];
3518
3519     nsx = pmegrid->s[XX];
3520     nsy = pmegrid->s[YY];
3521     nsz = pmegrid->s[ZZ];
3522
3523     for (d = 0; d < DIM; d++)
3524     {
3525         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3526                     local_fft_ndata[d] - pmegrid->offset[d]);
3527     }
3528
3529     offx = pmegrid->offset[XX];
3530     offy = pmegrid->offset[YY];
3531     offz = pmegrid->offset[ZZ];
3532
3533     /* Directly copy the non-overlapping parts of the local grids.
3534      * This also initializes the full grid.
3535      */
3536     grid_th = pmegrid->grid;
3537     for (x = 0; x < nf[XX]; x++)
3538     {
3539         for (y = 0; y < nf[YY]; y++)
3540         {
3541             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3542             i0t = (x*nsy + y)*nsz;
3543             for (z = 0; z < nf[ZZ]; z++)
3544             {
3545                 fftgrid[i0+z] = grid_th[i0t+z];
3546             }
3547         }
3548     }
3549 }
3550
3551 static void
3552 reduce_threadgrid_overlap(gmx_pme_t pme,
3553                           const pmegrids_t *pmegrids, int thread,
3554                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3555 {
3556     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3557     int  fft_nx, fft_ny, fft_nz;
3558     int  fft_my, fft_mz;
3559     int  buf_my = -1;
3560     int  nsx, nsy, nsz;
3561     ivec ne;
3562     int  offx, offy, offz, x, y, z, i0, i0t;
3563     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3564     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3565     gmx_bool bCommX, bCommY;
3566     int  d;
3567     int  thread_f;
3568     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3569     const real *grid_th;
3570     real *commbuf = NULL;
3571
3572     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3573                                    local_fft_ndata,
3574                                    local_fft_offset,
3575                                    local_fft_size);
3576     fft_nx = local_fft_ndata[XX];
3577     fft_ny = local_fft_ndata[YY];
3578     fft_nz = local_fft_ndata[ZZ];
3579
3580     fft_my = local_fft_size[YY];
3581     fft_mz = local_fft_size[ZZ];
3582
3583     /* This routine is called when all thread have finished spreading.
3584      * Here each thread sums grid contributions calculated by other threads
3585      * to the thread local grid volume.
3586      * To minimize the number of grid copying operations,
3587      * this routines sums immediately from the pmegrid to the fftgrid.
3588      */
3589
3590     /* Determine which part of the full node grid we should operate on,
3591      * this is our thread local part of the full grid.
3592      */
3593     pmegrid = &pmegrids->grid_th[thread];
3594
3595     for (d = 0; d < DIM; d++)
3596     {
3597         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3598                     local_fft_ndata[d]);
3599     }
3600
3601     offx = pmegrid->offset[XX];
3602     offy = pmegrid->offset[YY];
3603     offz = pmegrid->offset[ZZ];
3604
3605
3606     bClearBufX  = TRUE;
3607     bClearBufY  = TRUE;
3608     bClearBufXY = TRUE;
3609
3610     /* Now loop over all the thread data blocks that contribute
3611      * to the grid region we (our thread) are operating on.
3612      */
3613     /* Note that ffy_nx/y is equal to the number of grid points
3614      * between the first point of our node grid and the one of the next node.
3615      */
3616     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3617     {
3618         fx     = pmegrid->ci[XX] + sx;
3619         ox     = 0;
3620         bCommX = FALSE;
3621         if (fx < 0)
3622         {
3623             fx    += pmegrids->nc[XX];
3624             ox    -= fft_nx;
3625             bCommX = (pme->nnodes_major > 1);
3626         }
3627         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3628         ox       += pmegrid_g->offset[XX];
3629         if (!bCommX)
3630         {
3631             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3632         }
3633         else
3634         {
3635             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3636         }
3637
3638         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3639         {
3640             fy     = pmegrid->ci[YY] + sy;
3641             oy     = 0;
3642             bCommY = FALSE;
3643             if (fy < 0)
3644             {
3645                 fy    += pmegrids->nc[YY];
3646                 oy    -= fft_ny;
3647                 bCommY = (pme->nnodes_minor > 1);
3648             }
3649             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3650             oy       += pmegrid_g->offset[YY];
3651             if (!bCommY)
3652             {
3653                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3654             }
3655             else
3656             {
3657                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3658             }
3659
3660             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3661             {
3662                 fz = pmegrid->ci[ZZ] + sz;
3663                 oz = 0;
3664                 if (fz < 0)
3665                 {
3666                     fz += pmegrids->nc[ZZ];
3667                     oz -= fft_nz;
3668                 }
3669                 pmegrid_g = &pmegrids->grid_th[fz];
3670                 oz       += pmegrid_g->offset[ZZ];
3671                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3672
3673                 if (sx == 0 && sy == 0 && sz == 0)
3674                 {
3675                     /* We have already added our local contribution
3676                      * before calling this routine, so skip it here.
3677                      */
3678                     continue;
3679                 }
3680
3681                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3682
3683                 pmegrid_f = &pmegrids->grid_th[thread_f];
3684
3685                 grid_th = pmegrid_f->grid;
3686
3687                 nsx = pmegrid_f->s[XX];
3688                 nsy = pmegrid_f->s[YY];
3689                 nsz = pmegrid_f->s[ZZ];
3690
3691 #ifdef DEBUG_PME_REDUCE
3692                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3693                        pme->nodeid, thread, thread_f,
3694                        pme->pmegrid_start_ix,
3695                        pme->pmegrid_start_iy,
3696                        pme->pmegrid_start_iz,
3697                        sx, sy, sz,
3698                        offx-ox, tx1-ox, offx, tx1,
3699                        offy-oy, ty1-oy, offy, ty1,
3700                        offz-oz, tz1-oz, offz, tz1);
3701 #endif
3702
3703                 if (!(bCommX || bCommY))
3704                 {
3705                     /* Copy from the thread local grid to the node grid */
3706                     for (x = offx; x < tx1; x++)
3707                     {
3708                         for (y = offy; y < ty1; y++)
3709                         {
3710                             i0  = (x*fft_my + y)*fft_mz;
3711                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3712                             for (z = offz; z < tz1; z++)
3713                             {
3714                                 fftgrid[i0+z] += grid_th[i0t+z];
3715                             }
3716                         }
3717                     }
3718                 }
3719                 else
3720                 {
3721                     /* The order of this conditional decides
3722                      * where the corner volume gets stored with x+y decomp.
3723                      */
3724                     if (bCommY)
3725                     {
3726                         commbuf = commbuf_y;
3727                         buf_my  = ty1 - offy;
3728                         if (bCommX)
3729                         {
3730                             /* We index commbuf modulo the local grid size */
3731                             commbuf += buf_my*fft_nx*fft_nz;
3732
3733                             bClearBuf   = bClearBufXY;
3734                             bClearBufXY = FALSE;
3735                         }
3736                         else
3737                         {
3738                             bClearBuf  = bClearBufY;
3739                             bClearBufY = FALSE;
3740                         }
3741                     }
3742                     else
3743                     {
3744                         commbuf    = commbuf_x;
3745                         buf_my     = fft_ny;
3746                         bClearBuf  = bClearBufX;
3747                         bClearBufX = FALSE;
3748                     }
3749
3750                     /* Copy to the communication buffer */
3751                     for (x = offx; x < tx1; x++)
3752                     {
3753                         for (y = offy; y < ty1; y++)
3754                         {
3755                             i0  = (x*buf_my + y)*fft_nz;
3756                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3757
3758                             if (bClearBuf)
3759                             {
3760                                 /* First access of commbuf, initialize it */
3761                                 for (z = offz; z < tz1; z++)
3762                                 {
3763                                     commbuf[i0+z]  = grid_th[i0t+z];
3764                                 }
3765                             }
3766                             else
3767                             {
3768                                 for (z = offz; z < tz1; z++)
3769                                 {
3770                                     commbuf[i0+z] += grid_th[i0t+z];
3771                                 }
3772                             }
3773                         }
3774                     }
3775                 }
3776             }
3777         }
3778     }
3779 }
3780
3781
3782 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3783 {
3784     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3785     pme_overlap_t *overlap;
3786     int  send_index0, send_nindex;
3787     int  recv_nindex;
3788 #ifdef GMX_MPI
3789     MPI_Status stat;
3790 #endif
3791     int  send_size_y, recv_size_y;
3792     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3793     real *sendptr, *recvptr;
3794     int  x, y, z, indg, indb;
3795
3796     /* Note that this routine is only used for forward communication.
3797      * Since the force gathering, unlike the charge spreading,
3798      * can be trivially parallelized over the particles,
3799      * the backwards process is much simpler and can use the "old"
3800      * communication setup.
3801      */
3802
3803     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3804                                    local_fft_ndata,
3805                                    local_fft_offset,
3806                                    local_fft_size);
3807
3808     if (pme->nnodes_minor > 1)
3809     {
3810         /* Major dimension */
3811         overlap = &pme->overlap[1];
3812
3813         if (pme->nnodes_major > 1)
3814         {
3815             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3816         }
3817         else
3818         {
3819             size_yx = 0;
3820         }
3821         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3822
3823         send_size_y = overlap->send_size;
3824
3825         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3826         {
3827             send_id       = overlap->send_id[ipulse];
3828             recv_id       = overlap->recv_id[ipulse];
3829             send_index0   =
3830                 overlap->comm_data[ipulse].send_index0 -
3831                 overlap->comm_data[0].send_index0;
3832             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3833             /* We don't use recv_index0, as we always receive starting at 0 */
3834             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3835             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3836
3837             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3838             recvptr = overlap->recvbuf;
3839
3840 #ifdef GMX_MPI
3841             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3842                          send_id, ipulse,
3843                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3844                          recv_id, ipulse,
3845                          overlap->mpi_comm, &stat);
3846 #endif
3847
3848             for (x = 0; x < local_fft_ndata[XX]; x++)
3849             {
3850                 for (y = 0; y < recv_nindex; y++)
3851                 {
3852                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3853                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3854                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3855                     {
3856                         fftgrid[indg+z] += recvptr[indb+z];
3857                     }
3858                 }
3859             }
3860
3861             if (pme->nnodes_major > 1)
3862             {
3863                 /* Copy from the received buffer to the send buffer for dim 0 */
3864                 sendptr = pme->overlap[0].sendbuf;
3865                 for (x = 0; x < size_yx; x++)
3866                 {
3867                     for (y = 0; y < recv_nindex; y++)
3868                     {
3869                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3870                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3871                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3872                         {
3873                             sendptr[indg+z] += recvptr[indb+z];
3874                         }
3875                     }
3876                 }
3877             }
3878         }
3879     }
3880
3881     /* We only support a single pulse here.
3882      * This is not a severe limitation, as this code is only used
3883      * with OpenMP and with OpenMP the (PME) domains can be larger.
3884      */
3885     if (pme->nnodes_major > 1)
3886     {
3887         /* Major dimension */
3888         overlap = &pme->overlap[0];
3889
3890         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3891         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3892
3893         ipulse = 0;
3894
3895         send_id       = overlap->send_id[ipulse];
3896         recv_id       = overlap->recv_id[ipulse];
3897         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3898         /* We don't use recv_index0, as we always receive starting at 0 */
3899         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3900
3901         sendptr = overlap->sendbuf;
3902         recvptr = overlap->recvbuf;
3903
3904         if (debug != NULL)
3905         {
3906             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3907                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3908         }
3909
3910 #ifdef GMX_MPI
3911         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3912                      send_id, ipulse,
3913                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3914                      recv_id, ipulse,
3915                      overlap->mpi_comm, &stat);
3916 #endif
3917
3918         for (x = 0; x < recv_nindex; x++)
3919         {
3920             for (y = 0; y < local_fft_ndata[YY]; y++)
3921             {
3922                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3923                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3924                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3925                 {
3926                     fftgrid[indg+z] += recvptr[indb+z];
3927                 }
3928             }
3929         }
3930     }
3931 }
3932
3933
3934 static void spread_on_grid(gmx_pme_t pme,
3935                            pme_atomcomm_t *atc, pmegrids_t *grids,
3936                            gmx_bool bCalcSplines, gmx_bool bSpread,
3937                            real *fftgrid)
3938 {
3939     int nthread, thread;
3940 #ifdef PME_TIME_THREADS
3941     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3942     static double cs1     = 0, cs2 = 0, cs3 = 0;
3943     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3944     static int cnt        = 0;
3945 #endif
3946
3947     nthread = pme->nthread;
3948     assert(nthread > 0);
3949
3950 #ifdef PME_TIME_THREADS
3951     c1 = omp_cyc_start();
3952 #endif
3953     if (bCalcSplines)
3954     {
3955 #pragma omp parallel for num_threads(nthread) schedule(static)
3956         for (thread = 0; thread < nthread; thread++)
3957         {
3958             int start, end;
3959
3960             start = atc->n* thread   /nthread;
3961             end   = atc->n*(thread+1)/nthread;
3962
3963             /* Compute fftgrid index for all atoms,
3964              * with help of some extra variables.
3965              */
3966             calc_interpolation_idx(pme, atc, start, end, thread);
3967         }
3968     }
3969 #ifdef PME_TIME_THREADS
3970     c1   = omp_cyc_end(c1);
3971     cs1 += (double)c1;
3972 #endif
3973
3974 #ifdef PME_TIME_THREADS
3975     c2 = omp_cyc_start();
3976 #endif
3977 #pragma omp parallel for num_threads(nthread) schedule(static)
3978     for (thread = 0; thread < nthread; thread++)
3979     {
3980         splinedata_t *spline;
3981         pmegrid_t *grid = NULL;
3982
3983         /* make local bsplines  */
3984         if (grids == NULL || !pme->bUseThreads)
3985         {
3986             spline = &atc->spline[0];
3987
3988             spline->n = atc->n;
3989
3990             if (bSpread)
3991             {
3992                 grid = &grids->grid;
3993             }
3994         }
3995         else
3996         {
3997             spline = &atc->spline[thread];
3998
3999             if (grids->nthread == 1)
4000             {
4001                 /* One thread, we operate on all charges */
4002                 spline->n = atc->n;
4003             }
4004             else
4005             {
4006                 /* Get the indices our thread should operate on */
4007                 make_thread_local_ind(atc, thread, spline);
4008             }
4009
4010             grid = &grids->grid_th[thread];
4011         }
4012
4013         if (bCalcSplines)
4014         {
4015             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4016                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
4017         }
4018
4019         if (bSpread)
4020         {
4021             /* put local atoms on grid. */
4022 #ifdef PME_TIME_SPREAD
4023             ct1a = omp_cyc_start();
4024 #endif
4025             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4026
4027             if (pme->bUseThreads)
4028             {
4029                 copy_local_grid(pme, grids, thread, fftgrid);
4030             }
4031 #ifdef PME_TIME_SPREAD
4032             ct1a          = omp_cyc_end(ct1a);
4033             cs1a[thread] += (double)ct1a;
4034 #endif
4035         }
4036     }
4037 #ifdef PME_TIME_THREADS
4038     c2   = omp_cyc_end(c2);
4039     cs2 += (double)c2;
4040 #endif
4041
4042     if (bSpread && pme->bUseThreads)
4043     {
4044 #ifdef PME_TIME_THREADS
4045         c3 = omp_cyc_start();
4046 #endif
4047 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4048         for (thread = 0; thread < grids->nthread; thread++)
4049         {
4050             reduce_threadgrid_overlap(pme, grids, thread,
4051                                       fftgrid,
4052                                       pme->overlap[0].sendbuf,
4053                                       pme->overlap[1].sendbuf);
4054         }
4055 #ifdef PME_TIME_THREADS
4056         c3   = omp_cyc_end(c3);
4057         cs3 += (double)c3;
4058 #endif
4059
4060         if (pme->nnodes > 1)
4061         {
4062             /* Communicate the overlapping part of the fftgrid.
4063              * For this communication call we need to check pme->bUseThreads
4064              * to have all ranks communicate here, regardless of pme->nthread.
4065              */
4066             sum_fftgrid_dd(pme, fftgrid);
4067         }
4068     }
4069
4070 #ifdef PME_TIME_THREADS
4071     cnt++;
4072     if (cnt % 20 == 0)
4073     {
4074         printf("idx %.2f spread %.2f red %.2f",
4075                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4076 #ifdef PME_TIME_SPREAD
4077         for (thread = 0; thread < nthread; thread++)
4078         {
4079             printf(" %.2f", cs1a[thread]*1e-9);
4080         }
4081 #endif
4082         printf("\n");
4083     }
4084 #endif
4085 }
4086
4087
4088 static void dump_grid(FILE *fp,
4089                       int sx, int sy, int sz, int nx, int ny, int nz,
4090                       int my, int mz, const real *g)
4091 {
4092     int x, y, z;
4093
4094     for (x = 0; x < nx; x++)
4095     {
4096         for (y = 0; y < ny; y++)
4097         {
4098             for (z = 0; z < nz; z++)
4099             {
4100                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4101                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4102             }
4103         }
4104     }
4105 }
4106
4107 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4108 {
4109     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4110
4111     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4112                                    local_fft_ndata,
4113                                    local_fft_offset,
4114                                    local_fft_size);
4115
4116     dump_grid(stderr,
4117               pme->pmegrid_start_ix,
4118               pme->pmegrid_start_iy,
4119               pme->pmegrid_start_iz,
4120               pme->pmegrid_nx-pme->pme_order+1,
4121               pme->pmegrid_ny-pme->pme_order+1,
4122               pme->pmegrid_nz-pme->pme_order+1,
4123               local_fft_size[YY],
4124               local_fft_size[ZZ],
4125               fftgrid);
4126 }
4127
4128
4129 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4130 {
4131     pme_atomcomm_t *atc;
4132     pmegrids_t *grid;
4133
4134     if (pme->nnodes > 1)
4135     {
4136         gmx_incons("gmx_pme_calc_energy called in parallel");
4137     }
4138     if (pme->bFEP > 1)
4139     {
4140         gmx_incons("gmx_pme_calc_energy with free energy");
4141     }
4142
4143     atc            = &pme->atc_energy;
4144     atc->nthread   = 1;
4145     if (atc->spline == NULL)
4146     {
4147         snew(atc->spline, atc->nthread);
4148     }
4149     atc->nslab     = 1;
4150     atc->bSpread   = TRUE;
4151     atc->pme_order = pme->pme_order;
4152     atc->n         = n;
4153     pme_realloc_atomcomm_things(atc);
4154     atc->x         = x;
4155     atc->q         = q;
4156
4157     /* We only use the A-charges grid */
4158     grid = &pme->pmegridA;
4159
4160     /* Only calculate the spline coefficients, don't actually spread */
4161     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4162
4163     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4164 }
4165
4166
4167 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4168                                    gmx_runtime_t *runtime,
4169                                    t_nrnb *nrnb, t_inputrec *ir,
4170                                    gmx_large_int_t step)
4171 {
4172     /* Reset all the counters related to performance over the run */
4173     wallcycle_stop(wcycle, ewcRUN);
4174     wallcycle_reset_all(wcycle);
4175     init_nrnb(nrnb);
4176     if (ir->nsteps >= 0)
4177     {
4178         /* ir->nsteps is not used here, but we update it for consistency */
4179         ir->nsteps -= step - ir->init_step;
4180     }
4181     ir->init_step = step;
4182     wallcycle_start(wcycle, ewcRUN);
4183     runtime_start(runtime);
4184 }
4185
4186
4187 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4188                                ivec grid_size,
4189                                t_commrec *cr, t_inputrec *ir,
4190                                gmx_pme_t *pme_ret)
4191 {
4192     int ind;
4193     gmx_pme_t pme = NULL;
4194
4195     ind = 0;
4196     while (ind < *npmedata)
4197     {
4198         pme = (*pmedata)[ind];
4199         if (pme->nkx == grid_size[XX] &&
4200             pme->nky == grid_size[YY] &&
4201             pme->nkz == grid_size[ZZ])
4202         {
4203             *pme_ret = pme;
4204
4205             return;
4206         }
4207
4208         ind++;
4209     }
4210
4211     (*npmedata)++;
4212     srenew(*pmedata, *npmedata);
4213
4214     /* Generate a new PME data structure, copying part of the old pointers */
4215     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4216
4217     *pme_ret = (*pmedata)[ind];
4218 }
4219
4220
4221 int gmx_pmeonly(gmx_pme_t pme,
4222                 t_commrec *cr,    t_nrnb *nrnb,
4223                 gmx_wallcycle_t wcycle,
4224                 gmx_runtime_t *runtime,
4225                 real ewaldcoeff,  gmx_bool bGatherOnly,
4226                 t_inputrec *ir)
4227 {
4228     int npmedata;
4229     gmx_pme_t *pmedata;
4230     gmx_pme_pp_t pme_pp;
4231     int  ret;
4232     int  natoms;
4233     matrix box;
4234     rvec *x_pp      = NULL, *f_pp = NULL;
4235     real *chargeA   = NULL, *chargeB = NULL;
4236     real lambda     = 0;
4237     int  maxshift_x = 0, maxshift_y = 0;
4238     real energy, dvdlambda;
4239     matrix vir;
4240     float cycles;
4241     int  count;
4242     gmx_bool bEnerVir;
4243     gmx_large_int_t step, step_rel;
4244     ivec grid_switch;
4245
4246     /* This data will only use with PME tuning, i.e. switching PME grids */
4247     npmedata = 1;
4248     snew(pmedata, npmedata);
4249     pmedata[0] = pme;
4250
4251     pme_pp = gmx_pme_pp_init(cr);
4252
4253     init_nrnb(nrnb);
4254
4255     count = 0;
4256     do /****** this is a quasi-loop over time steps! */
4257     {
4258         /* The reason for having a loop here is PME grid tuning/switching */
4259         do
4260         {
4261             /* Domain decomposition */
4262             ret = gmx_pme_recv_q_x(pme_pp,
4263                                    &natoms,
4264                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4265                                    &maxshift_x, &maxshift_y,
4266                                    &pme->bFEP, &lambda,
4267                                    &bEnerVir,
4268                                    &step,
4269                                    grid_switch, &ewaldcoeff);
4270
4271             if (ret == pmerecvqxSWITCHGRID)
4272             {
4273                 /* Switch the PME grid to grid_switch */
4274                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4275             }
4276
4277             if (ret == pmerecvqxRESETCOUNTERS)
4278             {
4279                 /* Reset the cycle and flop counters */
4280                 reset_pmeonly_counters(cr, wcycle, runtime, nrnb, ir, step);
4281             }
4282         }
4283         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4284
4285         if (ret == pmerecvqxFINISH)
4286         {
4287             /* We should stop: break out of the loop */
4288             break;
4289         }
4290
4291         step_rel = step - ir->init_step;
4292
4293         if (count == 0)
4294         {
4295             wallcycle_start(wcycle, ewcRUN);
4296             runtime_start(runtime);
4297         }
4298
4299         wallcycle_start(wcycle, ewcPMEMESH);
4300
4301         dvdlambda = 0;
4302         clear_mat(vir);
4303         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4304                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4305                    &energy, lambda, &dvdlambda,
4306                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4307
4308         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4309
4310         gmx_pme_send_force_vir_ener(pme_pp,
4311                                     f_pp, vir, energy, dvdlambda,
4312                                     cycles);
4313
4314         count++;
4315     } /***** end of quasi-loop, we stop with the break above */
4316     while (TRUE);
4317
4318     runtime_end(runtime);
4319
4320     return 0;
4321 }
4322
4323 int gmx_pme_do(gmx_pme_t pme,
4324                int start,       int homenr,
4325                rvec x[],        rvec f[],
4326                real *chargeA,   real *chargeB,
4327                matrix box, t_commrec *cr,
4328                int  maxshift_x, int maxshift_y,
4329                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4330                matrix vir,      real ewaldcoeff,
4331                real *energy,    real lambda,
4332                real *dvdlambda, int flags)
4333 {
4334     int     q, d, i, j, ntot, npme;
4335     int     nx, ny, nz;
4336     int     n_d, local_ny;
4337     pme_atomcomm_t *atc = NULL;
4338     pmegrids_t *pmegrid = NULL;
4339     real    *grid       = NULL;
4340     real    *ptr;
4341     rvec    *x_d, *f_d;
4342     real    *charge = NULL, *q_d;
4343     real    energy_AB[2];
4344     matrix  vir_AB[2];
4345     gmx_bool bClearF;
4346     gmx_parallel_3dfft_t pfft_setup;
4347     real *  fftgrid;
4348     t_complex * cfftgrid;
4349     int     thread;
4350     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4351     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4352
4353     assert(pme->nnodes > 0);
4354     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4355
4356     if (pme->nnodes > 1)
4357     {
4358         atc      = &pme->atc[0];
4359         atc->npd = homenr;
4360         if (atc->npd > atc->pd_nalloc)
4361         {
4362             atc->pd_nalloc = over_alloc_dd(atc->npd);
4363             srenew(atc->pd, atc->pd_nalloc);
4364         }
4365         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4366     }
4367     else
4368     {
4369         /* This could be necessary for TPI */
4370         pme->atc[0].n = homenr;
4371     }
4372
4373     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4374     {
4375         if (q == 0)
4376         {
4377             pmegrid    = &pme->pmegridA;
4378             fftgrid    = pme->fftgridA;
4379             cfftgrid   = pme->cfftgridA;
4380             pfft_setup = pme->pfft_setupA;
4381             charge     = chargeA+start;
4382         }
4383         else
4384         {
4385             pmegrid    = &pme->pmegridB;
4386             fftgrid    = pme->fftgridB;
4387             cfftgrid   = pme->cfftgridB;
4388             pfft_setup = pme->pfft_setupB;
4389             charge     = chargeB+start;
4390         }
4391         grid = pmegrid->grid.grid;
4392         /* Unpack structure */
4393         if (debug)
4394         {
4395             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4396                     cr->nnodes, cr->nodeid);
4397             fprintf(debug, "Grid = %p\n", (void*)grid);
4398             if (grid == NULL)
4399             {
4400                 gmx_fatal(FARGS, "No grid!");
4401             }
4402         }
4403         where();
4404
4405         m_inv_ur0(box, pme->recipbox);
4406
4407         if (pme->nnodes == 1)
4408         {
4409             atc = &pme->atc[0];
4410             if (DOMAINDECOMP(cr))
4411             {
4412                 atc->n = homenr;
4413                 pme_realloc_atomcomm_things(atc);
4414             }
4415             atc->x = x;
4416             atc->q = charge;
4417             atc->f = f;
4418         }
4419         else
4420         {
4421             wallcycle_start(wcycle, ewcPME_REDISTXF);
4422             for (d = pme->ndecompdim-1; d >= 0; d--)
4423             {
4424                 if (d == pme->ndecompdim-1)
4425                 {
4426                     n_d = homenr;
4427                     x_d = x + start;
4428                     q_d = charge;
4429                 }
4430                 else
4431                 {
4432                     n_d = pme->atc[d+1].n;
4433                     x_d = atc->x;
4434                     q_d = atc->q;
4435                 }
4436                 atc      = &pme->atc[d];
4437                 atc->npd = n_d;
4438                 if (atc->npd > atc->pd_nalloc)
4439                 {
4440                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4441                     srenew(atc->pd, atc->pd_nalloc);
4442                 }
4443                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4444                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4445                 where();
4446
4447                 GMX_BARRIER(cr->mpi_comm_mygroup);
4448                 /* Redistribute x (only once) and qA or qB */
4449                 if (DOMAINDECOMP(cr))
4450                 {
4451                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4452                 }
4453                 else
4454                 {
4455                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4456                 }
4457             }
4458             where();
4459
4460             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4461         }
4462
4463         if (debug)
4464         {
4465             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4466                     cr->nodeid, atc->n);
4467         }
4468
4469         if (flags & GMX_PME_SPREAD_Q)
4470         {
4471             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4472
4473             /* Spread the charges on a grid */
4474             GMX_MPE_LOG(ev_spread_on_grid_start);
4475
4476             /* Spread the charges on a grid */
4477             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4478             GMX_MPE_LOG(ev_spread_on_grid_finish);
4479
4480             if (q == 0)
4481             {
4482                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4483             }
4484             inc_nrnb(nrnb, eNR_SPREADQBSP,
4485                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4486
4487             if (!pme->bUseThreads)
4488             {
4489                 wrap_periodic_pmegrid(pme, grid);
4490
4491                 /* sum contributions to local grid from other nodes */
4492 #ifdef GMX_MPI
4493                 if (pme->nnodes > 1)
4494                 {
4495                     GMX_BARRIER(cr->mpi_comm_mygroup);
4496                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4497                     where();
4498                 }
4499 #endif
4500
4501                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4502             }
4503
4504             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4505
4506             /*
4507                dump_local_fftgrid(pme,fftgrid);
4508                exit(0);
4509              */
4510         }
4511
4512         /* Here we start a large thread parallel region */
4513 #pragma omp parallel num_threads(pme->nthread) private(thread)
4514         {
4515             thread = gmx_omp_get_thread_num();
4516             if (flags & GMX_PME_SOLVE)
4517             {
4518                 int loop_count;
4519
4520                 /* do 3d-fft */
4521                 if (thread == 0)
4522                 {
4523                     GMX_BARRIER(cr->mpi_comm_mygroup);
4524                     GMX_MPE_LOG(ev_gmxfft3d_start);
4525                     wallcycle_start(wcycle, ewcPME_FFT);
4526                 }
4527                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4528                                            fftgrid, cfftgrid, thread, wcycle);
4529                 if (thread == 0)
4530                 {
4531                     wallcycle_stop(wcycle, ewcPME_FFT);
4532                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4533                 }
4534                 where();
4535
4536                 /* solve in k-space for our local cells */
4537                 if (thread == 0)
4538                 {
4539                     GMX_BARRIER(cr->mpi_comm_mygroup);
4540                     GMX_MPE_LOG(ev_solve_pme_start);
4541                     wallcycle_start(wcycle, ewcPME_SOLVE);
4542                 }
4543                 loop_count =
4544                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4545                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4546                                   bCalcEnerVir,
4547                                   pme->nthread, thread);
4548                 if (thread == 0)
4549                 {
4550                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4551                     where();
4552                     GMX_MPE_LOG(ev_solve_pme_finish);
4553                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4554                 }
4555             }
4556
4557             if (bCalcF)
4558             {
4559                 /* do 3d-invfft */
4560                 if (thread == 0)
4561                 {
4562                     GMX_BARRIER(cr->mpi_comm_mygroup);
4563                     GMX_MPE_LOG(ev_gmxfft3d_start);
4564                     where();
4565                     wallcycle_start(wcycle, ewcPME_FFT);
4566                 }
4567                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4568                                            cfftgrid, fftgrid, thread, wcycle);
4569                 if (thread == 0)
4570                 {
4571                     wallcycle_stop(wcycle, ewcPME_FFT);
4572
4573                     where();
4574                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4575
4576                     if (pme->nodeid == 0)
4577                     {
4578                         ntot  = pme->nkx*pme->nky*pme->nkz;
4579                         npme  = ntot*log((real)ntot)/log(2.0);
4580                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4581                     }
4582
4583                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4584                 }
4585
4586                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4587             }
4588         }
4589         /* End of thread parallel section.
4590          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4591          */
4592
4593         if (bCalcF)
4594         {
4595             /* distribute local grid to all nodes */
4596 #ifdef GMX_MPI
4597             if (pme->nnodes > 1)
4598             {
4599                 GMX_BARRIER(cr->mpi_comm_mygroup);
4600                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4601             }
4602 #endif
4603             where();
4604
4605             unwrap_periodic_pmegrid(pme, grid);
4606
4607             /* interpolate forces for our local atoms */
4608             GMX_BARRIER(cr->mpi_comm_mygroup);
4609             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4610
4611             where();
4612
4613             /* If we are running without parallelization,
4614              * atc->f is the actual force array, not a buffer,
4615              * therefore we should not clear it.
4616              */
4617             bClearF = (q == 0 && PAR(cr));
4618 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4619             for (thread = 0; thread < pme->nthread; thread++)
4620             {
4621                 gather_f_bsplines(pme, grid, bClearF, atc,
4622                                   &atc->spline[thread],
4623                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4624             }
4625
4626             where();
4627
4628             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4629
4630             inc_nrnb(nrnb, eNR_GATHERFBSP,
4631                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4632             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4633         }
4634
4635         if (bCalcEnerVir)
4636         {
4637             /* This should only be called on the master thread
4638              * and after the threads have synchronized.
4639              */
4640             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4641         }
4642     } /* of q-loop */
4643
4644     if (bCalcF && pme->nnodes > 1)
4645     {
4646         wallcycle_start(wcycle, ewcPME_REDISTXF);
4647         for (d = 0; d < pme->ndecompdim; d++)
4648         {
4649             atc = &pme->atc[d];
4650             if (d == pme->ndecompdim - 1)
4651             {
4652                 n_d = homenr;
4653                 f_d = f + start;
4654             }
4655             else
4656             {
4657                 n_d = pme->atc[d+1].n;
4658                 f_d = pme->atc[d+1].f;
4659             }
4660             GMX_BARRIER(cr->mpi_comm_mygroup);
4661             if (DOMAINDECOMP(cr))
4662             {
4663                 dd_pmeredist_f(pme, atc, n_d, f_d,
4664                                d == pme->ndecompdim-1 && pme->bPPnode);
4665             }
4666             else
4667             {
4668                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4669             }
4670         }
4671
4672         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4673     }
4674     where();
4675
4676     if (bCalcEnerVir)
4677     {
4678         if (!pme->bFEP)
4679         {
4680             *energy = energy_AB[0];
4681             m_add(vir, vir_AB[0], vir);
4682         }
4683         else
4684         {
4685             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4686             *dvdlambda += energy_AB[1] - energy_AB[0];
4687             for (i = 0; i < DIM; i++)
4688             {
4689                 for (j = 0; j < DIM; j++)
4690                 {
4691                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4692                         lambda*vir_AB[1][i][j];
4693                 }
4694             }
4695         }
4696     }
4697     else
4698     {
4699         *energy = 0;
4700     }
4701
4702     if (debug)
4703     {
4704         fprintf(debug, "PME mesh energy: %g\n", *energy);
4705     }
4706
4707     return 0;
4708 }