Used IWYU to partially clean up some includes
[alexxy/gromacs.git] / src / gromacs / ewald / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #include "gmxpre.h"
61
62 #include "pme.h"
63
64 #include "config.h"
65
66 #include <assert.h>
67 #include <math.h>
68 #include <stdio.h>
69 #include <stdlib.h>
70
71 #include "gromacs/ewald/pme-internal.h"
72 #include "gromacs/fft/fft.h"
73 #include "gromacs/fft/parallel_3dfft.h"
74 #include "gromacs/legacyheaders/macros.h"
75 #include "gromacs/legacyheaders/nrnb.h"
76 #include "gromacs/legacyheaders/types/commrec.h"
77 #include "gromacs/legacyheaders/types/enums.h"
78 #include "gromacs/legacyheaders/types/forcerec.h"
79 #include "gromacs/legacyheaders/types/inputrec.h"
80 #include "gromacs/legacyheaders/types/nrnb.h"
81 #include "gromacs/math/gmxcomplex.h"
82 #include "gromacs/math/units.h"
83 #include "gromacs/math/vec.h"
84 #include "gromacs/math/vectypes.h"
85 /* Include the SIMD macro file and then check for support */
86 #include "gromacs/simd/simd.h"
87 #include "gromacs/simd/simd_math.h"
88 #include "gromacs/timing/cyclecounter.h"
89 #include "gromacs/timing/wallcycle.h"
90 #include "gromacs/timing/walltime_accounting.h"
91 #include "gromacs/utility/basedefinitions.h"
92 #include "gromacs/utility/fatalerror.h"
93 #include "gromacs/utility/gmxmpi.h"
94 #include "gromacs/utility/gmxomp.h"
95 #include "gromacs/utility/real.h"
96 #include "gromacs/utility/smalloc.h"
97
98 #ifdef DEBUG_PME
99 #include "gromacs/fileio/pdbio.h"
100 #include "gromacs/utility/cstringutil.h"
101 #include "gromacs/utility/futil.h"
102 #endif
103
104 #ifdef GMX_SIMD_HAVE_REAL
105 /* Turn on arbitrary width SIMD intrinsics for PME solve */
106 #    define PME_SIMD_SOLVE
107 #endif
108
109 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
110 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
111 #define DO_Q           2 /* Electrostatic grids have index q<2 */
112 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
113 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
114
115 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
116 const real lb_scale_factor[] = {
117     1.0/64, 6.0/64, 15.0/64, 20.0/64,
118     15.0/64, 6.0/64, 1.0/64
119 };
120
121 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
122 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
123
124 /* Check if we have 4-wide SIMD macro support */
125 #if (defined GMX_SIMD4_HAVE_REAL)
126 /* Do PME spread and gather with 4-wide SIMD.
127  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
128  */
129 #    define PME_SIMD4_SPREAD_GATHER
130
131 #    if (defined GMX_SIMD_HAVE_LOADU) && (defined GMX_SIMD_HAVE_STOREU)
132 /* With PME-order=4 on x86, unaligned load+store is slightly faster
133  * than doubling all SIMD operations when using aligned load+store.
134  */
135 #        define PME_SIMD4_UNALIGNED
136 #    endif
137 #endif
138
139 #define DFT_TOL 1e-7
140 /* #define PRT_FORCE */
141 /* conditions for on the fly time-measurement */
142 /* #define TAKETIME (step > 1 && timesteps < 10) */
143 #define TAKETIME FALSE
144
145 /* #define PME_TIME_THREADS */
146
147 #ifdef GMX_DOUBLE
148 #define mpi_type MPI_DOUBLE
149 #else
150 #define mpi_type MPI_FLOAT
151 #endif
152
153 #ifdef PME_SIMD4_SPREAD_GATHER
154 #    define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
155 #else
156 /* We can use any alignment, apart from 0, so we use 4 reals */
157 #    define SIMD4_ALIGNMENT  (4*sizeof(real))
158 #endif
159
160 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
161  * to preserve alignment.
162  */
163 #define GMX_CACHE_SEP 64
164
165 /* We only define a maximum to be able to use local arrays without allocation.
166  * An order larger than 12 should never be needed, even for test cases.
167  * If needed it can be changed here.
168  */
169 #define PME_ORDER_MAX 12
170
171 /* Internal datastructures */
172 typedef struct {
173     int send_index0;
174     int send_nindex;
175     int recv_index0;
176     int recv_nindex;
177     int recv_size;   /* Receive buffer width, used with OpenMP */
178 } pme_grid_comm_t;
179
180 typedef real *splinevec[DIM];
181
182 typedef struct {
183 #ifdef GMX_MPI
184     MPI_Comm         mpi_comm;
185 #endif
186     int              nnodes, nodeid;
187     int             *s2g0;
188     int             *s2g1;
189     int              noverlap_nodes;
190     int             *send_id, *recv_id;
191     int              send_size; /* Send buffer width, used with OpenMP */
192     pme_grid_comm_t *comm_data;
193     real            *sendbuf;
194     real            *recvbuf;
195 } pme_overlap_t;
196
197 typedef struct {
198     int *n;      /* Cumulative counts of the number of particles per thread */
199     int  nalloc; /* Allocation size of i */
200     int *i;      /* Particle indices ordered on thread index (n) */
201 } thread_plist_t;
202
203 typedef struct {
204     int      *thread_one;
205     int       n;
206     int      *ind;
207     splinevec theta;
208     real     *ptr_theta_z;
209     splinevec dtheta;
210     real     *ptr_dtheta_z;
211 } splinedata_t;
212
213 typedef struct {
214     int      dimind;        /* The index of the dimension, 0=x, 1=y */
215     int      nslab;
216     int      nodeid;
217 #ifdef GMX_MPI
218     MPI_Comm mpi_comm;
219 #endif
220
221     int     *node_dest;     /* The nodes to send x and q to with DD */
222     int     *node_src;      /* The nodes to receive x and q from with DD */
223     int     *buf_index;     /* Index for commnode into the buffers */
224
225     int      maxshift;
226
227     int      npd;
228     int      pd_nalloc;
229     int     *pd;
230     int     *count;         /* The number of atoms to send to each node */
231     int    **count_thread;
232     int     *rcount;        /* The number of atoms to receive */
233
234     int      n;
235     int      nalloc;
236     rvec    *x;
237     real    *coefficient;
238     rvec    *f;
239     gmx_bool bSpread;       /* These coordinates are used for spreading */
240     int      pme_order;
241     ivec    *idx;
242     rvec    *fractx;            /* Fractional coordinate relative to
243                                  * the lower cell boundary
244                                  */
245     int             nthread;
246     int            *thread_idx; /* Which thread should spread which coefficient */
247     thread_plist_t *thread_plist;
248     splinedata_t   *spline;
249 } pme_atomcomm_t;
250
251 #define FLBS  3
252 #define FLBSZ 4
253
254 typedef struct {
255     ivec  ci;     /* The spatial location of this grid         */
256     ivec  n;      /* The used size of *grid, including order-1 */
257     ivec  offset; /* The grid offset from the full node grid   */
258     int   order;  /* PME spreading order                       */
259     ivec  s;      /* The allocated size of *grid, s >= n       */
260     real *grid;   /* The grid local thread, size n             */
261 } pmegrid_t;
262
263 typedef struct {
264     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
265     int        nthread;      /* The number of threads operating on this grid     */
266     ivec       nc;           /* The local spatial decomposition over the threads */
267     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
268     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
269     int      **g2t;          /* The grid to thread index                         */
270     ivec       nthread_comm; /* The number of threads to communicate with        */
271 } pmegrids_t;
272
273 typedef struct {
274 #ifdef PME_SIMD4_SPREAD_GATHER
275     /* Masks for 4-wide SIMD aligned spreading and gathering */
276     gmx_simd4_bool_t mask_S0[6], mask_S1[6];
277 #else
278     int              dummy; /* C89 requires that struct has at least one member */
279 #endif
280 } pme_spline_work_t;
281
282 typedef struct {
283     /* work data for solve_pme */
284     int      nalloc;
285     real *   mhx;
286     real *   mhy;
287     real *   mhz;
288     real *   m2;
289     real *   denom;
290     real *   tmp1_alloc;
291     real *   tmp1;
292     real *   tmp2;
293     real *   eterm;
294     real *   m2inv;
295
296     real     energy_q;
297     matrix   vir_q;
298     real     energy_lj;
299     matrix   vir_lj;
300 } pme_work_t;
301
302 typedef struct gmx_pme {
303     int           ndecompdim; /* The number of decomposition dimensions */
304     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
305     int           nodeid_major;
306     int           nodeid_minor;
307     int           nnodes;    /* The number of nodes doing PME */
308     int           nnodes_major;
309     int           nnodes_minor;
310
311     MPI_Comm      mpi_comm;
312     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
313 #ifdef GMX_MPI
314     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
315 #endif
316
317     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
318     int        nthread;       /* The number of threads doing PME on our rank */
319
320     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
321     gmx_bool   bFEP;          /* Compute Free energy contribution */
322     gmx_bool   bFEP_q;
323     gmx_bool   bFEP_lj;
324     int        nkx, nky, nkz; /* Grid dimensions */
325     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
326     int        pme_order;
327     real       epsilon_r;
328
329     int        ljpme_combination_rule;  /* Type of combination rule in LJ-PME */
330
331     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
332
333     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
334                                          * includes overlap Grid indices are ordered as
335                                          * follows:
336                                          * 0: Coloumb PME, state A
337                                          * 1: Coloumb PME, state B
338                                          * 2-8: LJ-PME
339                                          * This can probably be done in a better way
340                                          * but this simple hack works for now
341                                          */
342     /* The PME coefficient spreading grid sizes/strides, includes pme_order-1 */
343     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
344     /* pmegrid_nz might be larger than strictly necessary to ensure
345      * memory alignment, pmegrid_nz_base gives the real base size.
346      */
347     int     pmegrid_nz_base;
348     /* The local PME grid starting indices */
349     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
350
351     /* Work data for spreading and gathering */
352     pme_spline_work_t     *spline_work;
353
354     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
355     /* inside the interpolation grid, but separate for 2D PME decomp. */
356     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
357
358     t_complex            **cfftgrid;  /* Grids for complex FFT data */
359
360     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
361
362     gmx_parallel_3dfft_t  *pfft_setup;
363
364     int                   *nnx, *nny, *nnz;
365     real                  *fshx, *fshy, *fshz;
366
367     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
368     matrix                 recipbox;
369     splinevec              bsp_mod;
370     /* Buffers to store data for local atoms for L-B combination rule
371      * calculations in LJ-PME. lb_buf1 stores either the coefficients
372      * for spreading/gathering (in serial), or the C6 coefficient for
373      * local atoms (in parallel).  lb_buf2 is only used in parallel,
374      * and stores the sigma values for local atoms. */
375     real                 *lb_buf1, *lb_buf2;
376     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
377
378     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
379
380     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
381
382     rvec                 *bufv;          /* Communication buffer */
383     real                 *bufr;          /* Communication buffer */
384     int                   buf_nalloc;    /* The communication buffer size */
385
386     /* thread local work data for solve_pme */
387     pme_work_t *work;
388
389     /* Work data for sum_qgrid */
390     real *   sum_qgrid_tmp;
391     real *   sum_qgrid_dd_tmp;
392 } t_gmx_pme;
393
394 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
395                                    int start, int grid_index, int end, int thread)
396 {
397     int             i;
398     int            *idxptr, tix, tiy, tiz;
399     real           *xptr, *fptr, tx, ty, tz;
400     real            rxx, ryx, ryy, rzx, rzy, rzz;
401     int             nx, ny, nz;
402     int             start_ix, start_iy, start_iz;
403     int            *g2tx, *g2ty, *g2tz;
404     gmx_bool        bThreads;
405     int            *thread_idx = NULL;
406     thread_plist_t *tpl        = NULL;
407     int            *tpl_n      = NULL;
408     int             thread_i;
409
410     nx  = pme->nkx;
411     ny  = pme->nky;
412     nz  = pme->nkz;
413
414     start_ix = pme->pmegrid_start_ix;
415     start_iy = pme->pmegrid_start_iy;
416     start_iz = pme->pmegrid_start_iz;
417
418     rxx = pme->recipbox[XX][XX];
419     ryx = pme->recipbox[YY][XX];
420     ryy = pme->recipbox[YY][YY];
421     rzx = pme->recipbox[ZZ][XX];
422     rzy = pme->recipbox[ZZ][YY];
423     rzz = pme->recipbox[ZZ][ZZ];
424
425     g2tx = pme->pmegrid[grid_index].g2t[XX];
426     g2ty = pme->pmegrid[grid_index].g2t[YY];
427     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
428
429     bThreads = (atc->nthread > 1);
430     if (bThreads)
431     {
432         thread_idx = atc->thread_idx;
433
434         tpl   = &atc->thread_plist[thread];
435         tpl_n = tpl->n;
436         for (i = 0; i < atc->nthread; i++)
437         {
438             tpl_n[i] = 0;
439         }
440     }
441
442     for (i = start; i < end; i++)
443     {
444         xptr   = atc->x[i];
445         idxptr = atc->idx[i];
446         fptr   = atc->fractx[i];
447
448         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
449         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
450         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
451         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
452
453         tix = (int)(tx);
454         tiy = (int)(ty);
455         tiz = (int)(tz);
456
457         /* Because decomposition only occurs in x and y,
458          * we never have a fraction correction in z.
459          */
460         fptr[XX] = tx - tix + pme->fshx[tix];
461         fptr[YY] = ty - tiy + pme->fshy[tiy];
462         fptr[ZZ] = tz - tiz;
463
464         idxptr[XX] = pme->nnx[tix];
465         idxptr[YY] = pme->nny[tiy];
466         idxptr[ZZ] = pme->nnz[tiz];
467
468 #ifdef DEBUG
469         range_check(idxptr[XX], 0, pme->pmegrid_nx);
470         range_check(idxptr[YY], 0, pme->pmegrid_ny);
471         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
472 #endif
473
474         if (bThreads)
475         {
476             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
477             thread_idx[i] = thread_i;
478             tpl_n[thread_i]++;
479         }
480     }
481
482     if (bThreads)
483     {
484         /* Make a list of particle indices sorted on thread */
485
486         /* Get the cumulative count */
487         for (i = 1; i < atc->nthread; i++)
488         {
489             tpl_n[i] += tpl_n[i-1];
490         }
491         /* The current implementation distributes particles equally
492          * over the threads, so we could actually allocate for that
493          * in pme_realloc_atomcomm_things.
494          */
495         if (tpl_n[atc->nthread-1] > tpl->nalloc)
496         {
497             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
498             srenew(tpl->i, tpl->nalloc);
499         }
500         /* Set tpl_n to the cumulative start */
501         for (i = atc->nthread-1; i >= 1; i--)
502         {
503             tpl_n[i] = tpl_n[i-1];
504         }
505         tpl_n[0] = 0;
506
507         /* Fill our thread local array with indices sorted on thread */
508         for (i = start; i < end; i++)
509         {
510             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
511         }
512         /* Now tpl_n contains the cummulative count again */
513     }
514 }
515
516 static void make_thread_local_ind(pme_atomcomm_t *atc,
517                                   int thread, splinedata_t *spline)
518 {
519     int             n, t, i, start, end;
520     thread_plist_t *tpl;
521
522     /* Combine the indices made by each thread into one index */
523
524     n     = 0;
525     start = 0;
526     for (t = 0; t < atc->nthread; t++)
527     {
528         tpl = &atc->thread_plist[t];
529         /* Copy our part (start - end) from the list of thread t */
530         if (thread > 0)
531         {
532             start = tpl->n[thread-1];
533         }
534         end = tpl->n[thread];
535         for (i = start; i < end; i++)
536         {
537             spline->ind[n++] = tpl->i[i];
538         }
539     }
540
541     spline->n = n;
542 }
543
544
545 static void pme_calc_pidx(int start, int end,
546                           matrix recipbox, rvec x[],
547                           pme_atomcomm_t *atc, int *count)
548 {
549     int   nslab, i;
550     int   si;
551     real *xptr, s;
552     real  rxx, ryx, rzx, ryy, rzy;
553     int  *pd;
554
555     /* Calculate PME task index (pidx) for each grid index.
556      * Here we always assign equally sized slabs to each node
557      * for load balancing reasons (the PME grid spacing is not used).
558      */
559
560     nslab = atc->nslab;
561     pd    = atc->pd;
562
563     /* Reset the count */
564     for (i = 0; i < nslab; i++)
565     {
566         count[i] = 0;
567     }
568
569     if (atc->dimind == 0)
570     {
571         rxx = recipbox[XX][XX];
572         ryx = recipbox[YY][XX];
573         rzx = recipbox[ZZ][XX];
574         /* Calculate the node index in x-dimension */
575         for (i = start; i < end; i++)
576         {
577             xptr   = x[i];
578             /* Fractional coordinates along box vectors */
579             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
580             si    = (int)(s + 2*nslab) % nslab;
581             pd[i] = si;
582             count[si]++;
583         }
584     }
585     else
586     {
587         ryy = recipbox[YY][YY];
588         rzy = recipbox[ZZ][YY];
589         /* Calculate the node index in y-dimension */
590         for (i = start; i < end; i++)
591         {
592             xptr   = x[i];
593             /* Fractional coordinates along box vectors */
594             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
595             si    = (int)(s + 2*nslab) % nslab;
596             pd[i] = si;
597             count[si]++;
598         }
599     }
600 }
601
602 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
603                                   pme_atomcomm_t *atc)
604 {
605     int nthread, thread, slab;
606
607     nthread = atc->nthread;
608
609 #pragma omp parallel for num_threads(nthread) schedule(static)
610     for (thread = 0; thread < nthread; thread++)
611     {
612         pme_calc_pidx(natoms* thread   /nthread,
613                       natoms*(thread+1)/nthread,
614                       recipbox, x, atc, atc->count_thread[thread]);
615     }
616     /* Non-parallel reduction, since nslab is small */
617
618     for (thread = 1; thread < nthread; thread++)
619     {
620         for (slab = 0; slab < atc->nslab; slab++)
621         {
622             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
623         }
624     }
625 }
626
627 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
628 {
629     const int padding = 4;
630     int       i;
631
632     srenew(th[XX], nalloc);
633     srenew(th[YY], nalloc);
634     /* In z we add padding, this is only required for the aligned SIMD code */
635     sfree_aligned(*ptr_z);
636     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
637     th[ZZ] = *ptr_z + padding;
638
639     for (i = 0; i < padding; i++)
640     {
641         (*ptr_z)[               i] = 0;
642         (*ptr_z)[padding+nalloc+i] = 0;
643     }
644 }
645
646 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
647 {
648     int i, d;
649
650     srenew(spline->ind, atc->nalloc);
651     /* Initialize the index to identity so it works without threads */
652     for (i = 0; i < atc->nalloc; i++)
653     {
654         spline->ind[i] = i;
655     }
656
657     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
658                       atc->pme_order*atc->nalloc);
659     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
660                       atc->pme_order*atc->nalloc);
661 }
662
663 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
664 {
665     int nalloc_old, i, j, nalloc_tpl;
666
667     /* We have to avoid a NULL pointer for atc->x to avoid
668      * possible fatal errors in MPI routines.
669      */
670     if (atc->n > atc->nalloc || atc->nalloc == 0)
671     {
672         nalloc_old  = atc->nalloc;
673         atc->nalloc = over_alloc_dd(max(atc->n, 1));
674
675         if (atc->nslab > 1)
676         {
677             srenew(atc->x, atc->nalloc);
678             srenew(atc->coefficient, atc->nalloc);
679             srenew(atc->f, atc->nalloc);
680             for (i = nalloc_old; i < atc->nalloc; i++)
681             {
682                 clear_rvec(atc->f[i]);
683             }
684         }
685         if (atc->bSpread)
686         {
687             srenew(atc->fractx, atc->nalloc);
688             srenew(atc->idx, atc->nalloc);
689
690             if (atc->nthread > 1)
691             {
692                 srenew(atc->thread_idx, atc->nalloc);
693             }
694
695             for (i = 0; i < atc->nthread; i++)
696             {
697                 pme_realloc_splinedata(&atc->spline[i], atc);
698             }
699         }
700     }
701 }
702
703 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
704                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
705                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
706                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
707 {
708 #ifdef GMX_MPI
709     int        dest, src;
710     MPI_Status stat;
711
712     if (bBackward == FALSE)
713     {
714         dest = atc->node_dest[shift];
715         src  = atc->node_src[shift];
716     }
717     else
718     {
719         dest = atc->node_src[shift];
720         src  = atc->node_dest[shift];
721     }
722
723     if (nbyte_s > 0 && nbyte_r > 0)
724     {
725         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
726                      dest, shift,
727                      buf_r, nbyte_r, MPI_BYTE,
728                      src, shift,
729                      atc->mpi_comm, &stat);
730     }
731     else if (nbyte_s > 0)
732     {
733         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
734                  dest, shift,
735                  atc->mpi_comm);
736     }
737     else if (nbyte_r > 0)
738     {
739         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
740                  src, shift,
741                  atc->mpi_comm, &stat);
742     }
743 #endif
744 }
745
746 static void dd_pmeredist_pos_coeffs(gmx_pme_t pme,
747                                     int n, gmx_bool bX, rvec *x, real *data,
748                                     pme_atomcomm_t *atc)
749 {
750     int *commnode, *buf_index;
751     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
752
753     commnode  = atc->node_dest;
754     buf_index = atc->buf_index;
755
756     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
757
758     nsend = 0;
759     for (i = 0; i < nnodes_comm; i++)
760     {
761         buf_index[commnode[i]] = nsend;
762         nsend                 += atc->count[commnode[i]];
763     }
764     if (bX)
765     {
766         if (atc->count[atc->nodeid] + nsend != n)
767         {
768             gmx_fatal(FARGS, "%d particles communicated to PME rank %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
769                       "This usually means that your system is not well equilibrated.",
770                       n - (atc->count[atc->nodeid] + nsend),
771                       pme->nodeid, 'x'+atc->dimind);
772         }
773
774         if (nsend > pme->buf_nalloc)
775         {
776             pme->buf_nalloc = over_alloc_dd(nsend);
777             srenew(pme->bufv, pme->buf_nalloc);
778             srenew(pme->bufr, pme->buf_nalloc);
779         }
780
781         atc->n = atc->count[atc->nodeid];
782         for (i = 0; i < nnodes_comm; i++)
783         {
784             scount = atc->count[commnode[i]];
785             /* Communicate the count */
786             if (debug)
787             {
788                 fprintf(debug, "dimind %d PME rank %d send to rank %d: %d\n",
789                         atc->dimind, atc->nodeid, commnode[i], scount);
790             }
791             pme_dd_sendrecv(atc, FALSE, i,
792                             &scount, sizeof(int),
793                             &atc->rcount[i], sizeof(int));
794             atc->n += atc->rcount[i];
795         }
796
797         pme_realloc_atomcomm_things(atc);
798     }
799
800     local_pos = 0;
801     for (i = 0; i < n; i++)
802     {
803         node = atc->pd[i];
804         if (node == atc->nodeid)
805         {
806             /* Copy direct to the receive buffer */
807             if (bX)
808             {
809                 copy_rvec(x[i], atc->x[local_pos]);
810             }
811             atc->coefficient[local_pos] = data[i];
812             local_pos++;
813         }
814         else
815         {
816             /* Copy to the send buffer */
817             if (bX)
818             {
819                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
820             }
821             pme->bufr[buf_index[node]] = data[i];
822             buf_index[node]++;
823         }
824     }
825
826     buf_pos = 0;
827     for (i = 0; i < nnodes_comm; i++)
828     {
829         scount = atc->count[commnode[i]];
830         rcount = atc->rcount[i];
831         if (scount > 0 || rcount > 0)
832         {
833             if (bX)
834             {
835                 /* Communicate the coordinates */
836                 pme_dd_sendrecv(atc, FALSE, i,
837                                 pme->bufv[buf_pos], scount*sizeof(rvec),
838                                 atc->x[local_pos], rcount*sizeof(rvec));
839             }
840             /* Communicate the coefficients */
841             pme_dd_sendrecv(atc, FALSE, i,
842                             pme->bufr+buf_pos, scount*sizeof(real),
843                             atc->coefficient+local_pos, rcount*sizeof(real));
844             buf_pos   += scount;
845             local_pos += atc->rcount[i];
846         }
847     }
848 }
849
850 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
851                            int n, rvec *f,
852                            gmx_bool bAddF)
853 {
854     int *commnode, *buf_index;
855     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
856
857     commnode  = atc->node_dest;
858     buf_index = atc->buf_index;
859
860     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
861
862     local_pos = atc->count[atc->nodeid];
863     buf_pos   = 0;
864     for (i = 0; i < nnodes_comm; i++)
865     {
866         scount = atc->rcount[i];
867         rcount = atc->count[commnode[i]];
868         if (scount > 0 || rcount > 0)
869         {
870             /* Communicate the forces */
871             pme_dd_sendrecv(atc, TRUE, i,
872                             atc->f[local_pos], scount*sizeof(rvec),
873                             pme->bufv[buf_pos], rcount*sizeof(rvec));
874             local_pos += scount;
875         }
876         buf_index[commnode[i]] = buf_pos;
877         buf_pos               += rcount;
878     }
879
880     local_pos = 0;
881     if (bAddF)
882     {
883         for (i = 0; i < n; i++)
884         {
885             node = atc->pd[i];
886             if (node == atc->nodeid)
887             {
888                 /* Add from the local force array */
889                 rvec_inc(f[i], atc->f[local_pos]);
890                 local_pos++;
891             }
892             else
893             {
894                 /* Add from the receive buffer */
895                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
896                 buf_index[node]++;
897             }
898         }
899     }
900     else
901     {
902         for (i = 0; i < n; i++)
903         {
904             node = atc->pd[i];
905             if (node == atc->nodeid)
906             {
907                 /* Copy from the local force array */
908                 copy_rvec(atc->f[local_pos], f[i]);
909                 local_pos++;
910             }
911             else
912             {
913                 /* Copy from the receive buffer */
914                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
915                 buf_index[node]++;
916             }
917         }
918     }
919 }
920
921 #ifdef GMX_MPI
922 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
923 {
924     pme_overlap_t *overlap;
925     int            send_index0, send_nindex;
926     int            recv_index0, recv_nindex;
927     MPI_Status     stat;
928     int            i, j, k, ix, iy, iz, icnt;
929     int            ipulse, send_id, recv_id, datasize;
930     real          *p;
931     real          *sendptr, *recvptr;
932
933     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
934     overlap = &pme->overlap[1];
935
936     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
937     {
938         /* Since we have already (un)wrapped the overlap in the z-dimension,
939          * we only have to communicate 0 to nkz (not pmegrid_nz).
940          */
941         if (direction == GMX_SUM_GRID_FORWARD)
942         {
943             send_id       = overlap->send_id[ipulse];
944             recv_id       = overlap->recv_id[ipulse];
945             send_index0   = overlap->comm_data[ipulse].send_index0;
946             send_nindex   = overlap->comm_data[ipulse].send_nindex;
947             recv_index0   = overlap->comm_data[ipulse].recv_index0;
948             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
949         }
950         else
951         {
952             send_id       = overlap->recv_id[ipulse];
953             recv_id       = overlap->send_id[ipulse];
954             send_index0   = overlap->comm_data[ipulse].recv_index0;
955             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
956             recv_index0   = overlap->comm_data[ipulse].send_index0;
957             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
958         }
959
960         /* Copy data to contiguous send buffer */
961         if (debug)
962         {
963             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
964                     pme->nodeid, overlap->nodeid, send_id,
965                     pme->pmegrid_start_iy,
966                     send_index0-pme->pmegrid_start_iy,
967                     send_index0-pme->pmegrid_start_iy+send_nindex);
968         }
969         icnt = 0;
970         for (i = 0; i < pme->pmegrid_nx; i++)
971         {
972             ix = i;
973             for (j = 0; j < send_nindex; j++)
974             {
975                 iy = j + send_index0 - pme->pmegrid_start_iy;
976                 for (k = 0; k < pme->nkz; k++)
977                 {
978                     iz = k;
979                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
980                 }
981             }
982         }
983
984         datasize      = pme->pmegrid_nx * pme->nkz;
985
986         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
987                      send_id, ipulse,
988                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
989                      recv_id, ipulse,
990                      overlap->mpi_comm, &stat);
991
992         /* Get data from contiguous recv buffer */
993         if (debug)
994         {
995             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
996                     pme->nodeid, overlap->nodeid, recv_id,
997                     pme->pmegrid_start_iy,
998                     recv_index0-pme->pmegrid_start_iy,
999                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1000         }
1001         icnt = 0;
1002         for (i = 0; i < pme->pmegrid_nx; i++)
1003         {
1004             ix = i;
1005             for (j = 0; j < recv_nindex; j++)
1006             {
1007                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1008                 for (k = 0; k < pme->nkz; k++)
1009                 {
1010                     iz = k;
1011                     if (direction == GMX_SUM_GRID_FORWARD)
1012                     {
1013                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1014                     }
1015                     else
1016                     {
1017                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1018                     }
1019                 }
1020             }
1021         }
1022     }
1023
1024     /* Major dimension is easier, no copying required,
1025      * but we might have to sum to separate array.
1026      * Since we don't copy, we have to communicate up to pmegrid_nz,
1027      * not nkz as for the minor direction.
1028      */
1029     overlap = &pme->overlap[0];
1030
1031     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1032     {
1033         if (direction == GMX_SUM_GRID_FORWARD)
1034         {
1035             send_id       = overlap->send_id[ipulse];
1036             recv_id       = overlap->recv_id[ipulse];
1037             send_index0   = overlap->comm_data[ipulse].send_index0;
1038             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1039             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1040             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1041             recvptr       = overlap->recvbuf;
1042         }
1043         else
1044         {
1045             send_id       = overlap->recv_id[ipulse];
1046             recv_id       = overlap->send_id[ipulse];
1047             send_index0   = overlap->comm_data[ipulse].recv_index0;
1048             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1049             recv_index0   = overlap->comm_data[ipulse].send_index0;
1050             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1051             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1052         }
1053
1054         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1055         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1056
1057         if (debug)
1058         {
1059             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
1060                     pme->nodeid, overlap->nodeid, send_id,
1061                     pme->pmegrid_start_ix,
1062                     send_index0-pme->pmegrid_start_ix,
1063                     send_index0-pme->pmegrid_start_ix+send_nindex);
1064             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
1065                     pme->nodeid, overlap->nodeid, recv_id,
1066                     pme->pmegrid_start_ix,
1067                     recv_index0-pme->pmegrid_start_ix,
1068                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1069         }
1070
1071         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1072                      send_id, ipulse,
1073                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1074                      recv_id, ipulse,
1075                      overlap->mpi_comm, &stat);
1076
1077         /* ADD data from contiguous recv buffer */
1078         if (direction == GMX_SUM_GRID_FORWARD)
1079         {
1080             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1081             for (i = 0; i < recv_nindex*datasize; i++)
1082             {
1083                 p[i] += overlap->recvbuf[i];
1084             }
1085         }
1086     }
1087 }
1088 #endif
1089
1090
1091 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1092 {
1093     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1094     ivec    local_pme_size;
1095     int     i, ix, iy, iz;
1096     int     pmeidx, fftidx;
1097
1098     /* Dimensions should be identical for A/B grid, so we just use A here */
1099     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1100                                    local_fft_ndata,
1101                                    local_fft_offset,
1102                                    local_fft_size);
1103
1104     local_pme_size[0] = pme->pmegrid_nx;
1105     local_pme_size[1] = pme->pmegrid_ny;
1106     local_pme_size[2] = pme->pmegrid_nz;
1107
1108     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1109        the offset is identical, and the PME grid always has more data (due to overlap)
1110      */
1111     {
1112 #ifdef DEBUG_PME
1113         FILE *fp, *fp2;
1114         char  fn[STRLEN];
1115         real  val;
1116         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1117         fp = gmx_ffopen(fn, "w");
1118         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1119         fp2 = gmx_ffopen(fn, "w");
1120 #endif
1121
1122         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1123         {
1124             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1125             {
1126                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1127                 {
1128                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1129                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1130                     fftgrid[fftidx] = pmegrid[pmeidx];
1131 #ifdef DEBUG_PME
1132                     val = 100*pmegrid[pmeidx];
1133                     if (pmegrid[pmeidx] != 0)
1134                     {
1135                         gmx_fprintf_pdb_atomline(fp, epdbATOM, pmeidx, "CA", ' ', "GLY", ' ', pmeidx, ' ',
1136                                                  5.0*ix, 5.0*iy, 5.0*iz, 1.0, val, "");
1137                     }
1138                     if (pmegrid[pmeidx] != 0)
1139                     {
1140                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1141                                 "qgrid",
1142                                 pme->pmegrid_start_ix + ix,
1143                                 pme->pmegrid_start_iy + iy,
1144                                 pme->pmegrid_start_iz + iz,
1145                                 pmegrid[pmeidx]);
1146                     }
1147 #endif
1148                 }
1149             }
1150         }
1151 #ifdef DEBUG_PME
1152         gmx_ffclose(fp);
1153         gmx_ffclose(fp2);
1154 #endif
1155     }
1156     return 0;
1157 }
1158
1159
1160 static gmx_cycles_t omp_cyc_start()
1161 {
1162     return gmx_cycles_read();
1163 }
1164
1165 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1166 {
1167     return gmx_cycles_read() - c;
1168 }
1169
1170
1171 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1172                                    int nthread, int thread)
1173 {
1174     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1175     ivec          local_pme_size;
1176     int           ixy0, ixy1, ixy, ix, iy, iz;
1177     int           pmeidx, fftidx;
1178 #ifdef PME_TIME_THREADS
1179     gmx_cycles_t  c1;
1180     static double cs1 = 0;
1181     static int    cnt = 0;
1182 #endif
1183
1184 #ifdef PME_TIME_THREADS
1185     c1 = omp_cyc_start();
1186 #endif
1187     /* Dimensions should be identical for A/B grid, so we just use A here */
1188     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1189                                    local_fft_ndata,
1190                                    local_fft_offset,
1191                                    local_fft_size);
1192
1193     local_pme_size[0] = pme->pmegrid_nx;
1194     local_pme_size[1] = pme->pmegrid_ny;
1195     local_pme_size[2] = pme->pmegrid_nz;
1196
1197     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1198        the offset is identical, and the PME grid always has more data (due to overlap)
1199      */
1200     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1201     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1202
1203     for (ixy = ixy0; ixy < ixy1; ixy++)
1204     {
1205         ix = ixy/local_fft_ndata[YY];
1206         iy = ixy - ix*local_fft_ndata[YY];
1207
1208         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1209         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1210         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1211         {
1212             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1213         }
1214     }
1215
1216 #ifdef PME_TIME_THREADS
1217     c1   = omp_cyc_end(c1);
1218     cs1 += (double)c1;
1219     cnt++;
1220     if (cnt % 20 == 0)
1221     {
1222         printf("copy %.2f\n", cs1*1e-9);
1223     }
1224 #endif
1225
1226     return 0;
1227 }
1228
1229
1230 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1231 {
1232     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1233
1234     nx = pme->nkx;
1235     ny = pme->nky;
1236     nz = pme->nkz;
1237
1238     pnx = pme->pmegrid_nx;
1239     pny = pme->pmegrid_ny;
1240     pnz = pme->pmegrid_nz;
1241
1242     overlap = pme->pme_order - 1;
1243
1244     /* Add periodic overlap in z */
1245     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1246     {
1247         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1248         {
1249             for (iz = 0; iz < overlap; iz++)
1250             {
1251                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1252                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1253             }
1254         }
1255     }
1256
1257     if (pme->nnodes_minor == 1)
1258     {
1259         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1260         {
1261             for (iy = 0; iy < overlap; iy++)
1262             {
1263                 for (iz = 0; iz < nz; iz++)
1264                 {
1265                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1266                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1267                 }
1268             }
1269         }
1270     }
1271
1272     if (pme->nnodes_major == 1)
1273     {
1274         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1275
1276         for (ix = 0; ix < overlap; ix++)
1277         {
1278             for (iy = 0; iy < ny_x; iy++)
1279             {
1280                 for (iz = 0; iz < nz; iz++)
1281                 {
1282                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1283                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1284                 }
1285             }
1286         }
1287     }
1288 }
1289
1290
1291 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1292 {
1293     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1294
1295     nx = pme->nkx;
1296     ny = pme->nky;
1297     nz = pme->nkz;
1298
1299     pnx = pme->pmegrid_nx;
1300     pny = pme->pmegrid_ny;
1301     pnz = pme->pmegrid_nz;
1302
1303     overlap = pme->pme_order - 1;
1304
1305     if (pme->nnodes_major == 1)
1306     {
1307         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1308
1309         for (ix = 0; ix < overlap; ix++)
1310         {
1311             int iy, iz;
1312
1313             for (iy = 0; iy < ny_x; iy++)
1314             {
1315                 for (iz = 0; iz < nz; iz++)
1316                 {
1317                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1318                         pmegrid[(ix*pny+iy)*pnz+iz];
1319                 }
1320             }
1321         }
1322     }
1323
1324     if (pme->nnodes_minor == 1)
1325     {
1326 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1327         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1328         {
1329             int iy, iz;
1330
1331             for (iy = 0; iy < overlap; iy++)
1332             {
1333                 for (iz = 0; iz < nz; iz++)
1334                 {
1335                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1336                         pmegrid[(ix*pny+iy)*pnz+iz];
1337                 }
1338             }
1339         }
1340     }
1341
1342     /* Copy periodic overlap in z */
1343 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1344     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1345     {
1346         int iy, iz;
1347
1348         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1349         {
1350             for (iz = 0; iz < overlap; iz++)
1351             {
1352                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1353                     pmegrid[(ix*pny+iy)*pnz+iz];
1354             }
1355         }
1356     }
1357 }
1358
1359
1360 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1361 #define DO_BSPLINE(order)                            \
1362     for (ithx = 0; (ithx < order); ithx++)                    \
1363     {                                                    \
1364         index_x = (i0+ithx)*pny*pnz;                     \
1365         valx    = coefficient*thx[ithx];                          \
1366                                                      \
1367         for (ithy = 0; (ithy < order); ithy++)                \
1368         {                                                \
1369             valxy    = valx*thy[ithy];                   \
1370             index_xy = index_x+(j0+ithy)*pnz;            \
1371                                                      \
1372             for (ithz = 0; (ithz < order); ithz++)            \
1373             {                                            \
1374                 index_xyz        = index_xy+(k0+ithz);   \
1375                 grid[index_xyz] += valxy*thz[ithz];      \
1376             }                                            \
1377         }                                                \
1378     }
1379
1380
1381 static void spread_coefficients_bsplines_thread(pmegrid_t                    *pmegrid,
1382                                                 pme_atomcomm_t               *atc,
1383                                                 splinedata_t                 *spline,
1384                                                 pme_spline_work_t gmx_unused *work)
1385 {
1386
1387     /* spread coefficients from home atoms to local grid */
1388     real          *grid;
1389     pme_overlap_t *ol;
1390     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1391     int       *    idxptr;
1392     int            order, norder, index_x, index_xy, index_xyz;
1393     real           valx, valxy, coefficient;
1394     real          *thx, *thy, *thz;
1395     int            localsize, bndsize;
1396     int            pnx, pny, pnz, ndatatot;
1397     int            offx, offy, offz;
1398
1399 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1400     real           thz_buffer[GMX_SIMD4_WIDTH*3], *thz_aligned;
1401
1402     thz_aligned = gmx_simd4_align_r(thz_buffer);
1403 #endif
1404
1405     pnx = pmegrid->s[XX];
1406     pny = pmegrid->s[YY];
1407     pnz = pmegrid->s[ZZ];
1408
1409     offx = pmegrid->offset[XX];
1410     offy = pmegrid->offset[YY];
1411     offz = pmegrid->offset[ZZ];
1412
1413     ndatatot = pnx*pny*pnz;
1414     grid     = pmegrid->grid;
1415     for (i = 0; i < ndatatot; i++)
1416     {
1417         grid[i] = 0;
1418     }
1419
1420     order = pmegrid->order;
1421
1422     for (nn = 0; nn < spline->n; nn++)
1423     {
1424         n           = spline->ind[nn];
1425         coefficient = atc->coefficient[n];
1426
1427         if (coefficient != 0)
1428         {
1429             idxptr = atc->idx[n];
1430             norder = nn*order;
1431
1432             i0   = idxptr[XX] - offx;
1433             j0   = idxptr[YY] - offy;
1434             k0   = idxptr[ZZ] - offz;
1435
1436             thx = spline->theta[XX] + norder;
1437             thy = spline->theta[YY] + norder;
1438             thz = spline->theta[ZZ] + norder;
1439
1440             switch (order)
1441             {
1442                 case 4:
1443 #ifdef PME_SIMD4_SPREAD_GATHER
1444 #ifdef PME_SIMD4_UNALIGNED
1445 #define PME_SPREAD_SIMD4_ORDER4
1446 #else
1447 #define PME_SPREAD_SIMD4_ALIGNED
1448 #define PME_ORDER 4
1449 #endif
1450 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
1451 #else
1452                     DO_BSPLINE(4);
1453 #endif
1454                     break;
1455                 case 5:
1456 #ifdef PME_SIMD4_SPREAD_GATHER
1457 #define PME_SPREAD_SIMD4_ALIGNED
1458 #define PME_ORDER 5
1459 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
1460 #else
1461                     DO_BSPLINE(5);
1462 #endif
1463                     break;
1464                 default:
1465                     DO_BSPLINE(order);
1466                     break;
1467             }
1468         }
1469     }
1470 }
1471
1472 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1473 {
1474 #ifdef PME_SIMD4_SPREAD_GATHER
1475     if (pme_order == 5
1476 #ifndef PME_SIMD4_UNALIGNED
1477         || pme_order == 4
1478 #endif
1479         )
1480     {
1481         /* Round nz up to a multiple of 4 to ensure alignment */
1482         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1483     }
1484 #endif
1485 }
1486
1487 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1488 {
1489 #ifdef PME_SIMD4_SPREAD_GATHER
1490 #ifndef PME_SIMD4_UNALIGNED
1491     if (pme_order == 4)
1492     {
1493         /* Add extra elements to ensured aligned operations do not go
1494          * beyond the allocated grid size.
1495          * Note that for pme_order=5, the pme grid z-size alignment
1496          * ensures that we will not go beyond the grid size.
1497          */
1498         *gridsize += 4;
1499     }
1500 #endif
1501 #endif
1502 }
1503
1504 static void pmegrid_init(pmegrid_t *grid,
1505                          int cx, int cy, int cz,
1506                          int x0, int y0, int z0,
1507                          int x1, int y1, int z1,
1508                          gmx_bool set_alignment,
1509                          int pme_order,
1510                          real *ptr)
1511 {
1512     int nz, gridsize;
1513
1514     grid->ci[XX]     = cx;
1515     grid->ci[YY]     = cy;
1516     grid->ci[ZZ]     = cz;
1517     grid->offset[XX] = x0;
1518     grid->offset[YY] = y0;
1519     grid->offset[ZZ] = z0;
1520     grid->n[XX]      = x1 - x0 + pme_order - 1;
1521     grid->n[YY]      = y1 - y0 + pme_order - 1;
1522     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1523     copy_ivec(grid->n, grid->s);
1524
1525     nz = grid->s[ZZ];
1526     set_grid_alignment(&nz, pme_order);
1527     if (set_alignment)
1528     {
1529         grid->s[ZZ] = nz;
1530     }
1531     else if (nz != grid->s[ZZ])
1532     {
1533         gmx_incons("pmegrid_init call with an unaligned z size");
1534     }
1535
1536     grid->order = pme_order;
1537     if (ptr == NULL)
1538     {
1539         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1540         set_gridsize_alignment(&gridsize, pme_order);
1541         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1542     }
1543     else
1544     {
1545         grid->grid = ptr;
1546     }
1547 }
1548
1549 static int div_round_up(int enumerator, int denominator)
1550 {
1551     return (enumerator + denominator - 1)/denominator;
1552 }
1553
1554 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1555                                   ivec nsub)
1556 {
1557     int gsize_opt, gsize;
1558     int nsx, nsy, nsz;
1559     char *env;
1560
1561     gsize_opt = -1;
1562     for (nsx = 1; nsx <= nthread; nsx++)
1563     {
1564         if (nthread % nsx == 0)
1565         {
1566             for (nsy = 1; nsy <= nthread; nsy++)
1567             {
1568                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1569                 {
1570                     nsz = nthread/(nsx*nsy);
1571
1572                     /* Determine the number of grid points per thread */
1573                     gsize =
1574                         (div_round_up(n[XX], nsx) + ovl)*
1575                         (div_round_up(n[YY], nsy) + ovl)*
1576                         (div_round_up(n[ZZ], nsz) + ovl);
1577
1578                     /* Minimize the number of grids points per thread
1579                      * and, secondarily, the number of cuts in minor dimensions.
1580                      */
1581                     if (gsize_opt == -1 ||
1582                         gsize < gsize_opt ||
1583                         (gsize == gsize_opt &&
1584                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1585                     {
1586                         nsub[XX]  = nsx;
1587                         nsub[YY]  = nsy;
1588                         nsub[ZZ]  = nsz;
1589                         gsize_opt = gsize;
1590                     }
1591                 }
1592             }
1593         }
1594     }
1595
1596     env = getenv("GMX_PME_THREAD_DIVISION");
1597     if (env != NULL)
1598     {
1599         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1600     }
1601
1602     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1603     {
1604         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1605     }
1606 }
1607
1608 static void pmegrids_init(pmegrids_t *grids,
1609                           int nx, int ny, int nz, int nz_base,
1610                           int pme_order,
1611                           gmx_bool bUseThreads,
1612                           int nthread,
1613                           int overlap_x,
1614                           int overlap_y)
1615 {
1616     ivec n, n_base, g0, g1;
1617     int t, x, y, z, d, i, tfac;
1618     int max_comm_lines = -1;
1619
1620     n[XX] = nx - (pme_order - 1);
1621     n[YY] = ny - (pme_order - 1);
1622     n[ZZ] = nz - (pme_order - 1);
1623
1624     copy_ivec(n, n_base);
1625     n_base[ZZ] = nz_base;
1626
1627     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1628                  NULL);
1629
1630     grids->nthread = nthread;
1631
1632     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1633
1634     if (bUseThreads)
1635     {
1636         ivec nst;
1637         int gridsize;
1638
1639         for (d = 0; d < DIM; d++)
1640         {
1641             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1642         }
1643         set_grid_alignment(&nst[ZZ], pme_order);
1644
1645         if (debug)
1646         {
1647             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1648                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1649             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1650                     nx, ny, nz,
1651                     nst[XX], nst[YY], nst[ZZ]);
1652         }
1653
1654         snew(grids->grid_th, grids->nthread);
1655         t        = 0;
1656         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1657         set_gridsize_alignment(&gridsize, pme_order);
1658         snew_aligned(grids->grid_all,
1659                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1660                      SIMD4_ALIGNMENT);
1661
1662         for (x = 0; x < grids->nc[XX]; x++)
1663         {
1664             for (y = 0; y < grids->nc[YY]; y++)
1665             {
1666                 for (z = 0; z < grids->nc[ZZ]; z++)
1667                 {
1668                     pmegrid_init(&grids->grid_th[t],
1669                                  x, y, z,
1670                                  (n[XX]*(x  ))/grids->nc[XX],
1671                                  (n[YY]*(y  ))/grids->nc[YY],
1672                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1673                                  (n[XX]*(x+1))/grids->nc[XX],
1674                                  (n[YY]*(y+1))/grids->nc[YY],
1675                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1676                                  TRUE,
1677                                  pme_order,
1678                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1679                     t++;
1680                 }
1681             }
1682         }
1683     }
1684     else
1685     {
1686         grids->grid_th = NULL;
1687     }
1688
1689     snew(grids->g2t, DIM);
1690     tfac = 1;
1691     for (d = DIM-1; d >= 0; d--)
1692     {
1693         snew(grids->g2t[d], n[d]);
1694         t = 0;
1695         for (i = 0; i < n[d]; i++)
1696         {
1697             /* The second check should match the parameters
1698              * of the pmegrid_init call above.
1699              */
1700             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1701             {
1702                 t++;
1703             }
1704             grids->g2t[d][i] = t*tfac;
1705         }
1706
1707         tfac *= grids->nc[d];
1708
1709         switch (d)
1710         {
1711             case XX: max_comm_lines = overlap_x;     break;
1712             case YY: max_comm_lines = overlap_y;     break;
1713             case ZZ: max_comm_lines = pme_order - 1; break;
1714         }
1715         grids->nthread_comm[d] = 0;
1716         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1717                grids->nthread_comm[d] < grids->nc[d])
1718         {
1719             grids->nthread_comm[d]++;
1720         }
1721         if (debug != NULL)
1722         {
1723             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1724                     'x'+d, grids->nthread_comm[d]);
1725         }
1726         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1727          * work, but this is not a problematic restriction.
1728          */
1729         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1730         {
1731             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1732         }
1733     }
1734 }
1735
1736
1737 static void pmegrids_destroy(pmegrids_t *grids)
1738 {
1739     int t;
1740
1741     if (grids->grid.grid != NULL)
1742     {
1743         sfree(grids->grid.grid);
1744
1745         if (grids->nthread > 0)
1746         {
1747             for (t = 0; t < grids->nthread; t++)
1748             {
1749                 sfree(grids->grid_th[t].grid);
1750             }
1751             sfree(grids->grid_th);
1752         }
1753     }
1754 }
1755
1756
1757 static void realloc_work(pme_work_t *work, int nkx)
1758 {
1759     int simd_width;
1760
1761     if (nkx > work->nalloc)
1762     {
1763         work->nalloc = nkx;
1764         srenew(work->mhx, work->nalloc);
1765         srenew(work->mhy, work->nalloc);
1766         srenew(work->mhz, work->nalloc);
1767         srenew(work->m2, work->nalloc);
1768         /* Allocate an aligned pointer for SIMD operations, including extra
1769          * elements at the end for padding.
1770          */
1771 #ifdef PME_SIMD_SOLVE
1772         simd_width = GMX_SIMD_REAL_WIDTH;
1773 #else
1774         /* We can use any alignment, apart from 0, so we use 4 */
1775         simd_width = 4;
1776 #endif
1777         sfree_aligned(work->denom);
1778         sfree_aligned(work->tmp1);
1779         sfree_aligned(work->tmp2);
1780         sfree_aligned(work->eterm);
1781         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1782         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1783         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1784         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1785         srenew(work->m2inv, work->nalloc);
1786     }
1787 }
1788
1789
1790 static void free_work(pme_work_t *work)
1791 {
1792     sfree(work->mhx);
1793     sfree(work->mhy);
1794     sfree(work->mhz);
1795     sfree(work->m2);
1796     sfree_aligned(work->denom);
1797     sfree_aligned(work->tmp1);
1798     sfree_aligned(work->tmp2);
1799     sfree_aligned(work->eterm);
1800     sfree(work->m2inv);
1801 }
1802
1803
1804 #if defined PME_SIMD_SOLVE
1805 /* Calculate exponentials through SIMD */
1806 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1807 {
1808     {
1809         const gmx_simd_real_t two = gmx_simd_set1_r(2.0);
1810         gmx_simd_real_t f_simd;
1811         gmx_simd_real_t lu;
1812         gmx_simd_real_t tmp_d1, d_inv, tmp_r, tmp_e;
1813         int kx;
1814         f_simd = gmx_simd_set1_r(f);
1815         /* We only need to calculate from start. But since start is 0 or 1
1816          * and we want to use aligned loads/stores, we always start from 0.
1817          */
1818         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1819         {
1820             tmp_d1   = gmx_simd_load_r(d_aligned+kx);
1821             d_inv    = gmx_simd_inv_r(tmp_d1);
1822             tmp_r    = gmx_simd_load_r(r_aligned+kx);
1823             tmp_r    = gmx_simd_exp_r(tmp_r);
1824             tmp_e    = gmx_simd_mul_r(f_simd, d_inv);
1825             tmp_e    = gmx_simd_mul_r(tmp_e, tmp_r);
1826             gmx_simd_store_r(e_aligned+kx, tmp_e);
1827         }
1828     }
1829 }
1830 #else
1831 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1832 {
1833     int kx;
1834     for (kx = start; kx < end; kx++)
1835     {
1836         d[kx] = 1.0/d[kx];
1837     }
1838     for (kx = start; kx < end; kx++)
1839     {
1840         r[kx] = exp(r[kx]);
1841     }
1842     for (kx = start; kx < end; kx++)
1843     {
1844         e[kx] = f*r[kx]*d[kx];
1845     }
1846 }
1847 #endif
1848
1849 #if defined PME_SIMD_SOLVE
1850 /* Calculate exponentials through SIMD */
1851 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1852 {
1853     gmx_simd_real_t tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1854     const gmx_simd_real_t sqr_PI = gmx_simd_sqrt_r(gmx_simd_set1_r(M_PI));
1855     int kx;
1856     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1857     {
1858         /* We only need to calculate from start. But since start is 0 or 1
1859          * and we want to use aligned loads/stores, we always start from 0.
1860          */
1861         tmp_d = gmx_simd_load_r(d_aligned+kx);
1862         d_inv = gmx_simd_inv_r(tmp_d);
1863         gmx_simd_store_r(d_aligned+kx, d_inv);
1864         tmp_r = gmx_simd_load_r(r_aligned+kx);
1865         tmp_r = gmx_simd_exp_r(tmp_r);
1866         gmx_simd_store_r(r_aligned+kx, tmp_r);
1867         tmp_mk  = gmx_simd_load_r(factor_aligned+kx);
1868         tmp_fac = gmx_simd_mul_r(sqr_PI, gmx_simd_mul_r(tmp_mk, gmx_simd_erfc_r(tmp_mk)));
1869         gmx_simd_store_r(factor_aligned+kx, tmp_fac);
1870     }
1871 }
1872 #else
1873 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1874 {
1875     int kx;
1876     real mk;
1877     for (kx = start; kx < end; kx++)
1878     {
1879         d[kx] = 1.0/d[kx];
1880     }
1881
1882     for (kx = start; kx < end; kx++)
1883     {
1884         r[kx] = exp(r[kx]);
1885     }
1886
1887     for (kx = start; kx < end; kx++)
1888     {
1889         mk       = tmp2[kx];
1890         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1891     }
1892 }
1893 #endif
1894
1895 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1896                          real ewaldcoeff, real vol,
1897                          gmx_bool bEnerVir,
1898                          int nthread, int thread)
1899 {
1900     /* do recip sum over local cells in grid */
1901     /* y major, z middle, x minor or continuous */
1902     t_complex *p0;
1903     int     kx, ky, kz, maxkx, maxky, maxkz;
1904     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1905     real    mx, my, mz;
1906     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1907     real    ets2, struct2, vfactor, ets2vf;
1908     real    d1, d2, energy = 0;
1909     real    by, bz;
1910     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1911     real    rxx, ryx, ryy, rzx, rzy, rzz;
1912     pme_work_t *work;
1913     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1914     real    mhxk, mhyk, mhzk, m2k;
1915     real    corner_fac;
1916     ivec    complex_order;
1917     ivec    local_ndata, local_offset, local_size;
1918     real    elfac;
1919
1920     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1921
1922     nx = pme->nkx;
1923     ny = pme->nky;
1924     nz = pme->nkz;
1925
1926     /* Dimensions should be identical for A/B grid, so we just use A here */
1927     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
1928                                       complex_order,
1929                                       local_ndata,
1930                                       local_offset,
1931                                       local_size);
1932
1933     rxx = pme->recipbox[XX][XX];
1934     ryx = pme->recipbox[YY][XX];
1935     ryy = pme->recipbox[YY][YY];
1936     rzx = pme->recipbox[ZZ][XX];
1937     rzy = pme->recipbox[ZZ][YY];
1938     rzz = pme->recipbox[ZZ][ZZ];
1939
1940     maxkx = (nx+1)/2;
1941     maxky = (ny+1)/2;
1942     maxkz = nz/2+1;
1943
1944     work  = &pme->work[thread];
1945     mhx   = work->mhx;
1946     mhy   = work->mhy;
1947     mhz   = work->mhz;
1948     m2    = work->m2;
1949     denom = work->denom;
1950     tmp1  = work->tmp1;
1951     eterm = work->eterm;
1952     m2inv = work->m2inv;
1953
1954     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1955     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1956
1957     for (iyz = iyz0; iyz < iyz1; iyz++)
1958     {
1959         iy = iyz/local_ndata[ZZ];
1960         iz = iyz - iy*local_ndata[ZZ];
1961
1962         ky = iy + local_offset[YY];
1963
1964         if (ky < maxky)
1965         {
1966             my = ky;
1967         }
1968         else
1969         {
1970             my = (ky - ny);
1971         }
1972
1973         by = M_PI*vol*pme->bsp_mod[YY][ky];
1974
1975         kz = iz + local_offset[ZZ];
1976
1977         mz = kz;
1978
1979         bz = pme->bsp_mod[ZZ][kz];
1980
1981         /* 0.5 correction for corner points */
1982         corner_fac = 1;
1983         if (kz == 0 || kz == (nz+1)/2)
1984         {
1985             corner_fac = 0.5;
1986         }
1987
1988         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1989
1990         /* We should skip the k-space point (0,0,0) */
1991         /* Note that since here x is the minor index, local_offset[XX]=0 */
1992         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1993         {
1994             kxstart = local_offset[XX];
1995         }
1996         else
1997         {
1998             kxstart = local_offset[XX] + 1;
1999             p0++;
2000         }
2001         kxend = local_offset[XX] + local_ndata[XX];
2002
2003         if (bEnerVir)
2004         {
2005             /* More expensive inner loop, especially because of the storage
2006              * of the mh elements in array's.
2007              * Because x is the minor grid index, all mh elements
2008              * depend on kx for triclinic unit cells.
2009              */
2010
2011             /* Two explicit loops to avoid a conditional inside the loop */
2012             for (kx = kxstart; kx < maxkx; kx++)
2013             {
2014                 mx = kx;
2015
2016                 mhxk      = mx * rxx;
2017                 mhyk      = mx * ryx + my * ryy;
2018                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2019                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2020                 mhx[kx]   = mhxk;
2021                 mhy[kx]   = mhyk;
2022                 mhz[kx]   = mhzk;
2023                 m2[kx]    = m2k;
2024                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2025                 tmp1[kx]  = -factor*m2k;
2026             }
2027
2028             for (kx = maxkx; kx < kxend; kx++)
2029             {
2030                 mx = (kx - nx);
2031
2032                 mhxk      = mx * rxx;
2033                 mhyk      = mx * ryx + my * ryy;
2034                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2035                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2036                 mhx[kx]   = mhxk;
2037                 mhy[kx]   = mhyk;
2038                 mhz[kx]   = mhzk;
2039                 m2[kx]    = m2k;
2040                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2041                 tmp1[kx]  = -factor*m2k;
2042             }
2043
2044             for (kx = kxstart; kx < kxend; kx++)
2045             {
2046                 m2inv[kx] = 1.0/m2[kx];
2047             }
2048
2049             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2050
2051             for (kx = kxstart; kx < kxend; kx++, p0++)
2052             {
2053                 d1      = p0->re;
2054                 d2      = p0->im;
2055
2056                 p0->re  = d1*eterm[kx];
2057                 p0->im  = d2*eterm[kx];
2058
2059                 struct2 = 2.0*(d1*d1+d2*d2);
2060
2061                 tmp1[kx] = eterm[kx]*struct2;
2062             }
2063
2064             for (kx = kxstart; kx < kxend; kx++)
2065             {
2066                 ets2     = corner_fac*tmp1[kx];
2067                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2068                 energy  += ets2;
2069
2070                 ets2vf   = ets2*vfactor;
2071                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2072                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2073                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2074                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2075                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2076                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2077             }
2078         }
2079         else
2080         {
2081             /* We don't need to calculate the energy and the virial.
2082              * In this case the triclinic overhead is small.
2083              */
2084
2085             /* Two explicit loops to avoid a conditional inside the loop */
2086
2087             for (kx = kxstart; kx < maxkx; kx++)
2088             {
2089                 mx = kx;
2090
2091                 mhxk      = mx * rxx;
2092                 mhyk      = mx * ryx + my * ryy;
2093                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2094                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2095                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2096                 tmp1[kx]  = -factor*m2k;
2097             }
2098
2099             for (kx = maxkx; kx < kxend; kx++)
2100             {
2101                 mx = (kx - nx);
2102
2103                 mhxk      = mx * rxx;
2104                 mhyk      = mx * ryx + my * ryy;
2105                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2106                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2107                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2108                 tmp1[kx]  = -factor*m2k;
2109             }
2110
2111             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2112
2113             for (kx = kxstart; kx < kxend; kx++, p0++)
2114             {
2115                 d1      = p0->re;
2116                 d2      = p0->im;
2117
2118                 p0->re  = d1*eterm[kx];
2119                 p0->im  = d2*eterm[kx];
2120             }
2121         }
2122     }
2123
2124     if (bEnerVir)
2125     {
2126         /* Update virial with local values.
2127          * The virial is symmetric by definition.
2128          * this virial seems ok for isotropic scaling, but I'm
2129          * experiencing problems on semiisotropic membranes.
2130          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2131          */
2132         work->vir_q[XX][XX] = 0.25*virxx;
2133         work->vir_q[YY][YY] = 0.25*viryy;
2134         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2135         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2136         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2137         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2138
2139         /* This energy should be corrected for a charged system */
2140         work->energy_q = 0.5*energy;
2141     }
2142
2143     /* Return the loop count */
2144     return local_ndata[YY]*local_ndata[XX];
2145 }
2146
2147 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2148                             real ewaldcoeff, real vol,
2149                             gmx_bool bEnerVir, int nthread, int thread)
2150 {
2151     /* do recip sum over local cells in grid */
2152     /* y major, z middle, x minor or continuous */
2153     int     ig, gcount;
2154     int     kx, ky, kz, maxkx, maxky, maxkz;
2155     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2156     real    mx, my, mz;
2157     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2158     real    ets2, ets2vf;
2159     real    eterm, vterm, d1, d2, energy = 0;
2160     real    by, bz;
2161     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2162     real    rxx, ryx, ryy, rzx, rzy, rzz;
2163     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2164     real    mhxk, mhyk, mhzk, m2k;
2165     real    mk;
2166     pme_work_t *work;
2167     real    corner_fac;
2168     ivec    complex_order;
2169     ivec    local_ndata, local_offset, local_size;
2170     nx = pme->nkx;
2171     ny = pme->nky;
2172     nz = pme->nkz;
2173
2174     /* Dimensions should be identical for A/B grid, so we just use A here */
2175     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2176                                       complex_order,
2177                                       local_ndata,
2178                                       local_offset,
2179                                       local_size);
2180     rxx = pme->recipbox[XX][XX];
2181     ryx = pme->recipbox[YY][XX];
2182     ryy = pme->recipbox[YY][YY];
2183     rzx = pme->recipbox[ZZ][XX];
2184     rzy = pme->recipbox[ZZ][YY];
2185     rzz = pme->recipbox[ZZ][ZZ];
2186
2187     maxkx = (nx+1)/2;
2188     maxky = (ny+1)/2;
2189     maxkz = nz/2+1;
2190
2191     work  = &pme->work[thread];
2192     mhx   = work->mhx;
2193     mhy   = work->mhy;
2194     mhz   = work->mhz;
2195     m2    = work->m2;
2196     denom = work->denom;
2197     tmp1  = work->tmp1;
2198     tmp2  = work->tmp2;
2199
2200     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2201     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2202
2203     for (iyz = iyz0; iyz < iyz1; iyz++)
2204     {
2205         iy = iyz/local_ndata[ZZ];
2206         iz = iyz - iy*local_ndata[ZZ];
2207
2208         ky = iy + local_offset[YY];
2209
2210         if (ky < maxky)
2211         {
2212             my = ky;
2213         }
2214         else
2215         {
2216             my = (ky - ny);
2217         }
2218
2219         by = 3.0*vol*pme->bsp_mod[YY][ky]
2220             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2221
2222         kz = iz + local_offset[ZZ];
2223
2224         mz = kz;
2225
2226         bz = pme->bsp_mod[ZZ][kz];
2227
2228         /* 0.5 correction for corner points */
2229         corner_fac = 1;
2230         if (kz == 0 || kz == (nz+1)/2)
2231         {
2232             corner_fac = 0.5;
2233         }
2234
2235         kxstart = local_offset[XX];
2236         kxend   = local_offset[XX] + local_ndata[XX];
2237         if (bEnerVir)
2238         {
2239             /* More expensive inner loop, especially because of the
2240              * storage of the mh elements in array's.  Because x is the
2241              * minor grid index, all mh elements depend on kx for
2242              * triclinic unit cells.
2243              */
2244
2245             /* Two explicit loops to avoid a conditional inside the loop */
2246             for (kx = kxstart; kx < maxkx; kx++)
2247             {
2248                 mx = kx;
2249
2250                 mhxk      = mx * rxx;
2251                 mhyk      = mx * ryx + my * ryy;
2252                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2253                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2254                 mhx[kx]   = mhxk;
2255                 mhy[kx]   = mhyk;
2256                 mhz[kx]   = mhzk;
2257                 m2[kx]    = m2k;
2258                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2259                 tmp1[kx]  = -factor*m2k;
2260                 tmp2[kx]  = sqrt(factor*m2k);
2261             }
2262
2263             for (kx = maxkx; kx < kxend; kx++)
2264             {
2265                 mx = (kx - nx);
2266
2267                 mhxk      = mx * rxx;
2268                 mhyk      = mx * ryx + my * ryy;
2269                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2270                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2271                 mhx[kx]   = mhxk;
2272                 mhy[kx]   = mhyk;
2273                 mhz[kx]   = mhzk;
2274                 m2[kx]    = m2k;
2275                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2276                 tmp1[kx]  = -factor*m2k;
2277                 tmp2[kx]  = sqrt(factor*m2k);
2278             }
2279
2280             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2281
2282             for (kx = kxstart; kx < kxend; kx++)
2283             {
2284                 m2k   = factor*m2[kx];
2285                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2286                           + 2.0*m2k*tmp2[kx]);
2287                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2288                 tmp1[kx] = eterm*denom[kx];
2289                 tmp2[kx] = vterm*denom[kx];
2290             }
2291
2292             if (!bLB)
2293             {
2294                 t_complex *p0;
2295                 real       struct2;
2296
2297                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2298                 for (kx = kxstart; kx < kxend; kx++, p0++)
2299                 {
2300                     d1      = p0->re;
2301                     d2      = p0->im;
2302
2303                     eterm   = tmp1[kx];
2304                     vterm   = tmp2[kx];
2305                     p0->re  = d1*eterm;
2306                     p0->im  = d2*eterm;
2307
2308                     struct2 = 2.0*(d1*d1+d2*d2);
2309
2310                     tmp1[kx] = eterm*struct2;
2311                     tmp2[kx] = vterm*struct2;
2312                 }
2313             }
2314             else
2315             {
2316                 real *struct2 = denom;
2317                 real  str2;
2318
2319                 for (kx = kxstart; kx < kxend; kx++)
2320                 {
2321                     struct2[kx] = 0.0;
2322                 }
2323                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2324                 for (ig = 0; ig <= 3; ++ig)
2325                 {
2326                     t_complex *p0, *p1;
2327                     real       scale;
2328
2329                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2330                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2331                     scale = 2.0*lb_scale_factor_symm[ig];
2332                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2333                     {
2334                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2335                     }
2336
2337                 }
2338                 for (ig = 0; ig <= 6; ++ig)
2339                 {
2340                     t_complex *p0;
2341
2342                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2343                     for (kx = kxstart; kx < kxend; kx++, p0++)
2344                     {
2345                         d1     = p0->re;
2346                         d2     = p0->im;
2347
2348                         eterm  = tmp1[kx];
2349                         p0->re = d1*eterm;
2350                         p0->im = d2*eterm;
2351                     }
2352                 }
2353                 for (kx = kxstart; kx < kxend; kx++)
2354                 {
2355                     eterm    = tmp1[kx];
2356                     vterm    = tmp2[kx];
2357                     str2     = struct2[kx];
2358                     tmp1[kx] = eterm*str2;
2359                     tmp2[kx] = vterm*str2;
2360                 }
2361             }
2362
2363             for (kx = kxstart; kx < kxend; kx++)
2364             {
2365                 ets2     = corner_fac*tmp1[kx];
2366                 vterm    = 2.0*factor*tmp2[kx];
2367                 energy  += ets2;
2368                 ets2vf   = corner_fac*vterm;
2369                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2370                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2371                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2372                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2373                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2374                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2375             }
2376         }
2377         else
2378         {
2379             /* We don't need to calculate the energy and the virial.
2380              *  In this case the triclinic overhead is small.
2381              */
2382
2383             /* Two explicit loops to avoid a conditional inside the loop */
2384
2385             for (kx = kxstart; kx < maxkx; kx++)
2386             {
2387                 mx = kx;
2388
2389                 mhxk      = mx * rxx;
2390                 mhyk      = mx * ryx + my * ryy;
2391                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2392                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2393                 m2[kx]    = m2k;
2394                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2395                 tmp1[kx]  = -factor*m2k;
2396                 tmp2[kx]  = sqrt(factor*m2k);
2397             }
2398
2399             for (kx = maxkx; kx < kxend; kx++)
2400             {
2401                 mx = (kx - nx);
2402
2403                 mhxk      = mx * rxx;
2404                 mhyk      = mx * ryx + my * ryy;
2405                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2406                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2407                 m2[kx]    = m2k;
2408                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2409                 tmp1[kx]  = -factor*m2k;
2410                 tmp2[kx]  = sqrt(factor*m2k);
2411             }
2412
2413             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2414
2415             for (kx = kxstart; kx < kxend; kx++)
2416             {
2417                 m2k    = factor*m2[kx];
2418                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2419                            + 2.0*m2k*tmp2[kx]);
2420                 tmp1[kx] = eterm*denom[kx];
2421             }
2422             gcount = (bLB ? 7 : 1);
2423             for (ig = 0; ig < gcount; ++ig)
2424             {
2425                 t_complex *p0;
2426
2427                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2428                 for (kx = kxstart; kx < kxend; kx++, p0++)
2429                 {
2430                     d1      = p0->re;
2431                     d2      = p0->im;
2432
2433                     eterm   = tmp1[kx];
2434
2435                     p0->re  = d1*eterm;
2436                     p0->im  = d2*eterm;
2437                 }
2438             }
2439         }
2440     }
2441     if (bEnerVir)
2442     {
2443         work->vir_lj[XX][XX] = 0.25*virxx;
2444         work->vir_lj[YY][YY] = 0.25*viryy;
2445         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2446         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2447         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2448         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2449
2450         /* This energy should be corrected for a charged system */
2451         work->energy_lj = 0.5*energy;
2452     }
2453     /* Return the loop count */
2454     return local_ndata[YY]*local_ndata[XX];
2455 }
2456
2457 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2458                                real *mesh_energy, matrix vir)
2459 {
2460     /* This function sums output over threads and should therefore
2461      * only be called after thread synchronization.
2462      */
2463     int thread;
2464
2465     *mesh_energy = pme->work[0].energy_q;
2466     copy_mat(pme->work[0].vir_q, vir);
2467
2468     for (thread = 1; thread < nthread; thread++)
2469     {
2470         *mesh_energy += pme->work[thread].energy_q;
2471         m_add(vir, pme->work[thread].vir_q, vir);
2472     }
2473 }
2474
2475 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2476                                 real *mesh_energy, matrix vir)
2477 {
2478     /* This function sums output over threads and should therefore
2479      * only be called after thread synchronization.
2480      */
2481     int thread;
2482
2483     *mesh_energy = pme->work[0].energy_lj;
2484     copy_mat(pme->work[0].vir_lj, vir);
2485
2486     for (thread = 1; thread < nthread; thread++)
2487     {
2488         *mesh_energy += pme->work[thread].energy_lj;
2489         m_add(vir, pme->work[thread].vir_lj, vir);
2490     }
2491 }
2492
2493
2494 #define DO_FSPLINE(order)                      \
2495     for (ithx = 0; (ithx < order); ithx++)              \
2496     {                                              \
2497         index_x = (i0+ithx)*pny*pnz;               \
2498         tx      = thx[ithx];                       \
2499         dx      = dthx[ithx];                      \
2500                                                \
2501         for (ithy = 0; (ithy < order); ithy++)          \
2502         {                                          \
2503             index_xy = index_x+(j0+ithy)*pnz;      \
2504             ty       = thy[ithy];                  \
2505             dy       = dthy[ithy];                 \
2506             fxy1     = fz1 = 0;                    \
2507                                                \
2508             for (ithz = 0; (ithz < order); ithz++)      \
2509             {                                      \
2510                 gval  = grid[index_xy+(k0+ithz)];  \
2511                 fxy1 += thz[ithz]*gval;            \
2512                 fz1  += dthz[ithz]*gval;           \
2513             }                                      \
2514             fx += dx*ty*fxy1;                      \
2515             fy += tx*dy*fxy1;                      \
2516             fz += tx*ty*fz1;                       \
2517         }                                          \
2518     }
2519
2520
2521 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2522                               gmx_bool bClearF, pme_atomcomm_t *atc,
2523                               splinedata_t *spline,
2524                               real scale)
2525 {
2526     /* sum forces for local particles */
2527     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2528     int     index_x, index_xy;
2529     int     nx, ny, nz, pnx, pny, pnz;
2530     int *   idxptr;
2531     real    tx, ty, dx, dy, coefficient;
2532     real    fx, fy, fz, gval;
2533     real    fxy1, fz1;
2534     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2535     int     norder;
2536     real    rxx, ryx, ryy, rzx, rzy, rzz;
2537     int     order;
2538
2539     pme_spline_work_t *work;
2540
2541 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2542     real           thz_buffer[GMX_SIMD4_WIDTH*3],  *thz_aligned;
2543     real           dthz_buffer[GMX_SIMD4_WIDTH*3], *dthz_aligned;
2544
2545     thz_aligned  = gmx_simd4_align_r(thz_buffer);
2546     dthz_aligned = gmx_simd4_align_r(dthz_buffer);
2547 #endif
2548
2549     work = pme->spline_work;
2550
2551     order = pme->pme_order;
2552     thx   = spline->theta[XX];
2553     thy   = spline->theta[YY];
2554     thz   = spline->theta[ZZ];
2555     dthx  = spline->dtheta[XX];
2556     dthy  = spline->dtheta[YY];
2557     dthz  = spline->dtheta[ZZ];
2558     nx    = pme->nkx;
2559     ny    = pme->nky;
2560     nz    = pme->nkz;
2561     pnx   = pme->pmegrid_nx;
2562     pny   = pme->pmegrid_ny;
2563     pnz   = pme->pmegrid_nz;
2564
2565     rxx   = pme->recipbox[XX][XX];
2566     ryx   = pme->recipbox[YY][XX];
2567     ryy   = pme->recipbox[YY][YY];
2568     rzx   = pme->recipbox[ZZ][XX];
2569     rzy   = pme->recipbox[ZZ][YY];
2570     rzz   = pme->recipbox[ZZ][ZZ];
2571
2572     for (nn = 0; nn < spline->n; nn++)
2573     {
2574         n           = spline->ind[nn];
2575         coefficient = scale*atc->coefficient[n];
2576
2577         if (bClearF)
2578         {
2579             atc->f[n][XX] = 0;
2580             atc->f[n][YY] = 0;
2581             atc->f[n][ZZ] = 0;
2582         }
2583         if (coefficient != 0)
2584         {
2585             fx     = 0;
2586             fy     = 0;
2587             fz     = 0;
2588             idxptr = atc->idx[n];
2589             norder = nn*order;
2590
2591             i0   = idxptr[XX];
2592             j0   = idxptr[YY];
2593             k0   = idxptr[ZZ];
2594
2595             /* Pointer arithmetic alert, next six statements */
2596             thx  = spline->theta[XX] + norder;
2597             thy  = spline->theta[YY] + norder;
2598             thz  = spline->theta[ZZ] + norder;
2599             dthx = spline->dtheta[XX] + norder;
2600             dthy = spline->dtheta[YY] + norder;
2601             dthz = spline->dtheta[ZZ] + norder;
2602
2603             switch (order)
2604             {
2605                 case 4:
2606 #ifdef PME_SIMD4_SPREAD_GATHER
2607 #ifdef PME_SIMD4_UNALIGNED
2608 #define PME_GATHER_F_SIMD4_ORDER4
2609 #else
2610 #define PME_GATHER_F_SIMD4_ALIGNED
2611 #define PME_ORDER 4
2612 #endif
2613 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
2614 #else
2615                     DO_FSPLINE(4);
2616 #endif
2617                     break;
2618                 case 5:
2619 #ifdef PME_SIMD4_SPREAD_GATHER
2620 #define PME_GATHER_F_SIMD4_ALIGNED
2621 #define PME_ORDER 5
2622 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
2623 #else
2624                     DO_FSPLINE(5);
2625 #endif
2626                     break;
2627                 default:
2628                     DO_FSPLINE(order);
2629                     break;
2630             }
2631
2632             atc->f[n][XX] += -coefficient*( fx*nx*rxx );
2633             atc->f[n][YY] += -coefficient*( fx*nx*ryx + fy*ny*ryy );
2634             atc->f[n][ZZ] += -coefficient*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2635         }
2636     }
2637     /* Since the energy and not forces are interpolated
2638      * the net force might not be exactly zero.
2639      * This can be solved by also interpolating F, but
2640      * that comes at a cost.
2641      * A better hack is to remove the net force every
2642      * step, but that must be done at a higher level
2643      * since this routine doesn't see all atoms if running
2644      * in parallel. Don't know how important it is?  EL 990726
2645      */
2646 }
2647
2648
2649 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2650                                    pme_atomcomm_t *atc)
2651 {
2652     splinedata_t *spline;
2653     int     n, ithx, ithy, ithz, i0, j0, k0;
2654     int     index_x, index_xy;
2655     int *   idxptr;
2656     real    energy, pot, tx, ty, coefficient, gval;
2657     real    *thx, *thy, *thz;
2658     int     norder;
2659     int     order;
2660
2661     spline = &atc->spline[0];
2662
2663     order = pme->pme_order;
2664
2665     energy = 0;
2666     for (n = 0; (n < atc->n); n++)
2667     {
2668         coefficient      = atc->coefficient[n];
2669
2670         if (coefficient != 0)
2671         {
2672             idxptr = atc->idx[n];
2673             norder = n*order;
2674
2675             i0   = idxptr[XX];
2676             j0   = idxptr[YY];
2677             k0   = idxptr[ZZ];
2678
2679             /* Pointer arithmetic alert, next three statements */
2680             thx  = spline->theta[XX] + norder;
2681             thy  = spline->theta[YY] + norder;
2682             thz  = spline->theta[ZZ] + norder;
2683
2684             pot = 0;
2685             for (ithx = 0; (ithx < order); ithx++)
2686             {
2687                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2688                 tx      = thx[ithx];
2689
2690                 for (ithy = 0; (ithy < order); ithy++)
2691                 {
2692                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2693                     ty       = thy[ithy];
2694
2695                     for (ithz = 0; (ithz < order); ithz++)
2696                     {
2697                         gval  = grid[index_xy+(k0+ithz)];
2698                         pot  += tx*ty*thz[ithz]*gval;
2699                     }
2700
2701                 }
2702             }
2703
2704             energy += pot*coefficient;
2705         }
2706     }
2707
2708     return energy;
2709 }
2710
2711 /* Macro to force loop unrolling by fixing order.
2712  * This gives a significant performance gain.
2713  */
2714 #define CALC_SPLINE(order)                     \
2715     {                                              \
2716         int j, k, l;                                 \
2717         real dr, div;                               \
2718         real data[PME_ORDER_MAX];                  \
2719         real ddata[PME_ORDER_MAX];                 \
2720                                                \
2721         for (j = 0; (j < DIM); j++)                     \
2722         {                                          \
2723             dr  = xptr[j];                         \
2724                                                \
2725             /* dr is relative offset from lower cell limit */ \
2726             data[order-1] = 0;                     \
2727             data[1]       = dr;                          \
2728             data[0]       = 1 - dr;                      \
2729                                                \
2730             for (k = 3; (k < order); k++)               \
2731             {                                      \
2732                 div       = 1.0/(k - 1.0);               \
2733                 data[k-1] = div*dr*data[k-2];      \
2734                 for (l = 1; (l < (k-1)); l++)           \
2735                 {                                  \
2736                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2737                                        data[k-l-1]);                \
2738                 }                                  \
2739                 data[0] = div*(1-dr)*data[0];      \
2740             }                                      \
2741             /* differentiate */                    \
2742             ddata[0] = -data[0];                   \
2743             for (k = 1; (k < order); k++)               \
2744             {                                      \
2745                 ddata[k] = data[k-1] - data[k];    \
2746             }                                      \
2747                                                \
2748             div           = 1.0/(order - 1);                 \
2749             data[order-1] = div*dr*data[order-2];  \
2750             for (l = 1; (l < (order-1)); l++)           \
2751             {                                      \
2752                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2753                                        (order-l-dr)*data[order-l-1]); \
2754             }                                      \
2755             data[0] = div*(1 - dr)*data[0];        \
2756                                                \
2757             for (k = 0; k < order; k++)                 \
2758             {                                      \
2759                 theta[j][i*order+k]  = data[k];    \
2760                 dtheta[j][i*order+k] = ddata[k];   \
2761             }                                      \
2762         }                                          \
2763     }
2764
2765 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2766                    rvec fractx[], int nr, int ind[], real coefficient[],
2767                    gmx_bool bDoSplines)
2768 {
2769     /* construct splines for local atoms */
2770     int  i, ii;
2771     real *xptr;
2772
2773     for (i = 0; i < nr; i++)
2774     {
2775         /* With free energy we do not use the coefficient check.
2776          * In most cases this will be more efficient than calling make_bsplines
2777          * twice, since usually more than half the particles have non-zero coefficients.
2778          */
2779         ii = ind[i];
2780         if (bDoSplines || coefficient[ii] != 0.0)
2781         {
2782             xptr = fractx[ii];
2783             switch (order)
2784             {
2785                 case 4:  CALC_SPLINE(4);     break;
2786                 case 5:  CALC_SPLINE(5);     break;
2787                 default: CALC_SPLINE(order); break;
2788             }
2789         }
2790     }
2791 }
2792
2793
2794 void make_dft_mod(real *mod, real *data, int ndata)
2795 {
2796     int i, j;
2797     real sc, ss, arg;
2798
2799     for (i = 0; i < ndata; i++)
2800     {
2801         sc = ss = 0;
2802         for (j = 0; j < ndata; j++)
2803         {
2804             arg = (2.0*M_PI*i*j)/ndata;
2805             sc += data[j]*cos(arg);
2806             ss += data[j]*sin(arg);
2807         }
2808         mod[i] = sc*sc+ss*ss;
2809     }
2810     for (i = 0; i < ndata; i++)
2811     {
2812         if (mod[i] < 1e-7)
2813         {
2814             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2815         }
2816     }
2817 }
2818
2819
2820 static void make_bspline_moduli(splinevec bsp_mod,
2821                                 int nx, int ny, int nz, int order)
2822 {
2823     int nmax = max(nx, max(ny, nz));
2824     real *data, *ddata, *bsp_data;
2825     int i, k, l;
2826     real div;
2827
2828     snew(data, order);
2829     snew(ddata, order);
2830     snew(bsp_data, nmax);
2831
2832     data[order-1] = 0;
2833     data[1]       = 0;
2834     data[0]       = 1;
2835
2836     for (k = 3; k < order; k++)
2837     {
2838         div       = 1.0/(k-1.0);
2839         data[k-1] = 0;
2840         for (l = 1; l < (k-1); l++)
2841         {
2842             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2843         }
2844         data[0] = div*data[0];
2845     }
2846     /* differentiate */
2847     ddata[0] = -data[0];
2848     for (k = 1; k < order; k++)
2849     {
2850         ddata[k] = data[k-1]-data[k];
2851     }
2852     div           = 1.0/(order-1);
2853     data[order-1] = 0;
2854     for (l = 1; l < (order-1); l++)
2855     {
2856         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2857     }
2858     data[0] = div*data[0];
2859
2860     for (i = 0; i < nmax; i++)
2861     {
2862         bsp_data[i] = 0;
2863     }
2864     for (i = 1; i <= order; i++)
2865     {
2866         bsp_data[i] = data[i-1];
2867     }
2868
2869     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2870     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2871     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2872
2873     sfree(data);
2874     sfree(ddata);
2875     sfree(bsp_data);
2876 }
2877
2878
2879 /* Return the P3M optimal influence function */
2880 static double do_p3m_influence(double z, int order)
2881 {
2882     double z2, z4;
2883
2884     z2 = z*z;
2885     z4 = z2*z2;
2886
2887     /* The formula and most constants can be found in:
2888      * Ballenegger et al., JCTC 8, 936 (2012)
2889      */
2890     switch (order)
2891     {
2892         case 2:
2893             return 1.0 - 2.0*z2/3.0;
2894             break;
2895         case 3:
2896             return 1.0 - z2 + 2.0*z4/15.0;
2897             break;
2898         case 4:
2899             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2900             break;
2901         case 5:
2902             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2903             break;
2904         case 6:
2905             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2906             break;
2907         case 7:
2908             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2909         case 8:
2910             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2911             break;
2912     }
2913
2914     return 0.0;
2915 }
2916
2917 /* Calculate the P3M B-spline moduli for one dimension */
2918 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2919 {
2920     double zarg, zai, sinzai, infl;
2921     int    maxk, i;
2922
2923     if (order > 8)
2924     {
2925         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2926     }
2927
2928     zarg = M_PI/n;
2929
2930     maxk = (n + 1)/2;
2931
2932     for (i = -maxk; i < 0; i++)
2933     {
2934         zai          = zarg*i;
2935         sinzai       = sin(zai);
2936         infl         = do_p3m_influence(sinzai, order);
2937         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2938     }
2939     bsp_mod[0] = 1.0;
2940     for (i = 1; i < maxk; i++)
2941     {
2942         zai        = zarg*i;
2943         sinzai     = sin(zai);
2944         infl       = do_p3m_influence(sinzai, order);
2945         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2946     }
2947 }
2948
2949 /* Calculate the P3M B-spline moduli */
2950 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2951                                     int nx, int ny, int nz, int order)
2952 {
2953     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2954     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2955     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2956 }
2957
2958
2959 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2960 {
2961     int nslab, n, i;
2962     int fw, bw;
2963
2964     nslab = atc->nslab;
2965
2966     n = 0;
2967     for (i = 1; i <= nslab/2; i++)
2968     {
2969         fw = (atc->nodeid + i) % nslab;
2970         bw = (atc->nodeid - i + nslab) % nslab;
2971         if (n < nslab - 1)
2972         {
2973             atc->node_dest[n] = fw;
2974             atc->node_src[n]  = bw;
2975             n++;
2976         }
2977         if (n < nslab - 1)
2978         {
2979             atc->node_dest[n] = bw;
2980             atc->node_src[n]  = fw;
2981             n++;
2982         }
2983     }
2984 }
2985
2986 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2987 {
2988     int thread, i;
2989
2990     if (NULL != log)
2991     {
2992         fprintf(log, "Destroying PME data structures.\n");
2993     }
2994
2995     sfree((*pmedata)->nnx);
2996     sfree((*pmedata)->nny);
2997     sfree((*pmedata)->nnz);
2998
2999     for (i = 0; i < (*pmedata)->ngrids; ++i)
3000     {
3001         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
3002         sfree((*pmedata)->fftgrid[i]);
3003         sfree((*pmedata)->cfftgrid[i]);
3004         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
3005     }
3006
3007     sfree((*pmedata)->lb_buf1);
3008     sfree((*pmedata)->lb_buf2);
3009
3010     for (thread = 0; thread < (*pmedata)->nthread; thread++)
3011     {
3012         free_work(&(*pmedata)->work[thread]);
3013     }
3014     sfree((*pmedata)->work);
3015
3016     sfree(*pmedata);
3017     *pmedata = NULL;
3018
3019     return 0;
3020 }
3021
3022 static int mult_up(int n, int f)
3023 {
3024     return ((n + f - 1)/f)*f;
3025 }
3026
3027
3028 static double pme_load_imbalance(gmx_pme_t pme)
3029 {
3030     int    nma, nmi;
3031     double n1, n2, n3;
3032
3033     nma = pme->nnodes_major;
3034     nmi = pme->nnodes_minor;
3035
3036     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3037     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3038     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3039
3040     /* pme_solve is roughly double the cost of an fft */
3041
3042     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3043 }
3044
3045 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3046                           int dimind, gmx_bool bSpread)
3047 {
3048     int nk, k, s, thread;
3049
3050     atc->dimind    = dimind;
3051     atc->nslab     = 1;
3052     atc->nodeid    = 0;
3053     atc->pd_nalloc = 0;
3054 #ifdef GMX_MPI
3055     if (pme->nnodes > 1)
3056     {
3057         atc->mpi_comm = pme->mpi_comm_d[dimind];
3058         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3059         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3060     }
3061     if (debug)
3062     {
3063         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3064     }
3065 #endif
3066
3067     atc->bSpread   = bSpread;
3068     atc->pme_order = pme->pme_order;
3069
3070     if (atc->nslab > 1)
3071     {
3072         snew(atc->node_dest, atc->nslab);
3073         snew(atc->node_src, atc->nslab);
3074         setup_coordinate_communication(atc);
3075
3076         snew(atc->count_thread, pme->nthread);
3077         for (thread = 0; thread < pme->nthread; thread++)
3078         {
3079             snew(atc->count_thread[thread], atc->nslab);
3080         }
3081         atc->count = atc->count_thread[0];
3082         snew(atc->rcount, atc->nslab);
3083         snew(atc->buf_index, atc->nslab);
3084     }
3085
3086     atc->nthread = pme->nthread;
3087     if (atc->nthread > 1)
3088     {
3089         snew(atc->thread_plist, atc->nthread);
3090     }
3091     snew(atc->spline, atc->nthread);
3092     for (thread = 0; thread < atc->nthread; thread++)
3093     {
3094         if (atc->nthread > 1)
3095         {
3096             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3097             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3098         }
3099         snew(atc->spline[thread].thread_one, pme->nthread);
3100         atc->spline[thread].thread_one[thread] = 1;
3101     }
3102 }
3103
3104 static void
3105 init_overlap_comm(pme_overlap_t *  ol,
3106                   int              norder,
3107 #ifdef GMX_MPI
3108                   MPI_Comm         comm,
3109 #endif
3110                   int              nnodes,
3111                   int              nodeid,
3112                   int              ndata,
3113                   int              commplainsize)
3114 {
3115     int lbnd, rbnd, maxlr, b, i;
3116     int exten;
3117     int nn, nk;
3118     pme_grid_comm_t *pgc;
3119     gmx_bool bCont;
3120     int fft_start, fft_end, send_index1, recv_index1;
3121 #ifdef GMX_MPI
3122     MPI_Status stat;
3123
3124     ol->mpi_comm = comm;
3125 #endif
3126
3127     ol->nnodes = nnodes;
3128     ol->nodeid = nodeid;
3129
3130     /* Linear translation of the PME grid won't affect reciprocal space
3131      * calculations, so to optimize we only interpolate "upwards",
3132      * which also means we only have to consider overlap in one direction.
3133      * I.e., particles on this node might also be spread to grid indices
3134      * that belong to higher nodes (modulo nnodes)
3135      */
3136
3137     snew(ol->s2g0, ol->nnodes+1);
3138     snew(ol->s2g1, ol->nnodes);
3139     if (debug)
3140     {
3141         fprintf(debug, "PME slab boundaries:");
3142     }
3143     for (i = 0; i < nnodes; i++)
3144     {
3145         /* s2g0 the local interpolation grid start.
3146          * s2g1 the local interpolation grid end.
3147          * Since in calc_pidx we divide particles, and not grid lines,
3148          * spatially uniform along dimension x or y, we need to round
3149          * s2g0 down and s2g1 up.
3150          */
3151         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3152         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3153
3154         if (debug)
3155         {
3156             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3157         }
3158     }
3159     ol->s2g0[nnodes] = ndata;
3160     if (debug)
3161     {
3162         fprintf(debug, "\n");
3163     }
3164
3165     /* Determine with how many nodes we need to communicate the grid overlap */
3166     b = 0;
3167     do
3168     {
3169         b++;
3170         bCont = FALSE;
3171         for (i = 0; i < nnodes; i++)
3172         {
3173             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3174                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3175             {
3176                 bCont = TRUE;
3177             }
3178         }
3179     }
3180     while (bCont && b < nnodes);
3181     ol->noverlap_nodes = b - 1;
3182
3183     snew(ol->send_id, ol->noverlap_nodes);
3184     snew(ol->recv_id, ol->noverlap_nodes);
3185     for (b = 0; b < ol->noverlap_nodes; b++)
3186     {
3187         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3188         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3189     }
3190     snew(ol->comm_data, ol->noverlap_nodes);
3191
3192     ol->send_size = 0;
3193     for (b = 0; b < ol->noverlap_nodes; b++)
3194     {
3195         pgc = &ol->comm_data[b];
3196         /* Send */
3197         fft_start        = ol->s2g0[ol->send_id[b]];
3198         fft_end          = ol->s2g0[ol->send_id[b]+1];
3199         if (ol->send_id[b] < nodeid)
3200         {
3201             fft_start += ndata;
3202             fft_end   += ndata;
3203         }
3204         send_index1       = ol->s2g1[nodeid];
3205         send_index1       = min(send_index1, fft_end);
3206         pgc->send_index0  = fft_start;
3207         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3208         ol->send_size    += pgc->send_nindex;
3209
3210         /* We always start receiving to the first index of our slab */
3211         fft_start        = ol->s2g0[ol->nodeid];
3212         fft_end          = ol->s2g0[ol->nodeid+1];
3213         recv_index1      = ol->s2g1[ol->recv_id[b]];
3214         if (ol->recv_id[b] > nodeid)
3215         {
3216             recv_index1 -= ndata;
3217         }
3218         recv_index1      = min(recv_index1, fft_end);
3219         pgc->recv_index0 = fft_start;
3220         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3221     }
3222
3223 #ifdef GMX_MPI
3224     /* Communicate the buffer sizes to receive */
3225     for (b = 0; b < ol->noverlap_nodes; b++)
3226     {
3227         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3228                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3229                      ol->mpi_comm, &stat);
3230     }
3231 #endif
3232
3233     /* For non-divisible grid we need pme_order iso pme_order-1 */
3234     snew(ol->sendbuf, norder*commplainsize);
3235     snew(ol->recvbuf, norder*commplainsize);
3236 }
3237
3238 static void
3239 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3240                               int **global_to_local,
3241                               real **fraction_shift)
3242 {
3243     int i;
3244     int * gtl;
3245     real * fsh;
3246
3247     snew(gtl, 5*n);
3248     snew(fsh, 5*n);
3249     for (i = 0; (i < 5*n); i++)
3250     {
3251         /* Determine the global to local grid index */
3252         gtl[i] = (i - local_start + n) % n;
3253         /* For coordinates that fall within the local grid the fraction
3254          * is correct, we don't need to shift it.
3255          */
3256         fsh[i] = 0;
3257         if (local_range < n)
3258         {
3259             /* Due to rounding issues i could be 1 beyond the lower or
3260              * upper boundary of the local grid. Correct the index for this.
3261              * If we shift the index, we need to shift the fraction by
3262              * the same amount in the other direction to not affect
3263              * the weights.
3264              * Note that due to this shifting the weights at the end of
3265              * the spline might change, but that will only involve values
3266              * between zero and values close to the precision of a real,
3267              * which is anyhow the accuracy of the whole mesh calculation.
3268              */
3269             /* With local_range=0 we should not change i=local_start */
3270             if (i % n != local_start)
3271             {
3272                 if (gtl[i] == n-1)
3273                 {
3274                     gtl[i] = 0;
3275                     fsh[i] = -1;
3276                 }
3277                 else if (gtl[i] == local_range)
3278                 {
3279                     gtl[i] = local_range - 1;
3280                     fsh[i] = 1;
3281                 }
3282             }
3283         }
3284     }
3285
3286     *global_to_local = gtl;
3287     *fraction_shift  = fsh;
3288 }
3289
3290 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3291 {
3292     pme_spline_work_t *work;
3293
3294 #ifdef PME_SIMD4_SPREAD_GATHER
3295     real             tmp[GMX_SIMD4_WIDTH*3], *tmp_aligned;
3296     gmx_simd4_real_t zero_S;
3297     gmx_simd4_real_t real_mask_S0, real_mask_S1;
3298     int              of, i;
3299
3300     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3301
3302     tmp_aligned = gmx_simd4_align_r(tmp);
3303
3304     zero_S = gmx_simd4_setzero_r();
3305
3306     /* Generate bit masks to mask out the unused grid entries,
3307      * as we only operate on order of the 8 grid entries that are
3308      * load into 2 SIMD registers.
3309      */
3310     for (of = 0; of < 2*GMX_SIMD4_WIDTH-(order-1); of++)
3311     {
3312         for (i = 0; i < 2*GMX_SIMD4_WIDTH; i++)
3313         {
3314             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3315         }
3316         real_mask_S0      = gmx_simd4_load_r(tmp_aligned);
3317         real_mask_S1      = gmx_simd4_load_r(tmp_aligned+GMX_SIMD4_WIDTH);
3318         work->mask_S0[of] = gmx_simd4_cmplt_r(real_mask_S0, zero_S);
3319         work->mask_S1[of] = gmx_simd4_cmplt_r(real_mask_S1, zero_S);
3320     }
3321 #else
3322     work = NULL;
3323 #endif
3324
3325     return work;
3326 }
3327
3328 void gmx_pme_check_restrictions(int pme_order,
3329                                 int nkx, int nky, int nkz,
3330                                 int nnodes_major,
3331                                 int nnodes_minor,
3332                                 gmx_bool bUseThreads,
3333                                 gmx_bool bFatal,
3334                                 gmx_bool *bValidSettings)
3335 {
3336     if (pme_order > PME_ORDER_MAX)
3337     {
3338         if (!bFatal)
3339         {
3340             *bValidSettings = FALSE;
3341             return;
3342         }
3343         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3344                   pme_order, PME_ORDER_MAX);
3345     }
3346
3347     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3348         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3349         nkz <= pme_order)
3350     {
3351         if (!bFatal)
3352         {
3353             *bValidSettings = FALSE;
3354             return;
3355         }
3356         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3357                   pme_order);
3358     }
3359
3360     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3361      * We only allow multiple communication pulses in dim 1, not in dim 0.
3362      */
3363     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3364                         nkx != nnodes_major*(pme_order - 1)))
3365     {
3366         if (!bFatal)
3367         {
3368             *bValidSettings = FALSE;
3369             return;
3370         }
3371         gmx_fatal(FARGS, "The number of PME grid lines per rank along x is %g. But when using OpenMP threads, the number of grid lines per rank along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use fewer ranks along x (and possibly more along y and/or z) by specifying -dd manually.",
3372                   nkx/(double)nnodes_major, pme_order);
3373     }
3374
3375     if (bValidSettings != NULL)
3376     {
3377         *bValidSettings = TRUE;
3378     }
3379
3380     return;
3381 }
3382
3383 int gmx_pme_init(gmx_pme_t *         pmedata,
3384                  t_commrec *         cr,
3385                  int                 nnodes_major,
3386                  int                 nnodes_minor,
3387                  t_inputrec *        ir,
3388                  int                 homenr,
3389                  gmx_bool            bFreeEnergy_q,
3390                  gmx_bool            bFreeEnergy_lj,
3391                  gmx_bool            bReproducible,
3392                  int                 nthread)
3393 {
3394     gmx_pme_t pme = NULL;
3395
3396     int  use_threads, sum_use_threads, i;
3397     ivec ndata;
3398
3399     if (debug)
3400     {
3401         fprintf(debug, "Creating PME data structures.\n");
3402     }
3403     snew(pme, 1);
3404
3405     pme->sum_qgrid_tmp       = NULL;
3406     pme->sum_qgrid_dd_tmp    = NULL;
3407     pme->buf_nalloc          = 0;
3408
3409     pme->nnodes              = 1;
3410     pme->bPPnode             = TRUE;
3411
3412     pme->nnodes_major        = nnodes_major;
3413     pme->nnodes_minor        = nnodes_minor;
3414
3415 #ifdef GMX_MPI
3416     if (nnodes_major*nnodes_minor > 1)
3417     {
3418         pme->mpi_comm = cr->mpi_comm_mygroup;
3419
3420         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3421         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3422         if (pme->nnodes != nnodes_major*nnodes_minor)
3423         {
3424             gmx_incons("PME rank count mismatch");
3425         }
3426     }
3427     else
3428     {
3429         pme->mpi_comm = MPI_COMM_NULL;
3430     }
3431 #endif
3432
3433     if (pme->nnodes == 1)
3434     {
3435 #ifdef GMX_MPI
3436         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3437         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3438 #endif
3439         pme->ndecompdim   = 0;
3440         pme->nodeid_major = 0;
3441         pme->nodeid_minor = 0;
3442 #ifdef GMX_MPI
3443         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3444 #endif
3445     }
3446     else
3447     {
3448         if (nnodes_minor == 1)
3449         {
3450 #ifdef GMX_MPI
3451             pme->mpi_comm_d[0] = pme->mpi_comm;
3452             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3453 #endif
3454             pme->ndecompdim   = 1;
3455             pme->nodeid_major = pme->nodeid;
3456             pme->nodeid_minor = 0;
3457
3458         }
3459         else if (nnodes_major == 1)
3460         {
3461 #ifdef GMX_MPI
3462             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3463             pme->mpi_comm_d[1] = pme->mpi_comm;
3464 #endif
3465             pme->ndecompdim   = 1;
3466             pme->nodeid_major = 0;
3467             pme->nodeid_minor = pme->nodeid;
3468         }
3469         else
3470         {
3471             if (pme->nnodes % nnodes_major != 0)
3472             {
3473                 gmx_incons("For 2D PME decomposition, #PME ranks must be divisible by the number of ranks in the major dimension");
3474             }
3475             pme->ndecompdim = 2;
3476
3477 #ifdef GMX_MPI
3478             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3479                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3480             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3481                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3482
3483             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3484             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3485             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3486             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3487 #endif
3488         }
3489         pme->bPPnode = (cr->duty & DUTY_PP);
3490     }
3491
3492     pme->nthread = nthread;
3493
3494     /* Check if any of the PME MPI ranks uses threads */
3495     use_threads = (pme->nthread > 1 ? 1 : 0);
3496 #ifdef GMX_MPI
3497     if (pme->nnodes > 1)
3498     {
3499         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3500                       MPI_SUM, pme->mpi_comm);
3501     }
3502     else
3503 #endif
3504     {
3505         sum_use_threads = use_threads;
3506     }
3507     pme->bUseThreads = (sum_use_threads > 0);
3508
3509     if (ir->ePBC == epbcSCREW)
3510     {
3511         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3512     }
3513
3514     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3515     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3516     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3517     pme->nkx         = ir->nkx;
3518     pme->nky         = ir->nky;
3519     pme->nkz         = ir->nkz;
3520     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3521     pme->pme_order   = ir->pme_order;
3522
3523     /* Always constant electrostatics coefficients */
3524     pme->epsilon_r   = ir->epsilon_r;
3525
3526     /* Always constant LJ coefficients */
3527     pme->ljpme_combination_rule = ir->ljpme_combination_rule;
3528
3529     /* If we violate restrictions, generate a fatal error here */
3530     gmx_pme_check_restrictions(pme->pme_order,
3531                                pme->nkx, pme->nky, pme->nkz,
3532                                pme->nnodes_major,
3533                                pme->nnodes_minor,
3534                                pme->bUseThreads,
3535                                TRUE,
3536                                NULL);
3537
3538     if (pme->nnodes > 1)
3539     {
3540         double imbal;
3541
3542 #ifdef GMX_MPI
3543         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3544         MPI_Type_commit(&(pme->rvec_mpi));
3545 #endif
3546
3547         /* Note that the coefficient spreading and force gathering, which usually
3548          * takes about the same amount of time as FFT+solve_pme,
3549          * is always fully load balanced
3550          * (unless the coefficient distribution is inhomogeneous).
3551          */
3552
3553         imbal = pme_load_imbalance(pme);
3554         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3555         {
3556             fprintf(stderr,
3557                     "\n"
3558                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3559                     "      For optimal PME load balancing\n"
3560                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_ranks_x (%d)\n"
3561                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_ranks_y (%d)\n"
3562                     "\n",
3563                     (int)((imbal-1)*100 + 0.5),
3564                     pme->nkx, pme->nky, pme->nnodes_major,
3565                     pme->nky, pme->nkz, pme->nnodes_minor);
3566         }
3567     }
3568
3569     /* For non-divisible grid we need pme_order iso pme_order-1 */
3570     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3571      * y is always copied through a buffer: we don't need padding in z,
3572      * but we do need the overlap in x because of the communication order.
3573      */
3574     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3575 #ifdef GMX_MPI
3576                       pme->mpi_comm_d[0],
3577 #endif
3578                       pme->nnodes_major, pme->nodeid_major,
3579                       pme->nkx,
3580                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3581
3582     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3583      * We do this with an offset buffer of equal size, so we need to allocate
3584      * extra for the offset. That's what the (+1)*pme->nkz is for.
3585      */
3586     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3587 #ifdef GMX_MPI
3588                       pme->mpi_comm_d[1],
3589 #endif
3590                       pme->nnodes_minor, pme->nodeid_minor,
3591                       pme->nky,
3592                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3593
3594     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3595      * Note that gmx_pme_check_restrictions checked for this already.
3596      */
3597     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3598     {
3599         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3600     }
3601
3602     snew(pme->bsp_mod[XX], pme->nkx);
3603     snew(pme->bsp_mod[YY], pme->nky);
3604     snew(pme->bsp_mod[ZZ], pme->nkz);
3605
3606     /* The required size of the interpolation grid, including overlap.
3607      * The allocated size (pmegrid_n?) might be slightly larger.
3608      */
3609     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3610         pme->overlap[0].s2g0[pme->nodeid_major];
3611     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3612         pme->overlap[1].s2g0[pme->nodeid_minor];
3613     pme->pmegrid_nz_base = pme->nkz;
3614     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3615     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3616
3617     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3618     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3619     pme->pmegrid_start_iz = 0;
3620
3621     make_gridindex5_to_localindex(pme->nkx,
3622                                   pme->pmegrid_start_ix,
3623                                   pme->pmegrid_nx - (pme->pme_order-1),
3624                                   &pme->nnx, &pme->fshx);
3625     make_gridindex5_to_localindex(pme->nky,
3626                                   pme->pmegrid_start_iy,
3627                                   pme->pmegrid_ny - (pme->pme_order-1),
3628                                   &pme->nny, &pme->fshy);
3629     make_gridindex5_to_localindex(pme->nkz,
3630                                   pme->pmegrid_start_iz,
3631                                   pme->pmegrid_nz_base,
3632                                   &pme->nnz, &pme->fshz);
3633
3634     pme->spline_work = make_pme_spline_work(pme->pme_order);
3635
3636     ndata[0]    = pme->nkx;
3637     ndata[1]    = pme->nky;
3638     ndata[2]    = pme->nkz;
3639     /* It doesn't matter if we allocate too many grids here,
3640      * we only allocate and use the ones we need.
3641      */
3642     if (EVDW_PME(ir->vdwtype))
3643     {
3644         pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3645     }
3646     else
3647     {
3648         pme->ngrids = DO_Q;
3649     }
3650     snew(pme->fftgrid, pme->ngrids);
3651     snew(pme->cfftgrid, pme->ngrids);
3652     snew(pme->pfft_setup, pme->ngrids);
3653
3654     for (i = 0; i < pme->ngrids; ++i)
3655     {
3656         if ((i <  DO_Q && EEL_PME(ir->coulombtype) && (i == 0 ||
3657                                                        bFreeEnergy_q)) ||
3658             (i >= DO_Q && EVDW_PME(ir->vdwtype) && (i == 2 ||
3659                                                     bFreeEnergy_lj ||
3660                                                     ir->ljpme_combination_rule == eljpmeLB)))
3661         {
3662             pmegrids_init(&pme->pmegrid[i],
3663                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3664                           pme->pmegrid_nz_base,
3665                           pme->pme_order,
3666                           pme->bUseThreads,
3667                           pme->nthread,
3668                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3669                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3670             /* This routine will allocate the grid data to fit the FFTs */
3671             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3672                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3673                                     pme->mpi_comm_d,
3674                                     bReproducible, pme->nthread);
3675
3676         }
3677     }
3678
3679     if (!pme->bP3M)
3680     {
3681         /* Use plain SPME B-spline interpolation */
3682         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3683     }
3684     else
3685     {
3686         /* Use the P3M grid-optimized influence function */
3687         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3688     }
3689
3690     /* Use atc[0] for spreading */
3691     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3692     if (pme->ndecompdim >= 2)
3693     {
3694         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3695     }
3696
3697     if (pme->nnodes == 1)
3698     {
3699         pme->atc[0].n = homenr;
3700         pme_realloc_atomcomm_things(&pme->atc[0]);
3701     }
3702
3703     pme->lb_buf1       = NULL;
3704     pme->lb_buf2       = NULL;
3705     pme->lb_buf_nalloc = 0;
3706
3707     {
3708         int thread;
3709
3710         /* Use fft5d, order after FFT is y major, z, x minor */
3711
3712         snew(pme->work, pme->nthread);
3713         for (thread = 0; thread < pme->nthread; thread++)
3714         {
3715             realloc_work(&pme->work[thread], pme->nkx);
3716         }
3717     }
3718
3719     *pmedata = pme;
3720
3721     return 0;
3722 }
3723
3724 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3725 {
3726     int d, t;
3727
3728     for (d = 0; d < DIM; d++)
3729     {
3730         if (new->grid.n[d] > old->grid.n[d])
3731         {
3732             return;
3733         }
3734     }
3735
3736     sfree_aligned(new->grid.grid);
3737     new->grid.grid = old->grid.grid;
3738
3739     if (new->grid_th != NULL && new->nthread == old->nthread)
3740     {
3741         sfree_aligned(new->grid_all);
3742         for (t = 0; t < new->nthread; t++)
3743         {
3744             new->grid_th[t].grid = old->grid_th[t].grid;
3745         }
3746     }
3747 }
3748
3749 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3750                    t_commrec *         cr,
3751                    gmx_pme_t           pme_src,
3752                    const t_inputrec *  ir,
3753                    ivec                grid_size)
3754 {
3755     t_inputrec irc;
3756     int homenr;
3757     int ret;
3758
3759     irc     = *ir;
3760     irc.nkx = grid_size[XX];
3761     irc.nky = grid_size[YY];
3762     irc.nkz = grid_size[ZZ];
3763
3764     if (pme_src->nnodes == 1)
3765     {
3766         homenr = pme_src->atc[0].n;
3767     }
3768     else
3769     {
3770         homenr = -1;
3771     }
3772
3773     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3774                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3775
3776     if (ret == 0)
3777     {
3778         /* We can easily reuse the allocated pme grids in pme_src */
3779         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3780         /* We would like to reuse the fft grids, but that's harder */
3781     }
3782
3783     return ret;
3784 }
3785
3786
3787 static void copy_local_grid(gmx_pme_t pme, pmegrids_t *pmegrids,
3788                             int grid_index, int thread, real *fftgrid)
3789 {
3790     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3791     int  fft_my, fft_mz;
3792     int  nsx, nsy, nsz;
3793     ivec nf;
3794     int  offx, offy, offz, x, y, z, i0, i0t;
3795     int  d;
3796     pmegrid_t *pmegrid;
3797     real *grid_th;
3798
3799     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3800                                    local_fft_ndata,
3801                                    local_fft_offset,
3802                                    local_fft_size);
3803     fft_my = local_fft_size[YY];
3804     fft_mz = local_fft_size[ZZ];
3805
3806     pmegrid = &pmegrids->grid_th[thread];
3807
3808     nsx = pmegrid->s[XX];
3809     nsy = pmegrid->s[YY];
3810     nsz = pmegrid->s[ZZ];
3811
3812     for (d = 0; d < DIM; d++)
3813     {
3814         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3815                     local_fft_ndata[d] - pmegrid->offset[d]);
3816     }
3817
3818     offx = pmegrid->offset[XX];
3819     offy = pmegrid->offset[YY];
3820     offz = pmegrid->offset[ZZ];
3821
3822     /* Directly copy the non-overlapping parts of the local grids.
3823      * This also initializes the full grid.
3824      */
3825     grid_th = pmegrid->grid;
3826     for (x = 0; x < nf[XX]; x++)
3827     {
3828         for (y = 0; y < nf[YY]; y++)
3829         {
3830             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3831             i0t = (x*nsy + y)*nsz;
3832             for (z = 0; z < nf[ZZ]; z++)
3833             {
3834                 fftgrid[i0+z] = grid_th[i0t+z];
3835             }
3836         }
3837     }
3838 }
3839
3840 static void
3841 reduce_threadgrid_overlap(gmx_pme_t pme,
3842                           const pmegrids_t *pmegrids, int thread,
3843                           real *fftgrid, real *commbuf_x, real *commbuf_y,
3844                           int grid_index)
3845 {
3846     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3847     int  fft_nx, fft_ny, fft_nz;
3848     int  fft_my, fft_mz;
3849     int  buf_my = -1;
3850     int  nsx, nsy, nsz;
3851     ivec localcopy_end;
3852     int  offx, offy, offz, x, y, z, i0, i0t;
3853     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3854     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3855     gmx_bool bCommX, bCommY;
3856     int  d;
3857     int  thread_f;
3858     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3859     const real *grid_th;
3860     real *commbuf = NULL;
3861
3862     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3863                                    local_fft_ndata,
3864                                    local_fft_offset,
3865                                    local_fft_size);
3866     fft_nx = local_fft_ndata[XX];
3867     fft_ny = local_fft_ndata[YY];
3868     fft_nz = local_fft_ndata[ZZ];
3869
3870     fft_my = local_fft_size[YY];
3871     fft_mz = local_fft_size[ZZ];
3872
3873     /* This routine is called when all thread have finished spreading.
3874      * Here each thread sums grid contributions calculated by other threads
3875      * to the thread local grid volume.
3876      * To minimize the number of grid copying operations,
3877      * this routines sums immediately from the pmegrid to the fftgrid.
3878      */
3879
3880     /* Determine which part of the full node grid we should operate on,
3881      * this is our thread local part of the full grid.
3882      */
3883     pmegrid = &pmegrids->grid_th[thread];
3884
3885     for (d = 0; d < DIM; d++)
3886     {
3887         /* Determine up to where our thread needs to copy from the
3888          * thread-local charge spreading grid to the rank-local FFT grid.
3889          * This is up to our spreading grid end minus order-1 and
3890          * not beyond the local FFT grid.
3891          */
3892         localcopy_end[d] =
3893             min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3894                 local_fft_ndata[d]);
3895     }
3896
3897     offx = pmegrid->offset[XX];
3898     offy = pmegrid->offset[YY];
3899     offz = pmegrid->offset[ZZ];
3900
3901
3902     bClearBufX  = TRUE;
3903     bClearBufY  = TRUE;
3904     bClearBufXY = TRUE;
3905
3906     /* Now loop over all the thread data blocks that contribute
3907      * to the grid region we (our thread) are operating on.
3908      */
3909     /* Note that fft_nx/y is equal to the number of grid points
3910      * between the first point of our node grid and the one of the next node.
3911      */
3912     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3913     {
3914         fx     = pmegrid->ci[XX] + sx;
3915         ox     = 0;
3916         bCommX = FALSE;
3917         if (fx < 0)
3918         {
3919             fx    += pmegrids->nc[XX];
3920             ox    -= fft_nx;
3921             bCommX = (pme->nnodes_major > 1);
3922         }
3923         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3924         ox       += pmegrid_g->offset[XX];
3925         /* Determine the end of our part of the source grid */
3926         if (!bCommX)
3927         {
3928             /* Use our thread local source grid and target grid part */
3929             tx1 = min(ox + pmegrid_g->n[XX], localcopy_end[XX]);
3930         }
3931         else
3932         {
3933             /* Use our thread local source grid and the spreading range */
3934             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3935         }
3936
3937         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3938         {
3939             fy     = pmegrid->ci[YY] + sy;
3940             oy     = 0;
3941             bCommY = FALSE;
3942             if (fy < 0)
3943             {
3944                 fy    += pmegrids->nc[YY];
3945                 oy    -= fft_ny;
3946                 bCommY = (pme->nnodes_minor > 1);
3947             }
3948             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3949             oy       += pmegrid_g->offset[YY];
3950             /* Determine the end of our part of the source grid */
3951             if (!bCommY)
3952             {
3953                 /* Use our thread local source grid and target grid part */
3954                 ty1 = min(oy + pmegrid_g->n[YY], localcopy_end[YY]);
3955             }
3956             else
3957             {
3958                 /* Use our thread local source grid and the spreading range */
3959                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3960             }
3961
3962             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3963             {
3964                 fz = pmegrid->ci[ZZ] + sz;
3965                 oz = 0;
3966                 if (fz < 0)
3967                 {
3968                     fz += pmegrids->nc[ZZ];
3969                     oz -= fft_nz;
3970                 }
3971                 pmegrid_g = &pmegrids->grid_th[fz];
3972                 oz       += pmegrid_g->offset[ZZ];
3973                 tz1       = min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
3974
3975                 if (sx == 0 && sy == 0 && sz == 0)
3976                 {
3977                     /* We have already added our local contribution
3978                      * before calling this routine, so skip it here.
3979                      */
3980                     continue;
3981                 }
3982
3983                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3984
3985                 pmegrid_f = &pmegrids->grid_th[thread_f];
3986
3987                 grid_th = pmegrid_f->grid;
3988
3989                 nsx = pmegrid_f->s[XX];
3990                 nsy = pmegrid_f->s[YY];
3991                 nsz = pmegrid_f->s[ZZ];
3992
3993 #ifdef DEBUG_PME_REDUCE
3994                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3995                        pme->nodeid, thread, thread_f,
3996                        pme->pmegrid_start_ix,
3997                        pme->pmegrid_start_iy,
3998                        pme->pmegrid_start_iz,
3999                        sx, sy, sz,
4000                        offx-ox, tx1-ox, offx, tx1,
4001                        offy-oy, ty1-oy, offy, ty1,
4002                        offz-oz, tz1-oz, offz, tz1);
4003 #endif
4004
4005                 if (!(bCommX || bCommY))
4006                 {
4007                     /* Copy from the thread local grid to the node grid */
4008                     for (x = offx; x < tx1; x++)
4009                     {
4010                         for (y = offy; y < ty1; y++)
4011                         {
4012                             i0  = (x*fft_my + y)*fft_mz;
4013                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4014                             for (z = offz; z < tz1; z++)
4015                             {
4016                                 fftgrid[i0+z] += grid_th[i0t+z];
4017                             }
4018                         }
4019                     }
4020                 }
4021                 else
4022                 {
4023                     /* The order of this conditional decides
4024                      * where the corner volume gets stored with x+y decomp.
4025                      */
4026                     if (bCommY)
4027                     {
4028                         commbuf = commbuf_y;
4029                         /* The y-size of the communication buffer is set by
4030                          * the overlap of the grid part of our local slab
4031                          * with the part starting at the next slab.
4032                          */
4033                         buf_my  =
4034                             pme->overlap[1].s2g1[pme->nodeid_minor] -
4035                             pme->overlap[1].s2g0[pme->nodeid_minor+1];
4036                         if (bCommX)
4037                         {
4038                             /* We index commbuf modulo the local grid size */
4039                             commbuf += buf_my*fft_nx*fft_nz;
4040
4041                             bClearBuf   = bClearBufXY;
4042                             bClearBufXY = FALSE;
4043                         }
4044                         else
4045                         {
4046                             bClearBuf  = bClearBufY;
4047                             bClearBufY = FALSE;
4048                         }
4049                     }
4050                     else
4051                     {
4052                         commbuf    = commbuf_x;
4053                         buf_my     = fft_ny;
4054                         bClearBuf  = bClearBufX;
4055                         bClearBufX = FALSE;
4056                     }
4057
4058                     /* Copy to the communication buffer */
4059                     for (x = offx; x < tx1; x++)
4060                     {
4061                         for (y = offy; y < ty1; y++)
4062                         {
4063                             i0  = (x*buf_my + y)*fft_nz;
4064                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4065
4066                             if (bClearBuf)
4067                             {
4068                                 /* First access of commbuf, initialize it */
4069                                 for (z = offz; z < tz1; z++)
4070                                 {
4071                                     commbuf[i0+z]  = grid_th[i0t+z];
4072                                 }
4073                             }
4074                             else
4075                             {
4076                                 for (z = offz; z < tz1; z++)
4077                                 {
4078                                     commbuf[i0+z] += grid_th[i0t+z];
4079                                 }
4080                             }
4081                         }
4082                     }
4083                 }
4084             }
4085         }
4086     }
4087 }
4088
4089
4090 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid, int grid_index)
4091 {
4092     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4093     pme_overlap_t *overlap;
4094     int  send_index0, send_nindex;
4095     int  recv_nindex;
4096 #ifdef GMX_MPI
4097     MPI_Status stat;
4098 #endif
4099     int  send_size_y, recv_size_y;
4100     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4101     real *sendptr, *recvptr;
4102     int  x, y, z, indg, indb;
4103
4104     /* Note that this routine is only used for forward communication.
4105      * Since the force gathering, unlike the coefficient spreading,
4106      * can be trivially parallelized over the particles,
4107      * the backwards process is much simpler and can use the "old"
4108      * communication setup.
4109      */
4110
4111     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
4112                                    local_fft_ndata,
4113                                    local_fft_offset,
4114                                    local_fft_size);
4115
4116     if (pme->nnodes_minor > 1)
4117     {
4118         /* Major dimension */
4119         overlap = &pme->overlap[1];
4120
4121         if (pme->nnodes_major > 1)
4122         {
4123             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4124         }
4125         else
4126         {
4127             size_yx = 0;
4128         }
4129         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4130
4131         send_size_y = overlap->send_size;
4132
4133         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4134         {
4135             send_id       = overlap->send_id[ipulse];
4136             recv_id       = overlap->recv_id[ipulse];
4137             send_index0   =
4138                 overlap->comm_data[ipulse].send_index0 -
4139                 overlap->comm_data[0].send_index0;
4140             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4141             /* We don't use recv_index0, as we always receive starting at 0 */
4142             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4143             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4144
4145             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4146             recvptr = overlap->recvbuf;
4147
4148             if (debug != NULL)
4149             {
4150                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n",
4151                         local_fft_ndata[XX], send_nindex, local_fft_ndata[ZZ]);
4152             }
4153
4154 #ifdef GMX_MPI
4155             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4156                          send_id, ipulse,
4157                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4158                          recv_id, ipulse,
4159                          overlap->mpi_comm, &stat);
4160 #endif
4161
4162             for (x = 0; x < local_fft_ndata[XX]; x++)
4163             {
4164                 for (y = 0; y < recv_nindex; y++)
4165                 {
4166                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4167                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4168                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4169                     {
4170                         fftgrid[indg+z] += recvptr[indb+z];
4171                     }
4172                 }
4173             }
4174
4175             if (pme->nnodes_major > 1)
4176             {
4177                 /* Copy from the received buffer to the send buffer for dim 0 */
4178                 sendptr = pme->overlap[0].sendbuf;
4179                 for (x = 0; x < size_yx; x++)
4180                 {
4181                     for (y = 0; y < recv_nindex; y++)
4182                     {
4183                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4184                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4185                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4186                         {
4187                             sendptr[indg+z] += recvptr[indb+z];
4188                         }
4189                     }
4190                 }
4191             }
4192         }
4193     }
4194
4195     /* We only support a single pulse here.
4196      * This is not a severe limitation, as this code is only used
4197      * with OpenMP and with OpenMP the (PME) domains can be larger.
4198      */
4199     if (pme->nnodes_major > 1)
4200     {
4201         /* Major dimension */
4202         overlap = &pme->overlap[0];
4203
4204         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4205         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4206
4207         ipulse = 0;
4208
4209         send_id       = overlap->send_id[ipulse];
4210         recv_id       = overlap->recv_id[ipulse];
4211         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4212         /* We don't use recv_index0, as we always receive starting at 0 */
4213         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4214
4215         sendptr = overlap->sendbuf;
4216         recvptr = overlap->recvbuf;
4217
4218         if (debug != NULL)
4219         {
4220             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n",
4221                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4222         }
4223
4224 #ifdef GMX_MPI
4225         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4226                      send_id, ipulse,
4227                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4228                      recv_id, ipulse,
4229                      overlap->mpi_comm, &stat);
4230 #endif
4231
4232         for (x = 0; x < recv_nindex; x++)
4233         {
4234             for (y = 0; y < local_fft_ndata[YY]; y++)
4235             {
4236                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4237                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4238                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4239                 {
4240                     fftgrid[indg+z] += recvptr[indb+z];
4241                 }
4242             }
4243         }
4244     }
4245 }
4246
4247
4248 static void spread_on_grid(gmx_pme_t pme,
4249                            pme_atomcomm_t *atc, pmegrids_t *grids,
4250                            gmx_bool bCalcSplines, gmx_bool bSpread,
4251                            real *fftgrid, gmx_bool bDoSplines, int grid_index)
4252 {
4253     int nthread, thread;
4254 #ifdef PME_TIME_THREADS
4255     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4256     static double cs1     = 0, cs2 = 0, cs3 = 0;
4257     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4258     static int cnt        = 0;
4259 #endif
4260
4261     nthread = pme->nthread;
4262     assert(nthread > 0);
4263
4264 #ifdef PME_TIME_THREADS
4265     c1 = omp_cyc_start();
4266 #endif
4267     if (bCalcSplines)
4268     {
4269 #pragma omp parallel for num_threads(nthread) schedule(static)
4270         for (thread = 0; thread < nthread; thread++)
4271         {
4272             int start, end;
4273
4274             start = atc->n* thread   /nthread;
4275             end   = atc->n*(thread+1)/nthread;
4276
4277             /* Compute fftgrid index for all atoms,
4278              * with help of some extra variables.
4279              */
4280             calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
4281         }
4282     }
4283 #ifdef PME_TIME_THREADS
4284     c1   = omp_cyc_end(c1);
4285     cs1 += (double)c1;
4286 #endif
4287
4288 #ifdef PME_TIME_THREADS
4289     c2 = omp_cyc_start();
4290 #endif
4291 #pragma omp parallel for num_threads(nthread) schedule(static)
4292     for (thread = 0; thread < nthread; thread++)
4293     {
4294         splinedata_t *spline;
4295         pmegrid_t *grid = NULL;
4296
4297         /* make local bsplines  */
4298         if (grids == NULL || !pme->bUseThreads)
4299         {
4300             spline = &atc->spline[0];
4301
4302             spline->n = atc->n;
4303
4304             if (bSpread)
4305             {
4306                 grid = &grids->grid;
4307             }
4308         }
4309         else
4310         {
4311             spline = &atc->spline[thread];
4312
4313             if (grids->nthread == 1)
4314             {
4315                 /* One thread, we operate on all coefficients */
4316                 spline->n = atc->n;
4317             }
4318             else
4319             {
4320                 /* Get the indices our thread should operate on */
4321                 make_thread_local_ind(atc, thread, spline);
4322             }
4323
4324             grid = &grids->grid_th[thread];
4325         }
4326
4327         if (bCalcSplines)
4328         {
4329             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4330                           atc->fractx, spline->n, spline->ind, atc->coefficient, bDoSplines);
4331         }
4332
4333         if (bSpread)
4334         {
4335             /* put local atoms on grid. */
4336 #ifdef PME_TIME_SPREAD
4337             ct1a = omp_cyc_start();
4338 #endif
4339             spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
4340
4341             if (pme->bUseThreads)
4342             {
4343                 copy_local_grid(pme, grids, grid_index, thread, fftgrid);
4344             }
4345 #ifdef PME_TIME_SPREAD
4346             ct1a          = omp_cyc_end(ct1a);
4347             cs1a[thread] += (double)ct1a;
4348 #endif
4349         }
4350     }
4351 #ifdef PME_TIME_THREADS
4352     c2   = omp_cyc_end(c2);
4353     cs2 += (double)c2;
4354 #endif
4355
4356     if (bSpread && pme->bUseThreads)
4357     {
4358 #ifdef PME_TIME_THREADS
4359         c3 = omp_cyc_start();
4360 #endif
4361 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4362         for (thread = 0; thread < grids->nthread; thread++)
4363         {
4364             reduce_threadgrid_overlap(pme, grids, thread,
4365                                       fftgrid,
4366                                       pme->overlap[0].sendbuf,
4367                                       pme->overlap[1].sendbuf,
4368                                       grid_index);
4369         }
4370 #ifdef PME_TIME_THREADS
4371         c3   = omp_cyc_end(c3);
4372         cs3 += (double)c3;
4373 #endif
4374
4375         if (pme->nnodes > 1)
4376         {
4377             /* Communicate the overlapping part of the fftgrid.
4378              * For this communication call we need to check pme->bUseThreads
4379              * to have all ranks communicate here, regardless of pme->nthread.
4380              */
4381             sum_fftgrid_dd(pme, fftgrid, grid_index);
4382         }
4383     }
4384
4385 #ifdef PME_TIME_THREADS
4386     cnt++;
4387     if (cnt % 20 == 0)
4388     {
4389         printf("idx %.2f spread %.2f red %.2f",
4390                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4391 #ifdef PME_TIME_SPREAD
4392         for (thread = 0; thread < nthread; thread++)
4393         {
4394             printf(" %.2f", cs1a[thread]*1e-9);
4395         }
4396 #endif
4397         printf("\n");
4398     }
4399 #endif
4400 }
4401
4402
4403 static void dump_grid(FILE *fp,
4404                       int sx, int sy, int sz, int nx, int ny, int nz,
4405                       int my, int mz, const real *g)
4406 {
4407     int x, y, z;
4408
4409     for (x = 0; x < nx; x++)
4410     {
4411         for (y = 0; y < ny; y++)
4412         {
4413             for (z = 0; z < nz; z++)
4414             {
4415                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4416                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4417             }
4418         }
4419     }
4420 }
4421
4422 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4423 {
4424     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4425
4426     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4427                                    local_fft_ndata,
4428                                    local_fft_offset,
4429                                    local_fft_size);
4430
4431     dump_grid(stderr,
4432               pme->pmegrid_start_ix,
4433               pme->pmegrid_start_iy,
4434               pme->pmegrid_start_iz,
4435               pme->pmegrid_nx-pme->pme_order+1,
4436               pme->pmegrid_ny-pme->pme_order+1,
4437               pme->pmegrid_nz-pme->pme_order+1,
4438               local_fft_size[YY],
4439               local_fft_size[ZZ],
4440               fftgrid);
4441 }
4442
4443
4444 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4445 {
4446     pme_atomcomm_t *atc;
4447     pmegrids_t *grid;
4448
4449     if (pme->nnodes > 1)
4450     {
4451         gmx_incons("gmx_pme_calc_energy called in parallel");
4452     }
4453     if (pme->bFEP_q > 1)
4454     {
4455         gmx_incons("gmx_pme_calc_energy with free energy");
4456     }
4457
4458     atc            = &pme->atc_energy;
4459     atc->nthread   = 1;
4460     if (atc->spline == NULL)
4461     {
4462         snew(atc->spline, atc->nthread);
4463     }
4464     atc->nslab     = 1;
4465     atc->bSpread   = TRUE;
4466     atc->pme_order = pme->pme_order;
4467     atc->n         = n;
4468     pme_realloc_atomcomm_things(atc);
4469     atc->x           = x;
4470     atc->coefficient = q;
4471
4472     /* We only use the A-charges grid */
4473     grid = &pme->pmegrid[PME_GRID_QA];
4474
4475     /* Only calculate the spline coefficients, don't actually spread */
4476     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE, PME_GRID_QA);
4477
4478     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4479 }
4480
4481
4482 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4483                                    gmx_walltime_accounting_t walltime_accounting,
4484                                    t_nrnb *nrnb, t_inputrec *ir,
4485                                    gmx_int64_t step)
4486 {
4487     /* Reset all the counters related to performance over the run */
4488     wallcycle_stop(wcycle, ewcRUN);
4489     wallcycle_reset_all(wcycle);
4490     init_nrnb(nrnb);
4491     if (ir->nsteps >= 0)
4492     {
4493         /* ir->nsteps is not used here, but we update it for consistency */
4494         ir->nsteps -= step - ir->init_step;
4495     }
4496     ir->init_step = step;
4497     wallcycle_start(wcycle, ewcRUN);
4498     walltime_accounting_start(walltime_accounting);
4499 }
4500
4501
4502 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4503                                ivec grid_size,
4504                                t_commrec *cr, t_inputrec *ir,
4505                                gmx_pme_t *pme_ret)
4506 {
4507     int ind;
4508     gmx_pme_t pme = NULL;
4509
4510     ind = 0;
4511     while (ind < *npmedata)
4512     {
4513         pme = (*pmedata)[ind];
4514         if (pme->nkx == grid_size[XX] &&
4515             pme->nky == grid_size[YY] &&
4516             pme->nkz == grid_size[ZZ])
4517         {
4518             *pme_ret = pme;
4519
4520             return;
4521         }
4522
4523         ind++;
4524     }
4525
4526     (*npmedata)++;
4527     srenew(*pmedata, *npmedata);
4528
4529     /* Generate a new PME data structure, copying part of the old pointers */
4530     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4531
4532     *pme_ret = (*pmedata)[ind];
4533 }
4534
4535 int gmx_pmeonly(gmx_pme_t pme,
4536                 t_commrec *cr,    t_nrnb *mynrnb,
4537                 gmx_wallcycle_t wcycle,
4538                 gmx_walltime_accounting_t walltime_accounting,
4539                 real ewaldcoeff_q, real ewaldcoeff_lj,
4540                 t_inputrec *ir)
4541 {
4542     int npmedata;
4543     gmx_pme_t *pmedata;
4544     gmx_pme_pp_t pme_pp;
4545     int  ret;
4546     int  natoms;
4547     matrix box;
4548     rvec *x_pp      = NULL, *f_pp = NULL;
4549     real *chargeA   = NULL, *chargeB = NULL;
4550     real *c6A       = NULL, *c6B = NULL;
4551     real *sigmaA    = NULL, *sigmaB = NULL;
4552     real lambda_q   = 0;
4553     real lambda_lj  = 0;
4554     int  maxshift_x = 0, maxshift_y = 0;
4555     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4556     matrix vir_q, vir_lj;
4557     float cycles;
4558     int  count;
4559     gmx_bool bEnerVir;
4560     int pme_flags;
4561     gmx_int64_t step, step_rel;
4562     ivec grid_switch;
4563
4564     /* This data will only use with PME tuning, i.e. switching PME grids */
4565     npmedata = 1;
4566     snew(pmedata, npmedata);
4567     pmedata[0] = pme;
4568
4569     pme_pp = gmx_pme_pp_init(cr);
4570
4571     init_nrnb(mynrnb);
4572
4573     count = 0;
4574     do /****** this is a quasi-loop over time steps! */
4575     {
4576         /* The reason for having a loop here is PME grid tuning/switching */
4577         do
4578         {
4579             /* Domain decomposition */
4580             ret = gmx_pme_recv_coeffs_coords(pme_pp,
4581                                              &natoms,
4582                                              &chargeA, &chargeB,
4583                                              &c6A, &c6B,
4584                                              &sigmaA, &sigmaB,
4585                                              box, &x_pp, &f_pp,
4586                                              &maxshift_x, &maxshift_y,
4587                                              &pme->bFEP_q, &pme->bFEP_lj,
4588                                              &lambda_q, &lambda_lj,
4589                                              &bEnerVir,
4590                                              &pme_flags,
4591                                              &step,
4592                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4593
4594             if (ret == pmerecvqxSWITCHGRID)
4595             {
4596                 /* Switch the PME grid to grid_switch */
4597                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4598             }
4599
4600             if (ret == pmerecvqxRESETCOUNTERS)
4601             {
4602                 /* Reset the cycle and flop counters */
4603                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4604             }
4605         }
4606         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4607
4608         if (ret == pmerecvqxFINISH)
4609         {
4610             /* We should stop: break out of the loop */
4611             break;
4612         }
4613
4614         step_rel = step - ir->init_step;
4615
4616         if (count == 0)
4617         {
4618             wallcycle_start(wcycle, ewcRUN);
4619             walltime_accounting_start(walltime_accounting);
4620         }
4621
4622         wallcycle_start(wcycle, ewcPMEMESH);
4623
4624         dvdlambda_q  = 0;
4625         dvdlambda_lj = 0;
4626         clear_mat(vir_q);
4627         clear_mat(vir_lj);
4628
4629         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4630                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4631                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4632                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4633                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4634                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4635
4636         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4637
4638         gmx_pme_send_force_vir_ener(pme_pp,
4639                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4640                                     dvdlambda_q, dvdlambda_lj, cycles);
4641
4642         count++;
4643     } /***** end of quasi-loop, we stop with the break above */
4644     while (TRUE);
4645
4646     walltime_accounting_end(walltime_accounting);
4647
4648     return 0;
4649 }
4650
4651 static void
4652 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4653 {
4654     int  i;
4655
4656     for (i = 0; i < pme->atc[0].n; ++i)
4657     {
4658         real sigma4;
4659
4660         sigma4                     = local_sigma[i];
4661         sigma4                     = sigma4*sigma4;
4662         sigma4                     = sigma4*sigma4;
4663         pme->atc[0].coefficient[i] = local_c6[i] / sigma4;
4664     }
4665 }
4666
4667 static void
4668 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4669 {
4670     int  i;
4671
4672     for (i = 0; i < pme->atc[0].n; ++i)
4673     {
4674         pme->atc[0].coefficient[i] *= local_sigma[i];
4675     }
4676 }
4677
4678 static void
4679 do_redist_pos_coeffs(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4680                      gmx_bool bFirst, rvec x[], real *data)
4681 {
4682     int      d;
4683     pme_atomcomm_t *atc;
4684     atc = &pme->atc[0];
4685
4686     for (d = pme->ndecompdim - 1; d >= 0; d--)
4687     {
4688         int             n_d;
4689         rvec           *x_d;
4690         real           *param_d;
4691
4692         if (d == pme->ndecompdim - 1)
4693         {
4694             n_d     = homenr;
4695             x_d     = x + start;
4696             param_d = data;
4697         }
4698         else
4699         {
4700             n_d     = pme->atc[d + 1].n;
4701             x_d     = atc->x;
4702             param_d = atc->coefficient;
4703         }
4704         atc      = &pme->atc[d];
4705         atc->npd = n_d;
4706         if (atc->npd > atc->pd_nalloc)
4707         {
4708             atc->pd_nalloc = over_alloc_dd(atc->npd);
4709             srenew(atc->pd, atc->pd_nalloc);
4710         }
4711         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4712         where();
4713         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4714         if (DOMAINDECOMP(cr))
4715         {
4716             dd_pmeredist_pos_coeffs(pme, n_d, bFirst, x_d, param_d, atc);
4717         }
4718     }
4719 }
4720
4721 int gmx_pme_do(gmx_pme_t pme,
4722                int start,       int homenr,
4723                rvec x[],        rvec f[],
4724                real *chargeA,   real *chargeB,
4725                real *c6A,       real *c6B,
4726                real *sigmaA,    real *sigmaB,
4727                matrix box, t_commrec *cr,
4728                int  maxshift_x, int maxshift_y,
4729                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4730                matrix vir_q,      real ewaldcoeff_q,
4731                matrix vir_lj,   real ewaldcoeff_lj,
4732                real *energy_q,  real *energy_lj,
4733                real lambda_q, real lambda_lj,
4734                real *dvdlambda_q, real *dvdlambda_lj,
4735                int flags)
4736 {
4737     int     d, i, j, k, ntot, npme, grid_index, max_grid_index;
4738     int     nx, ny, nz;
4739     int     n_d, local_ny;
4740     pme_atomcomm_t *atc = NULL;
4741     pmegrids_t *pmegrid = NULL;
4742     real    *grid       = NULL;
4743     real    *ptr;
4744     rvec    *x_d, *f_d;
4745     real    *coefficient = NULL;
4746     real    energy_AB[4];
4747     matrix  vir_AB[4];
4748     real    scale, lambda;
4749     gmx_bool bClearF;
4750     gmx_parallel_3dfft_t pfft_setup;
4751     real *  fftgrid;
4752     t_complex * cfftgrid;
4753     int     thread;
4754     gmx_bool bFirst, bDoSplines;
4755     int fep_state;
4756     int fep_states_lj           = pme->bFEP_lj ? 2 : 1;
4757     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4758     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4759
4760     assert(pme->nnodes > 0);
4761     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4762
4763     if (pme->nnodes > 1)
4764     {
4765         atc      = &pme->atc[0];
4766         atc->npd = homenr;
4767         if (atc->npd > atc->pd_nalloc)
4768         {
4769             atc->pd_nalloc = over_alloc_dd(atc->npd);
4770             srenew(atc->pd, atc->pd_nalloc);
4771         }
4772         for (d = pme->ndecompdim-1; d >= 0; d--)
4773         {
4774             atc           = &pme->atc[d];
4775             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4776         }
4777     }
4778     else
4779     {
4780         atc = &pme->atc[0];
4781         /* This could be necessary for TPI */
4782         pme->atc[0].n = homenr;
4783         if (DOMAINDECOMP(cr))
4784         {
4785             pme_realloc_atomcomm_things(atc);
4786         }
4787         atc->x = x;
4788         atc->f = f;
4789     }
4790
4791     m_inv_ur0(box, pme->recipbox);
4792     bFirst = TRUE;
4793
4794     /* For simplicity, we construct the splines for all particles if
4795      * more than one PME calculations is needed. Some optimization
4796      * could be done by keeping track of which atoms have splines
4797      * constructed, and construct new splines on each pass for atoms
4798      * that don't yet have them.
4799      */
4800
4801     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4802
4803     /* We need a maximum of four separate PME calculations:
4804      * grid_index=0: Coulomb PME with charges from state A
4805      * grid_index=1: Coulomb PME with charges from state B
4806      * grid_index=2: LJ PME with C6 from state A
4807      * grid_index=3: LJ PME with C6 from state B
4808      * For Lorentz-Berthelot combination rules, a separate loop is used to
4809      * calculate all the terms
4810      */
4811
4812     /* If we are doing LJ-PME with LB, we only do Q here */
4813     max_grid_index = (pme->ljpme_combination_rule == eljpmeLB) ? DO_Q : DO_Q_AND_LJ;
4814
4815     for (grid_index = 0; grid_index < max_grid_index; ++grid_index)
4816     {
4817         /* Check if we should do calculations at this grid_index
4818          * If grid_index is odd we should be doing FEP
4819          * If grid_index < 2 we should be doing electrostatic PME
4820          * If grid_index >= 2 we should be doing LJ-PME
4821          */
4822         if ((grid_index <  DO_Q && (!(flags & GMX_PME_DO_COULOMB) ||
4823                                     (grid_index == 1 && !pme->bFEP_q))) ||
4824             (grid_index >= DO_Q && (!(flags & GMX_PME_DO_LJ) ||
4825                                     (grid_index == 3 && !pme->bFEP_lj))))
4826         {
4827             continue;
4828         }
4829         /* Unpack structure */
4830         pmegrid    = &pme->pmegrid[grid_index];
4831         fftgrid    = pme->fftgrid[grid_index];
4832         cfftgrid   = pme->cfftgrid[grid_index];
4833         pfft_setup = pme->pfft_setup[grid_index];
4834         switch (grid_index)
4835         {
4836             case 0: coefficient = chargeA + start; break;
4837             case 1: coefficient = chargeB + start; break;
4838             case 2: coefficient = c6A + start; break;
4839             case 3: coefficient = c6B + start; break;
4840         }
4841
4842         grid = pmegrid->grid.grid;
4843
4844         if (debug)
4845         {
4846             fprintf(debug, "PME: number of ranks = %d, rank = %d\n",
4847                     cr->nnodes, cr->nodeid);
4848             fprintf(debug, "Grid = %p\n", (void*)grid);
4849             if (grid == NULL)
4850             {
4851                 gmx_fatal(FARGS, "No grid!");
4852             }
4853         }
4854         where();
4855
4856         if (pme->nnodes == 1)
4857         {
4858             atc->coefficient = coefficient;
4859         }
4860         else
4861         {
4862             wallcycle_start(wcycle, ewcPME_REDISTXF);
4863             do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, coefficient);
4864             where();
4865
4866             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4867         }
4868
4869         if (debug)
4870         {
4871             fprintf(debug, "Rank= %6d, pme local particles=%6d\n",
4872                     cr->nodeid, atc->n);
4873         }
4874
4875         if (flags & GMX_PME_SPREAD)
4876         {
4877             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4878
4879             /* Spread the coefficients on a grid */
4880             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
4881
4882             if (bFirst)
4883             {
4884                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4885             }
4886             inc_nrnb(nrnb, eNR_SPREADBSP,
4887                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4888
4889             if (!pme->bUseThreads)
4890             {
4891                 wrap_periodic_pmegrid(pme, grid);
4892
4893                 /* sum contributions to local grid from other nodes */
4894 #ifdef GMX_MPI
4895                 if (pme->nnodes > 1)
4896                 {
4897                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
4898                     where();
4899                 }
4900 #endif
4901
4902                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
4903             }
4904
4905             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4906
4907             /*
4908                dump_local_fftgrid(pme,fftgrid);
4909                exit(0);
4910              */
4911         }
4912
4913         /* Here we start a large thread parallel region */
4914 #pragma omp parallel num_threads(pme->nthread) private(thread)
4915         {
4916             thread = gmx_omp_get_thread_num();
4917             if (flags & GMX_PME_SOLVE)
4918             {
4919                 int loop_count;
4920
4921                 /* do 3d-fft */
4922                 if (thread == 0)
4923                 {
4924                     wallcycle_start(wcycle, ewcPME_FFT);
4925                 }
4926                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4927                                            thread, wcycle);
4928                 if (thread == 0)
4929                 {
4930                     wallcycle_stop(wcycle, ewcPME_FFT);
4931                 }
4932                 where();
4933
4934                 /* solve in k-space for our local cells */
4935                 if (thread == 0)
4936                 {
4937                     wallcycle_start(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4938                 }
4939                 if (grid_index < DO_Q)
4940                 {
4941                     loop_count =
4942                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
4943                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4944                                       bCalcEnerVir,
4945                                       pme->nthread, thread);
4946                 }
4947                 else
4948                 {
4949                     loop_count =
4950                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
4951                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4952                                          bCalcEnerVir,
4953                                          pme->nthread, thread);
4954                 }
4955
4956                 if (thread == 0)
4957                 {
4958                     wallcycle_stop(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4959                     where();
4960                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4961                 }
4962             }
4963
4964             if (bCalcF)
4965             {
4966                 /* do 3d-invfft */
4967                 if (thread == 0)
4968                 {
4969                     where();
4970                     wallcycle_start(wcycle, ewcPME_FFT);
4971                 }
4972                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4973                                            thread, wcycle);
4974                 if (thread == 0)
4975                 {
4976                     wallcycle_stop(wcycle, ewcPME_FFT);
4977
4978                     where();
4979
4980                     if (pme->nodeid == 0)
4981                     {
4982                         ntot  = pme->nkx*pme->nky*pme->nkz;
4983                         npme  = ntot*log((real)ntot)/log(2.0);
4984                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4985                     }
4986
4987                     /* Note: this wallcycle region is closed below
4988                        outside an OpenMP region, so take care if
4989                        refactoring code here. */
4990                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4991                 }
4992
4993                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
4994             }
4995         }
4996         /* End of thread parallel section.
4997          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4998          */
4999
5000         if (bCalcF)
5001         {
5002             /* distribute local grid to all nodes */
5003 #ifdef GMX_MPI
5004             if (pme->nnodes > 1)
5005             {
5006                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
5007             }
5008 #endif
5009             where();
5010
5011             unwrap_periodic_pmegrid(pme, grid);
5012
5013             /* interpolate forces for our local atoms */
5014
5015             where();
5016
5017             /* If we are running without parallelization,
5018              * atc->f is the actual force array, not a buffer,
5019              * therefore we should not clear it.
5020              */
5021             lambda  = grid_index < DO_Q ? lambda_q : lambda_lj;
5022             bClearF = (bFirst && PAR(cr));
5023 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5024             for (thread = 0; thread < pme->nthread; thread++)
5025             {
5026                 gather_f_bsplines(pme, grid, bClearF, atc,
5027                                   &atc->spline[thread],
5028                                   pme->bFEP ? (grid_index % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
5029             }
5030
5031             where();
5032
5033             inc_nrnb(nrnb, eNR_GATHERFBSP,
5034                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5035             /* Note: this wallcycle region is opened above inside an OpenMP
5036                region, so take care if refactoring code here. */
5037             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5038         }
5039
5040         if (bCalcEnerVir)
5041         {
5042             /* This should only be called on the master thread
5043              * and after the threads have synchronized.
5044              */
5045             if (grid_index < 2)
5046             {
5047                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5048             }
5049             else
5050             {
5051                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5052             }
5053         }
5054         bFirst = FALSE;
5055     } /* of grid_index-loop */
5056
5057     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
5058      * seven terms. */
5059
5060     if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB)
5061     {
5062         /* Loop over A- and B-state if we are doing FEP */
5063         for (fep_state = 0; fep_state < fep_states_lj; ++fep_state)
5064         {
5065             real *local_c6 = NULL, *local_sigma = NULL, *RedistC6 = NULL, *RedistSigma = NULL;
5066             if (pme->nnodes == 1)
5067             {
5068                 if (pme->lb_buf1 == NULL)
5069                 {
5070                     pme->lb_buf_nalloc = pme->atc[0].n;
5071                     snew(pme->lb_buf1, pme->lb_buf_nalloc);
5072                 }
5073                 pme->atc[0].coefficient = pme->lb_buf1;
5074                 switch (fep_state)
5075                 {
5076                     case 0:
5077                         local_c6      = c6A;
5078                         local_sigma   = sigmaA;
5079                         break;
5080                     case 1:
5081                         local_c6      = c6B;
5082                         local_sigma   = sigmaB;
5083                         break;
5084                     default:
5085                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5086                 }
5087             }
5088             else
5089             {
5090                 atc = &pme->atc[0];
5091                 switch (fep_state)
5092                 {
5093                     case 0:
5094                         RedistC6      = c6A;
5095                         RedistSigma   = sigmaA;
5096                         break;
5097                     case 1:
5098                         RedistC6      = c6B;
5099                         RedistSigma   = sigmaB;
5100                         break;
5101                     default:
5102                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5103                 }
5104                 wallcycle_start(wcycle, ewcPME_REDISTXF);
5105
5106                 do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, RedistC6);
5107                 if (pme->lb_buf_nalloc < atc->n)
5108                 {
5109                     pme->lb_buf_nalloc = atc->nalloc;
5110                     srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5111                     srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5112                 }
5113                 local_c6 = pme->lb_buf1;
5114                 for (i = 0; i < atc->n; ++i)
5115                 {
5116                     local_c6[i] = atc->coefficient[i];
5117                 }
5118                 where();
5119
5120                 do_redist_pos_coeffs(pme, cr, start, homenr, FALSE, x, RedistSigma);
5121                 local_sigma = pme->lb_buf2;
5122                 for (i = 0; i < atc->n; ++i)
5123                 {
5124                     local_sigma[i] = atc->coefficient[i];
5125                 }
5126                 where();
5127
5128                 wallcycle_stop(wcycle, ewcPME_REDISTXF);
5129             }
5130             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5131
5132             /*Seven terms in LJ-PME with LB, grid_index < 2 reserved for electrostatics*/
5133             for (grid_index = 2; grid_index < 9; ++grid_index)
5134             {
5135                 /* Unpack structure */
5136                 pmegrid    = &pme->pmegrid[grid_index];
5137                 fftgrid    = pme->fftgrid[grid_index];
5138                 cfftgrid   = pme->cfftgrid[grid_index];
5139                 pfft_setup = pme->pfft_setup[grid_index];
5140                 calc_next_lb_coeffs(pme, local_sigma);
5141                 grid = pmegrid->grid.grid;
5142                 where();
5143
5144                 if (flags & GMX_PME_SPREAD)
5145                 {
5146                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5147                     /* Spread the c6 on a grid */
5148                     spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
5149
5150                     if (bFirst)
5151                     {
5152                         inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5153                     }
5154
5155                     inc_nrnb(nrnb, eNR_SPREADBSP,
5156                              pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5157                     if (pme->nthread == 1)
5158                     {
5159                         wrap_periodic_pmegrid(pme, grid);
5160                         /* sum contributions to local grid from other nodes */
5161 #ifdef GMX_MPI
5162                         if (pme->nnodes > 1)
5163                         {
5164                             gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
5165                             where();
5166                         }
5167 #endif
5168                         copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
5169                     }
5170                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5171                 }
5172                 /*Here we start a large thread parallel region*/
5173 #pragma omp parallel num_threads(pme->nthread) private(thread)
5174                 {
5175                     thread = gmx_omp_get_thread_num();
5176                     if (flags & GMX_PME_SOLVE)
5177                     {
5178                         /* do 3d-fft */
5179                         if (thread == 0)
5180                         {
5181                             wallcycle_start(wcycle, ewcPME_FFT);
5182                         }
5183
5184                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5185                                                    thread, wcycle);
5186                         if (thread == 0)
5187                         {
5188                             wallcycle_stop(wcycle, ewcPME_FFT);
5189                         }
5190                         where();
5191                     }
5192                 }
5193                 bFirst = FALSE;
5194             }
5195             if (flags & GMX_PME_SOLVE)
5196             {
5197                 /* solve in k-space for our local cells */
5198 #pragma omp parallel num_threads(pme->nthread) private(thread)
5199                 {
5200                     int loop_count;
5201                     thread = gmx_omp_get_thread_num();
5202                     if (thread == 0)
5203                     {
5204                         wallcycle_start(wcycle, ewcLJPME);
5205                     }
5206
5207                     loop_count =
5208                         solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5209                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5210                                          bCalcEnerVir,
5211                                          pme->nthread, thread);
5212                     if (thread == 0)
5213                     {
5214                         wallcycle_stop(wcycle, ewcLJPME);
5215                         where();
5216                         inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5217                     }
5218                 }
5219             }
5220
5221             if (bCalcEnerVir)
5222             {
5223                 /* This should only be called on the master thread and
5224                  * after the threads have synchronized.
5225                  */
5226                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
5227             }
5228
5229             if (bCalcF)
5230             {
5231                 bFirst = !(flags & GMX_PME_DO_COULOMB);
5232                 calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5233                 for (grid_index = 8; grid_index >= 2; --grid_index)
5234                 {
5235                     /* Unpack structure */
5236                     pmegrid    = &pme->pmegrid[grid_index];
5237                     fftgrid    = pme->fftgrid[grid_index];
5238                     cfftgrid   = pme->cfftgrid[grid_index];
5239                     pfft_setup = pme->pfft_setup[grid_index];
5240                     grid       = pmegrid->grid.grid;
5241                     calc_next_lb_coeffs(pme, local_sigma);
5242                     where();
5243 #pragma omp parallel num_threads(pme->nthread) private(thread)
5244                     {
5245                         thread = gmx_omp_get_thread_num();
5246                         /* do 3d-invfft */
5247                         if (thread == 0)
5248                         {
5249                             where();
5250                             wallcycle_start(wcycle, ewcPME_FFT);
5251                         }
5252
5253                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5254                                                    thread, wcycle);
5255                         if (thread == 0)
5256                         {
5257                             wallcycle_stop(wcycle, ewcPME_FFT);
5258
5259                             where();
5260
5261                             if (pme->nodeid == 0)
5262                             {
5263                                 ntot  = pme->nkx*pme->nky*pme->nkz;
5264                                 npme  = ntot*log((real)ntot)/log(2.0);
5265                                 inc_nrnb(nrnb, eNR_FFT, 2*npme);
5266                             }
5267                             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5268                         }
5269
5270                         copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
5271
5272                     } /*#pragma omp parallel*/
5273
5274                     /* distribute local grid to all nodes */
5275 #ifdef GMX_MPI
5276                     if (pme->nnodes > 1)
5277                     {
5278                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
5279                     }
5280 #endif
5281                     where();
5282
5283                     unwrap_periodic_pmegrid(pme, grid);
5284
5285                     /* interpolate forces for our local atoms */
5286                     where();
5287                     bClearF = (bFirst && PAR(cr));
5288                     scale   = pme->bFEP ? (fep_state < 1 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5289                     scale  *= lb_scale_factor[grid_index-2];
5290 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5291                     for (thread = 0; thread < pme->nthread; thread++)
5292                     {
5293                         gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5294                                           &pme->atc[0].spline[thread],
5295                                           scale);
5296                     }
5297                     where();
5298
5299                     inc_nrnb(nrnb, eNR_GATHERFBSP,
5300                              pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5301                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5302
5303                     bFirst = FALSE;
5304                 } /* for (grid_index = 8; grid_index >= 2; --grid_index) */
5305             }     /* if (bCalcF) */
5306         }         /* for (fep_state = 0; fep_state < fep_states_lj; ++fep_state) */
5307     }             /* if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB) */
5308
5309     if (bCalcF && pme->nnodes > 1)
5310     {
5311         wallcycle_start(wcycle, ewcPME_REDISTXF);
5312         for (d = 0; d < pme->ndecompdim; d++)
5313         {
5314             atc = &pme->atc[d];
5315             if (d == pme->ndecompdim - 1)
5316             {
5317                 n_d = homenr;
5318                 f_d = f + start;
5319             }
5320             else
5321             {
5322                 n_d = pme->atc[d+1].n;
5323                 f_d = pme->atc[d+1].f;
5324             }
5325             if (DOMAINDECOMP(cr))
5326             {
5327                 dd_pmeredist_f(pme, atc, n_d, f_d,
5328                                d == pme->ndecompdim-1 && pme->bPPnode);
5329             }
5330         }
5331
5332         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5333     }
5334     where();
5335
5336     if (bCalcEnerVir)
5337     {
5338         if (flags & GMX_PME_DO_COULOMB)
5339         {
5340             if (!pme->bFEP_q)
5341             {
5342                 *energy_q = energy_AB[0];
5343                 m_add(vir_q, vir_AB[0], vir_q);
5344             }
5345             else
5346             {
5347                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5348                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5349                 for (i = 0; i < DIM; i++)
5350                 {
5351                     for (j = 0; j < DIM; j++)
5352                     {
5353                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5354                             lambda_q*vir_AB[1][i][j];
5355                     }
5356                 }
5357             }
5358             if (debug)
5359             {
5360                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5361             }
5362         }
5363         else
5364         {
5365             *energy_q = 0;
5366         }
5367
5368         if (flags & GMX_PME_DO_LJ)
5369         {
5370             if (!pme->bFEP_lj)
5371             {
5372                 *energy_lj = energy_AB[2];
5373                 m_add(vir_lj, vir_AB[2], vir_lj);
5374             }
5375             else
5376             {
5377                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5378                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5379                 for (i = 0; i < DIM; i++)
5380                 {
5381                     for (j = 0; j < DIM; j++)
5382                     {
5383                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5384                     }
5385                 }
5386             }
5387             if (debug)
5388             {
5389                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5390             }
5391         }
5392         else
5393         {
5394             *energy_lj = 0;
5395         }
5396     }
5397     return 0;
5398 }