Enable fp-exceptions
[alexxy/gromacs.git] / src / gromacs / ewald / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #include "gmxpre.h"
61
62 #include "pme.h"
63
64 #include "config.h"
65
66 #include <assert.h>
67 #include <math.h>
68 #include <stdio.h>
69 #include <stdlib.h>
70
71 #include "gromacs/ewald/pme-internal.h"
72 #include "gromacs/fft/fft.h"
73 #include "gromacs/fft/parallel_3dfft.h"
74 #include "gromacs/legacyheaders/macros.h"
75 #include "gromacs/legacyheaders/nrnb.h"
76 #include "gromacs/legacyheaders/types/commrec.h"
77 #include "gromacs/legacyheaders/types/enums.h"
78 #include "gromacs/legacyheaders/types/forcerec.h"
79 #include "gromacs/legacyheaders/types/inputrec.h"
80 #include "gromacs/legacyheaders/types/nrnb.h"
81 #include "gromacs/math/gmxcomplex.h"
82 #include "gromacs/math/units.h"
83 #include "gromacs/math/vec.h"
84 #include "gromacs/math/vectypes.h"
85 /* Include the SIMD macro file and then check for support */
86 #include "gromacs/simd/simd.h"
87 #include "gromacs/simd/simd_math.h"
88 #include "gromacs/timing/cyclecounter.h"
89 #include "gromacs/timing/wallcycle.h"
90 #include "gromacs/timing/walltime_accounting.h"
91 #include "gromacs/utility/basedefinitions.h"
92 #include "gromacs/utility/fatalerror.h"
93 #include "gromacs/utility/gmxmpi.h"
94 #include "gromacs/utility/gmxomp.h"
95 #include "gromacs/utility/real.h"
96 #include "gromacs/utility/smalloc.h"
97
98 #ifdef DEBUG_PME
99 #include "gromacs/fileio/pdbio.h"
100 #include "gromacs/utility/cstringutil.h"
101 #include "gromacs/utility/futil.h"
102 #endif
103
104 #ifdef GMX_SIMD_HAVE_REAL
105 /* Turn on arbitrary width SIMD intrinsics for PME solve */
106 #    define PME_SIMD_SOLVE
107 #endif
108
109 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
110 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
111 #define DO_Q           2 /* Electrostatic grids have index q<2 */
112 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
113 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
114
115 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
116 const real lb_scale_factor[] = {
117     1.0/64, 6.0/64, 15.0/64, 20.0/64,
118     15.0/64, 6.0/64, 1.0/64
119 };
120
121 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
122 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
123
124 /* Check if we have 4-wide SIMD macro support */
125 #if (defined GMX_SIMD4_HAVE_REAL)
126 /* Do PME spread and gather with 4-wide SIMD.
127  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
128  */
129 #    define PME_SIMD4_SPREAD_GATHER
130
131 #    if (defined GMX_SIMD_HAVE_LOADU) && (defined GMX_SIMD_HAVE_STOREU)
132 /* With PME-order=4 on x86, unaligned load+store is slightly faster
133  * than doubling all SIMD operations when using aligned load+store.
134  */
135 #        define PME_SIMD4_UNALIGNED
136 #    endif
137 #endif
138
139 #define DFT_TOL 1e-7
140 /* #define PRT_FORCE */
141 /* conditions for on the fly time-measurement */
142 /* #define TAKETIME (step > 1 && timesteps < 10) */
143 #define TAKETIME FALSE
144
145 /* #define PME_TIME_THREADS */
146
147 #ifdef GMX_DOUBLE
148 #define mpi_type MPI_DOUBLE
149 #else
150 #define mpi_type MPI_FLOAT
151 #endif
152
153 #ifdef PME_SIMD4_SPREAD_GATHER
154 #    define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
155 #else
156 /* We can use any alignment, apart from 0, so we use 4 reals */
157 #    define SIMD4_ALIGNMENT  (4*sizeof(real))
158 #endif
159
160 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
161  * to preserve alignment.
162  */
163 #define GMX_CACHE_SEP 64
164
165 /* We only define a maximum to be able to use local arrays without allocation.
166  * An order larger than 12 should never be needed, even for test cases.
167  * If needed it can be changed here.
168  */
169 #define PME_ORDER_MAX 12
170
171 /* Internal datastructures */
172 typedef struct {
173     int send_index0;
174     int send_nindex;
175     int recv_index0;
176     int recv_nindex;
177     int recv_size;   /* Receive buffer width, used with OpenMP */
178 } pme_grid_comm_t;
179
180 typedef real *splinevec[DIM];
181
182 typedef struct {
183 #ifdef GMX_MPI
184     MPI_Comm         mpi_comm;
185 #endif
186     int              nnodes, nodeid;
187     int             *s2g0;
188     int             *s2g1;
189     int              noverlap_nodes;
190     int             *send_id, *recv_id;
191     int              send_size; /* Send buffer width, used with OpenMP */
192     pme_grid_comm_t *comm_data;
193     real            *sendbuf;
194     real            *recvbuf;
195 } pme_overlap_t;
196
197 typedef struct {
198     int *n;      /* Cumulative counts of the number of particles per thread */
199     int  nalloc; /* Allocation size of i */
200     int *i;      /* Particle indices ordered on thread index (n) */
201 } thread_plist_t;
202
203 typedef struct {
204     int      *thread_one;
205     int       n;
206     int      *ind;
207     splinevec theta;
208     real     *ptr_theta_z;
209     splinevec dtheta;
210     real     *ptr_dtheta_z;
211 } splinedata_t;
212
213 typedef struct {
214     int      dimind;        /* The index of the dimension, 0=x, 1=y */
215     int      nslab;
216     int      nodeid;
217 #ifdef GMX_MPI
218     MPI_Comm mpi_comm;
219 #endif
220
221     int     *node_dest;     /* The nodes to send x and q to with DD */
222     int     *node_src;      /* The nodes to receive x and q from with DD */
223     int     *buf_index;     /* Index for commnode into the buffers */
224
225     int      maxshift;
226
227     int      npd;
228     int      pd_nalloc;
229     int     *pd;
230     int     *count;         /* The number of atoms to send to each node */
231     int    **count_thread;
232     int     *rcount;        /* The number of atoms to receive */
233
234     int      n;
235     int      nalloc;
236     rvec    *x;
237     real    *coefficient;
238     rvec    *f;
239     gmx_bool bSpread;       /* These coordinates are used for spreading */
240     int      pme_order;
241     ivec    *idx;
242     rvec    *fractx;            /* Fractional coordinate relative to
243                                  * the lower cell boundary
244                                  */
245     int             nthread;
246     int            *thread_idx; /* Which thread should spread which coefficient */
247     thread_plist_t *thread_plist;
248     splinedata_t   *spline;
249 } pme_atomcomm_t;
250
251 #define FLBS  3
252 #define FLBSZ 4
253
254 typedef struct {
255     ivec  ci;     /* The spatial location of this grid         */
256     ivec  n;      /* The used size of *grid, including order-1 */
257     ivec  offset; /* The grid offset from the full node grid   */
258     int   order;  /* PME spreading order                       */
259     ivec  s;      /* The allocated size of *grid, s >= n       */
260     real *grid;   /* The grid local thread, size n             */
261 } pmegrid_t;
262
263 typedef struct {
264     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
265     int        nthread;      /* The number of threads operating on this grid     */
266     ivec       nc;           /* The local spatial decomposition over the threads */
267     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
268     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
269     int      **g2t;          /* The grid to thread index                         */
270     ivec       nthread_comm; /* The number of threads to communicate with        */
271 } pmegrids_t;
272
273 typedef struct {
274 #ifdef PME_SIMD4_SPREAD_GATHER
275     /* Masks for 4-wide SIMD aligned spreading and gathering */
276     gmx_simd4_bool_t mask_S0[6], mask_S1[6];
277 #else
278     int              dummy; /* C89 requires that struct has at least one member */
279 #endif
280 } pme_spline_work_t;
281
282 typedef struct {
283     /* work data for solve_pme */
284     int      nalloc;
285     real *   mhx;
286     real *   mhy;
287     real *   mhz;
288     real *   m2;
289     real *   denom;
290     real *   tmp1_alloc;
291     real *   tmp1;
292     real *   tmp2;
293     real *   eterm;
294     real *   m2inv;
295
296     real     energy_q;
297     matrix   vir_q;
298     real     energy_lj;
299     matrix   vir_lj;
300 } pme_work_t;
301
302 typedef struct gmx_pme {
303     int           ndecompdim; /* The number of decomposition dimensions */
304     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
305     int           nodeid_major;
306     int           nodeid_minor;
307     int           nnodes;    /* The number of nodes doing PME */
308     int           nnodes_major;
309     int           nnodes_minor;
310
311     MPI_Comm      mpi_comm;
312     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
313 #ifdef GMX_MPI
314     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
315 #endif
316
317     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
318     int        nthread;       /* The number of threads doing PME on our rank */
319
320     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
321     gmx_bool   bFEP;          /* Compute Free energy contribution */
322     gmx_bool   bFEP_q;
323     gmx_bool   bFEP_lj;
324     int        nkx, nky, nkz; /* Grid dimensions */
325     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
326     int        pme_order;
327     real       epsilon_r;
328
329     int        ljpme_combination_rule;  /* Type of combination rule in LJ-PME */
330
331     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
332
333     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
334                                          * includes overlap Grid indices are ordered as
335                                          * follows:
336                                          * 0: Coloumb PME, state A
337                                          * 1: Coloumb PME, state B
338                                          * 2-8: LJ-PME
339                                          * This can probably be done in a better way
340                                          * but this simple hack works for now
341                                          */
342     /* The PME coefficient spreading grid sizes/strides, includes pme_order-1 */
343     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
344     /* pmegrid_nz might be larger than strictly necessary to ensure
345      * memory alignment, pmegrid_nz_base gives the real base size.
346      */
347     int     pmegrid_nz_base;
348     /* The local PME grid starting indices */
349     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
350
351     /* Work data for spreading and gathering */
352     pme_spline_work_t     *spline_work;
353
354     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
355     /* inside the interpolation grid, but separate for 2D PME decomp. */
356     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
357
358     t_complex            **cfftgrid;  /* Grids for complex FFT data */
359
360     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
361
362     gmx_parallel_3dfft_t  *pfft_setup;
363
364     int                   *nnx, *nny, *nnz;
365     real                  *fshx, *fshy, *fshz;
366
367     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
368     matrix                 recipbox;
369     splinevec              bsp_mod;
370     /* Buffers to store data for local atoms for L-B combination rule
371      * calculations in LJ-PME. lb_buf1 stores either the coefficients
372      * for spreading/gathering (in serial), or the C6 coefficient for
373      * local atoms (in parallel).  lb_buf2 is only used in parallel,
374      * and stores the sigma values for local atoms. */
375     real                 *lb_buf1, *lb_buf2;
376     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
377
378     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
379
380     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
381
382     rvec                 *bufv;          /* Communication buffer */
383     real                 *bufr;          /* Communication buffer */
384     int                   buf_nalloc;    /* The communication buffer size */
385
386     /* thread local work data for solve_pme */
387     pme_work_t *work;
388
389     /* Work data for sum_qgrid */
390     real *   sum_qgrid_tmp;
391     real *   sum_qgrid_dd_tmp;
392 } t_gmx_pme;
393
394 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
395                                    int start, int grid_index, int end, int thread)
396 {
397     int             i;
398     int            *idxptr, tix, tiy, tiz;
399     real           *xptr, *fptr, tx, ty, tz;
400     real            rxx, ryx, ryy, rzx, rzy, rzz;
401     int             nx, ny, nz;
402     int             start_ix, start_iy, start_iz;
403     int            *g2tx, *g2ty, *g2tz;
404     gmx_bool        bThreads;
405     int            *thread_idx = NULL;
406     thread_plist_t *tpl        = NULL;
407     int            *tpl_n      = NULL;
408     int             thread_i;
409
410     nx  = pme->nkx;
411     ny  = pme->nky;
412     nz  = pme->nkz;
413
414     start_ix = pme->pmegrid_start_ix;
415     start_iy = pme->pmegrid_start_iy;
416     start_iz = pme->pmegrid_start_iz;
417
418     rxx = pme->recipbox[XX][XX];
419     ryx = pme->recipbox[YY][XX];
420     ryy = pme->recipbox[YY][YY];
421     rzx = pme->recipbox[ZZ][XX];
422     rzy = pme->recipbox[ZZ][YY];
423     rzz = pme->recipbox[ZZ][ZZ];
424
425     g2tx = pme->pmegrid[grid_index].g2t[XX];
426     g2ty = pme->pmegrid[grid_index].g2t[YY];
427     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
428
429     bThreads = (atc->nthread > 1);
430     if (bThreads)
431     {
432         thread_idx = atc->thread_idx;
433
434         tpl   = &atc->thread_plist[thread];
435         tpl_n = tpl->n;
436         for (i = 0; i < atc->nthread; i++)
437         {
438             tpl_n[i] = 0;
439         }
440     }
441
442     for (i = start; i < end; i++)
443     {
444         xptr   = atc->x[i];
445         idxptr = atc->idx[i];
446         fptr   = atc->fractx[i];
447
448         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
449         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
450         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
451         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
452
453         tix = (int)(tx);
454         tiy = (int)(ty);
455         tiz = (int)(tz);
456
457         /* Because decomposition only occurs in x and y,
458          * we never have a fraction correction in z.
459          */
460         fptr[XX] = tx - tix + pme->fshx[tix];
461         fptr[YY] = ty - tiy + pme->fshy[tiy];
462         fptr[ZZ] = tz - tiz;
463
464         idxptr[XX] = pme->nnx[tix];
465         idxptr[YY] = pme->nny[tiy];
466         idxptr[ZZ] = pme->nnz[tiz];
467
468 #ifdef DEBUG
469         range_check(idxptr[XX], 0, pme->pmegrid_nx);
470         range_check(idxptr[YY], 0, pme->pmegrid_ny);
471         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
472 #endif
473
474         if (bThreads)
475         {
476             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
477             thread_idx[i] = thread_i;
478             tpl_n[thread_i]++;
479         }
480     }
481
482     if (bThreads)
483     {
484         /* Make a list of particle indices sorted on thread */
485
486         /* Get the cumulative count */
487         for (i = 1; i < atc->nthread; i++)
488         {
489             tpl_n[i] += tpl_n[i-1];
490         }
491         /* The current implementation distributes particles equally
492          * over the threads, so we could actually allocate for that
493          * in pme_realloc_atomcomm_things.
494          */
495         if (tpl_n[atc->nthread-1] > tpl->nalloc)
496         {
497             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
498             srenew(tpl->i, tpl->nalloc);
499         }
500         /* Set tpl_n to the cumulative start */
501         for (i = atc->nthread-1; i >= 1; i--)
502         {
503             tpl_n[i] = tpl_n[i-1];
504         }
505         tpl_n[0] = 0;
506
507         /* Fill our thread local array with indices sorted on thread */
508         for (i = start; i < end; i++)
509         {
510             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
511         }
512         /* Now tpl_n contains the cummulative count again */
513     }
514 }
515
516 static void make_thread_local_ind(pme_atomcomm_t *atc,
517                                   int thread, splinedata_t *spline)
518 {
519     int             n, t, i, start, end;
520     thread_plist_t *tpl;
521
522     /* Combine the indices made by each thread into one index */
523
524     n     = 0;
525     start = 0;
526     for (t = 0; t < atc->nthread; t++)
527     {
528         tpl = &atc->thread_plist[t];
529         /* Copy our part (start - end) from the list of thread t */
530         if (thread > 0)
531         {
532             start = tpl->n[thread-1];
533         }
534         end = tpl->n[thread];
535         for (i = start; i < end; i++)
536         {
537             spline->ind[n++] = tpl->i[i];
538         }
539     }
540
541     spline->n = n;
542 }
543
544
545 static void pme_calc_pidx(int start, int end,
546                           matrix recipbox, rvec x[],
547                           pme_atomcomm_t *atc, int *count)
548 {
549     int   nslab, i;
550     int   si;
551     real *xptr, s;
552     real  rxx, ryx, rzx, ryy, rzy;
553     int  *pd;
554
555     /* Calculate PME task index (pidx) for each grid index.
556      * Here we always assign equally sized slabs to each node
557      * for load balancing reasons (the PME grid spacing is not used).
558      */
559
560     nslab = atc->nslab;
561     pd    = atc->pd;
562
563     /* Reset the count */
564     for (i = 0; i < nslab; i++)
565     {
566         count[i] = 0;
567     }
568
569     if (atc->dimind == 0)
570     {
571         rxx = recipbox[XX][XX];
572         ryx = recipbox[YY][XX];
573         rzx = recipbox[ZZ][XX];
574         /* Calculate the node index in x-dimension */
575         for (i = start; i < end; i++)
576         {
577             xptr   = x[i];
578             /* Fractional coordinates along box vectors */
579             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
580             si    = (int)(s + 2*nslab) % nslab;
581             pd[i] = si;
582             count[si]++;
583         }
584     }
585     else
586     {
587         ryy = recipbox[YY][YY];
588         rzy = recipbox[ZZ][YY];
589         /* Calculate the node index in y-dimension */
590         for (i = start; i < end; i++)
591         {
592             xptr   = x[i];
593             /* Fractional coordinates along box vectors */
594             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
595             si    = (int)(s + 2*nslab) % nslab;
596             pd[i] = si;
597             count[si]++;
598         }
599     }
600 }
601
602 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
603                                   pme_atomcomm_t *atc)
604 {
605     int nthread, thread, slab;
606
607     nthread = atc->nthread;
608
609 #pragma omp parallel for num_threads(nthread) schedule(static)
610     for (thread = 0; thread < nthread; thread++)
611     {
612         pme_calc_pidx(natoms* thread   /nthread,
613                       natoms*(thread+1)/nthread,
614                       recipbox, x, atc, atc->count_thread[thread]);
615     }
616     /* Non-parallel reduction, since nslab is small */
617
618     for (thread = 1; thread < nthread; thread++)
619     {
620         for (slab = 0; slab < atc->nslab; slab++)
621         {
622             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
623         }
624     }
625 }
626
627 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
628 {
629     const int padding = 4;
630     int       i;
631
632     srenew(th[XX], nalloc);
633     srenew(th[YY], nalloc);
634     /* In z we add padding, this is only required for the aligned SIMD code */
635     sfree_aligned(*ptr_z);
636     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
637     th[ZZ] = *ptr_z + padding;
638
639     for (i = 0; i < padding; i++)
640     {
641         (*ptr_z)[               i] = 0;
642         (*ptr_z)[padding+nalloc+i] = 0;
643     }
644 }
645
646 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
647 {
648     int i, d;
649
650     srenew(spline->ind, atc->nalloc);
651     /* Initialize the index to identity so it works without threads */
652     for (i = 0; i < atc->nalloc; i++)
653     {
654         spline->ind[i] = i;
655     }
656
657     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
658                       atc->pme_order*atc->nalloc);
659     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
660                       atc->pme_order*atc->nalloc);
661 }
662
663 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
664 {
665     int nalloc_old, i, j, nalloc_tpl;
666
667     /* We have to avoid a NULL pointer for atc->x to avoid
668      * possible fatal errors in MPI routines.
669      */
670     if (atc->n > atc->nalloc || atc->nalloc == 0)
671     {
672         nalloc_old  = atc->nalloc;
673         atc->nalloc = over_alloc_dd(max(atc->n, 1));
674
675         if (atc->nslab > 1)
676         {
677             srenew(atc->x, atc->nalloc);
678             srenew(atc->coefficient, atc->nalloc);
679             srenew(atc->f, atc->nalloc);
680             for (i = nalloc_old; i < atc->nalloc; i++)
681             {
682                 clear_rvec(atc->f[i]);
683             }
684         }
685         if (atc->bSpread)
686         {
687             srenew(atc->fractx, atc->nalloc);
688             srenew(atc->idx, atc->nalloc);
689
690             if (atc->nthread > 1)
691             {
692                 srenew(atc->thread_idx, atc->nalloc);
693             }
694
695             for (i = 0; i < atc->nthread; i++)
696             {
697                 pme_realloc_splinedata(&atc->spline[i], atc);
698             }
699         }
700     }
701 }
702
703 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
704                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
705                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
706                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
707 {
708 #ifdef GMX_MPI
709     int        dest, src;
710     MPI_Status stat;
711
712     if (bBackward == FALSE)
713     {
714         dest = atc->node_dest[shift];
715         src  = atc->node_src[shift];
716     }
717     else
718     {
719         dest = atc->node_src[shift];
720         src  = atc->node_dest[shift];
721     }
722
723     if (nbyte_s > 0 && nbyte_r > 0)
724     {
725         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
726                      dest, shift,
727                      buf_r, nbyte_r, MPI_BYTE,
728                      src, shift,
729                      atc->mpi_comm, &stat);
730     }
731     else if (nbyte_s > 0)
732     {
733         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
734                  dest, shift,
735                  atc->mpi_comm);
736     }
737     else if (nbyte_r > 0)
738     {
739         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
740                  src, shift,
741                  atc->mpi_comm, &stat);
742     }
743 #endif
744 }
745
746 static void dd_pmeredist_pos_coeffs(gmx_pme_t pme,
747                                     int n, gmx_bool bX, rvec *x, real *data,
748                                     pme_atomcomm_t *atc)
749 {
750     int *commnode, *buf_index;
751     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
752
753     commnode  = atc->node_dest;
754     buf_index = atc->buf_index;
755
756     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
757
758     nsend = 0;
759     for (i = 0; i < nnodes_comm; i++)
760     {
761         buf_index[commnode[i]] = nsend;
762         nsend                 += atc->count[commnode[i]];
763     }
764     if (bX)
765     {
766         if (atc->count[atc->nodeid] + nsend != n)
767         {
768             gmx_fatal(FARGS, "%d particles communicated to PME rank %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
769                       "This usually means that your system is not well equilibrated.",
770                       n - (atc->count[atc->nodeid] + nsend),
771                       pme->nodeid, 'x'+atc->dimind);
772         }
773
774         if (nsend > pme->buf_nalloc)
775         {
776             pme->buf_nalloc = over_alloc_dd(nsend);
777             srenew(pme->bufv, pme->buf_nalloc);
778             srenew(pme->bufr, pme->buf_nalloc);
779         }
780
781         atc->n = atc->count[atc->nodeid];
782         for (i = 0; i < nnodes_comm; i++)
783         {
784             scount = atc->count[commnode[i]];
785             /* Communicate the count */
786             if (debug)
787             {
788                 fprintf(debug, "dimind %d PME rank %d send to rank %d: %d\n",
789                         atc->dimind, atc->nodeid, commnode[i], scount);
790             }
791             pme_dd_sendrecv(atc, FALSE, i,
792                             &scount, sizeof(int),
793                             &atc->rcount[i], sizeof(int));
794             atc->n += atc->rcount[i];
795         }
796
797         pme_realloc_atomcomm_things(atc);
798     }
799
800     local_pos = 0;
801     for (i = 0; i < n; i++)
802     {
803         node = atc->pd[i];
804         if (node == atc->nodeid)
805         {
806             /* Copy direct to the receive buffer */
807             if (bX)
808             {
809                 copy_rvec(x[i], atc->x[local_pos]);
810             }
811             atc->coefficient[local_pos] = data[i];
812             local_pos++;
813         }
814         else
815         {
816             /* Copy to the send buffer */
817             if (bX)
818             {
819                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
820             }
821             pme->bufr[buf_index[node]] = data[i];
822             buf_index[node]++;
823         }
824     }
825
826     buf_pos = 0;
827     for (i = 0; i < nnodes_comm; i++)
828     {
829         scount = atc->count[commnode[i]];
830         rcount = atc->rcount[i];
831         if (scount > 0 || rcount > 0)
832         {
833             if (bX)
834             {
835                 /* Communicate the coordinates */
836                 pme_dd_sendrecv(atc, FALSE, i,
837                                 pme->bufv[buf_pos], scount*sizeof(rvec),
838                                 atc->x[local_pos], rcount*sizeof(rvec));
839             }
840             /* Communicate the coefficients */
841             pme_dd_sendrecv(atc, FALSE, i,
842                             pme->bufr+buf_pos, scount*sizeof(real),
843                             atc->coefficient+local_pos, rcount*sizeof(real));
844             buf_pos   += scount;
845             local_pos += atc->rcount[i];
846         }
847     }
848 }
849
850 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
851                            int n, rvec *f,
852                            gmx_bool bAddF)
853 {
854     int *commnode, *buf_index;
855     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
856
857     commnode  = atc->node_dest;
858     buf_index = atc->buf_index;
859
860     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
861
862     local_pos = atc->count[atc->nodeid];
863     buf_pos   = 0;
864     for (i = 0; i < nnodes_comm; i++)
865     {
866         scount = atc->rcount[i];
867         rcount = atc->count[commnode[i]];
868         if (scount > 0 || rcount > 0)
869         {
870             /* Communicate the forces */
871             pme_dd_sendrecv(atc, TRUE, i,
872                             atc->f[local_pos], scount*sizeof(rvec),
873                             pme->bufv[buf_pos], rcount*sizeof(rvec));
874             local_pos += scount;
875         }
876         buf_index[commnode[i]] = buf_pos;
877         buf_pos               += rcount;
878     }
879
880     local_pos = 0;
881     if (bAddF)
882     {
883         for (i = 0; i < n; i++)
884         {
885             node = atc->pd[i];
886             if (node == atc->nodeid)
887             {
888                 /* Add from the local force array */
889                 rvec_inc(f[i], atc->f[local_pos]);
890                 local_pos++;
891             }
892             else
893             {
894                 /* Add from the receive buffer */
895                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
896                 buf_index[node]++;
897             }
898         }
899     }
900     else
901     {
902         for (i = 0; i < n; i++)
903         {
904             node = atc->pd[i];
905             if (node == atc->nodeid)
906             {
907                 /* Copy from the local force array */
908                 copy_rvec(atc->f[local_pos], f[i]);
909                 local_pos++;
910             }
911             else
912             {
913                 /* Copy from the receive buffer */
914                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
915                 buf_index[node]++;
916             }
917         }
918     }
919 }
920
921 #ifdef GMX_MPI
922 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
923 {
924     pme_overlap_t *overlap;
925     int            send_index0, send_nindex;
926     int            recv_index0, recv_nindex;
927     MPI_Status     stat;
928     int            i, j, k, ix, iy, iz, icnt;
929     int            ipulse, send_id, recv_id, datasize;
930     real          *p;
931     real          *sendptr, *recvptr;
932
933     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
934     overlap = &pme->overlap[1];
935
936     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
937     {
938         /* Since we have already (un)wrapped the overlap in the z-dimension,
939          * we only have to communicate 0 to nkz (not pmegrid_nz).
940          */
941         if (direction == GMX_SUM_GRID_FORWARD)
942         {
943             send_id       = overlap->send_id[ipulse];
944             recv_id       = overlap->recv_id[ipulse];
945             send_index0   = overlap->comm_data[ipulse].send_index0;
946             send_nindex   = overlap->comm_data[ipulse].send_nindex;
947             recv_index0   = overlap->comm_data[ipulse].recv_index0;
948             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
949         }
950         else
951         {
952             send_id       = overlap->recv_id[ipulse];
953             recv_id       = overlap->send_id[ipulse];
954             send_index0   = overlap->comm_data[ipulse].recv_index0;
955             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
956             recv_index0   = overlap->comm_data[ipulse].send_index0;
957             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
958         }
959
960         /* Copy data to contiguous send buffer */
961         if (debug)
962         {
963             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
964                     pme->nodeid, overlap->nodeid, send_id,
965                     pme->pmegrid_start_iy,
966                     send_index0-pme->pmegrid_start_iy,
967                     send_index0-pme->pmegrid_start_iy+send_nindex);
968         }
969         icnt = 0;
970         for (i = 0; i < pme->pmegrid_nx; i++)
971         {
972             ix = i;
973             for (j = 0; j < send_nindex; j++)
974             {
975                 iy = j + send_index0 - pme->pmegrid_start_iy;
976                 for (k = 0; k < pme->nkz; k++)
977                 {
978                     iz = k;
979                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
980                 }
981             }
982         }
983
984         datasize      = pme->pmegrid_nx * pme->nkz;
985
986         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
987                      send_id, ipulse,
988                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
989                      recv_id, ipulse,
990                      overlap->mpi_comm, &stat);
991
992         /* Get data from contiguous recv buffer */
993         if (debug)
994         {
995             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
996                     pme->nodeid, overlap->nodeid, recv_id,
997                     pme->pmegrid_start_iy,
998                     recv_index0-pme->pmegrid_start_iy,
999                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1000         }
1001         icnt = 0;
1002         for (i = 0; i < pme->pmegrid_nx; i++)
1003         {
1004             ix = i;
1005             for (j = 0; j < recv_nindex; j++)
1006             {
1007                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1008                 for (k = 0; k < pme->nkz; k++)
1009                 {
1010                     iz = k;
1011                     if (direction == GMX_SUM_GRID_FORWARD)
1012                     {
1013                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1014                     }
1015                     else
1016                     {
1017                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1018                     }
1019                 }
1020             }
1021         }
1022     }
1023
1024     /* Major dimension is easier, no copying required,
1025      * but we might have to sum to separate array.
1026      * Since we don't copy, we have to communicate up to pmegrid_nz,
1027      * not nkz as for the minor direction.
1028      */
1029     overlap = &pme->overlap[0];
1030
1031     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1032     {
1033         if (direction == GMX_SUM_GRID_FORWARD)
1034         {
1035             send_id       = overlap->send_id[ipulse];
1036             recv_id       = overlap->recv_id[ipulse];
1037             send_index0   = overlap->comm_data[ipulse].send_index0;
1038             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1039             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1040             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1041             recvptr       = overlap->recvbuf;
1042         }
1043         else
1044         {
1045             send_id       = overlap->recv_id[ipulse];
1046             recv_id       = overlap->send_id[ipulse];
1047             send_index0   = overlap->comm_data[ipulse].recv_index0;
1048             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1049             recv_index0   = overlap->comm_data[ipulse].send_index0;
1050             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1051             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1052         }
1053
1054         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1055         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1056
1057         if (debug)
1058         {
1059             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
1060                     pme->nodeid, overlap->nodeid, send_id,
1061                     pme->pmegrid_start_ix,
1062                     send_index0-pme->pmegrid_start_ix,
1063                     send_index0-pme->pmegrid_start_ix+send_nindex);
1064             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
1065                     pme->nodeid, overlap->nodeid, recv_id,
1066                     pme->pmegrid_start_ix,
1067                     recv_index0-pme->pmegrid_start_ix,
1068                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1069         }
1070
1071         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1072                      send_id, ipulse,
1073                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1074                      recv_id, ipulse,
1075                      overlap->mpi_comm, &stat);
1076
1077         /* ADD data from contiguous recv buffer */
1078         if (direction == GMX_SUM_GRID_FORWARD)
1079         {
1080             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1081             for (i = 0; i < recv_nindex*datasize; i++)
1082             {
1083                 p[i] += overlap->recvbuf[i];
1084             }
1085         }
1086     }
1087 }
1088 #endif
1089
1090
1091 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1092 {
1093     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1094     ivec    local_pme_size;
1095     int     i, ix, iy, iz;
1096     int     pmeidx, fftidx;
1097
1098     /* Dimensions should be identical for A/B grid, so we just use A here */
1099     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1100                                    local_fft_ndata,
1101                                    local_fft_offset,
1102                                    local_fft_size);
1103
1104     local_pme_size[0] = pme->pmegrid_nx;
1105     local_pme_size[1] = pme->pmegrid_ny;
1106     local_pme_size[2] = pme->pmegrid_nz;
1107
1108     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1109        the offset is identical, and the PME grid always has more data (due to overlap)
1110      */
1111     {
1112 #ifdef DEBUG_PME
1113         FILE *fp, *fp2;
1114         char  fn[STRLEN];
1115         real  val;
1116         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1117         fp = gmx_ffopen(fn, "w");
1118         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1119         fp2 = gmx_ffopen(fn, "w");
1120 #endif
1121
1122         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1123         {
1124             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1125             {
1126                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1127                 {
1128                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1129                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1130                     fftgrid[fftidx] = pmegrid[pmeidx];
1131 #ifdef DEBUG_PME
1132                     val = 100*pmegrid[pmeidx];
1133                     if (pmegrid[pmeidx] != 0)
1134                     {
1135                         gmx_fprintf_pdb_atomline(fp, epdbATOM, pmeidx, "CA", ' ', "GLY", ' ', pmeidx, ' ',
1136                                                  5.0*ix, 5.0*iy, 5.0*iz, 1.0, val, "");
1137                     }
1138                     if (pmegrid[pmeidx] != 0)
1139                     {
1140                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1141                                 "qgrid",
1142                                 pme->pmegrid_start_ix + ix,
1143                                 pme->pmegrid_start_iy + iy,
1144                                 pme->pmegrid_start_iz + iz,
1145                                 pmegrid[pmeidx]);
1146                     }
1147 #endif
1148                 }
1149             }
1150         }
1151 #ifdef DEBUG_PME
1152         gmx_ffclose(fp);
1153         gmx_ffclose(fp2);
1154 #endif
1155     }
1156     return 0;
1157 }
1158
1159
1160 static gmx_cycles_t omp_cyc_start()
1161 {
1162     return gmx_cycles_read();
1163 }
1164
1165 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1166 {
1167     return gmx_cycles_read() - c;
1168 }
1169
1170
1171 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1172                                    int nthread, int thread)
1173 {
1174     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1175     ivec          local_pme_size;
1176     int           ixy0, ixy1, ixy, ix, iy, iz;
1177     int           pmeidx, fftidx;
1178 #ifdef PME_TIME_THREADS
1179     gmx_cycles_t  c1;
1180     static double cs1 = 0;
1181     static int    cnt = 0;
1182 #endif
1183
1184 #ifdef PME_TIME_THREADS
1185     c1 = omp_cyc_start();
1186 #endif
1187     /* Dimensions should be identical for A/B grid, so we just use A here */
1188     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1189                                    local_fft_ndata,
1190                                    local_fft_offset,
1191                                    local_fft_size);
1192
1193     local_pme_size[0] = pme->pmegrid_nx;
1194     local_pme_size[1] = pme->pmegrid_ny;
1195     local_pme_size[2] = pme->pmegrid_nz;
1196
1197     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1198        the offset is identical, and the PME grid always has more data (due to overlap)
1199      */
1200     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1201     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1202
1203     for (ixy = ixy0; ixy < ixy1; ixy++)
1204     {
1205         ix = ixy/local_fft_ndata[YY];
1206         iy = ixy - ix*local_fft_ndata[YY];
1207
1208         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1209         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1210         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1211         {
1212             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1213         }
1214     }
1215
1216 #ifdef PME_TIME_THREADS
1217     c1   = omp_cyc_end(c1);
1218     cs1 += (double)c1;
1219     cnt++;
1220     if (cnt % 20 == 0)
1221     {
1222         printf("copy %.2f\n", cs1*1e-9);
1223     }
1224 #endif
1225
1226     return 0;
1227 }
1228
1229
1230 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1231 {
1232     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1233
1234     nx = pme->nkx;
1235     ny = pme->nky;
1236     nz = pme->nkz;
1237
1238     pnx = pme->pmegrid_nx;
1239     pny = pme->pmegrid_ny;
1240     pnz = pme->pmegrid_nz;
1241
1242     overlap = pme->pme_order - 1;
1243
1244     /* Add periodic overlap in z */
1245     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1246     {
1247         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1248         {
1249             for (iz = 0; iz < overlap; iz++)
1250             {
1251                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1252                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1253             }
1254         }
1255     }
1256
1257     if (pme->nnodes_minor == 1)
1258     {
1259         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1260         {
1261             for (iy = 0; iy < overlap; iy++)
1262             {
1263                 for (iz = 0; iz < nz; iz++)
1264                 {
1265                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1266                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1267                 }
1268             }
1269         }
1270     }
1271
1272     if (pme->nnodes_major == 1)
1273     {
1274         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1275
1276         for (ix = 0; ix < overlap; ix++)
1277         {
1278             for (iy = 0; iy < ny_x; iy++)
1279             {
1280                 for (iz = 0; iz < nz; iz++)
1281                 {
1282                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1283                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1284                 }
1285             }
1286         }
1287     }
1288 }
1289
1290
1291 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1292 {
1293     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1294
1295     nx = pme->nkx;
1296     ny = pme->nky;
1297     nz = pme->nkz;
1298
1299     pnx = pme->pmegrid_nx;
1300     pny = pme->pmegrid_ny;
1301     pnz = pme->pmegrid_nz;
1302
1303     overlap = pme->pme_order - 1;
1304
1305     if (pme->nnodes_major == 1)
1306     {
1307         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1308
1309         for (ix = 0; ix < overlap; ix++)
1310         {
1311             int iy, iz;
1312
1313             for (iy = 0; iy < ny_x; iy++)
1314             {
1315                 for (iz = 0; iz < nz; iz++)
1316                 {
1317                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1318                         pmegrid[(ix*pny+iy)*pnz+iz];
1319                 }
1320             }
1321         }
1322     }
1323
1324     if (pme->nnodes_minor == 1)
1325     {
1326 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1327         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1328         {
1329             int iy, iz;
1330
1331             for (iy = 0; iy < overlap; iy++)
1332             {
1333                 for (iz = 0; iz < nz; iz++)
1334                 {
1335                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1336                         pmegrid[(ix*pny+iy)*pnz+iz];
1337                 }
1338             }
1339         }
1340     }
1341
1342     /* Copy periodic overlap in z */
1343 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1344     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1345     {
1346         int iy, iz;
1347
1348         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1349         {
1350             for (iz = 0; iz < overlap; iz++)
1351             {
1352                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1353                     pmegrid[(ix*pny+iy)*pnz+iz];
1354             }
1355         }
1356     }
1357 }
1358
1359
1360 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1361 #define DO_BSPLINE(order)                            \
1362     for (ithx = 0; (ithx < order); ithx++)                    \
1363     {                                                    \
1364         index_x = (i0+ithx)*pny*pnz;                     \
1365         valx    = coefficient*thx[ithx];                          \
1366                                                      \
1367         for (ithy = 0; (ithy < order); ithy++)                \
1368         {                                                \
1369             valxy    = valx*thy[ithy];                   \
1370             index_xy = index_x+(j0+ithy)*pnz;            \
1371                                                      \
1372             for (ithz = 0; (ithz < order); ithz++)            \
1373             {                                            \
1374                 index_xyz        = index_xy+(k0+ithz);   \
1375                 grid[index_xyz] += valxy*thz[ithz];      \
1376             }                                            \
1377         }                                                \
1378     }
1379
1380
1381 static void spread_coefficients_bsplines_thread(pmegrid_t                    *pmegrid,
1382                                                 pme_atomcomm_t               *atc,
1383                                                 splinedata_t                 *spline,
1384                                                 pme_spline_work_t gmx_unused *work)
1385 {
1386
1387     /* spread coefficients from home atoms to local grid */
1388     real          *grid;
1389     pme_overlap_t *ol;
1390     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1391     int       *    idxptr;
1392     int            order, norder, index_x, index_xy, index_xyz;
1393     real           valx, valxy, coefficient;
1394     real          *thx, *thy, *thz;
1395     int            localsize, bndsize;
1396     int            pnx, pny, pnz, ndatatot;
1397     int            offx, offy, offz;
1398
1399 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1400     real           thz_buffer[GMX_SIMD4_WIDTH*3], *thz_aligned;
1401
1402     thz_aligned = gmx_simd4_align_r(thz_buffer);
1403 #endif
1404
1405     pnx = pmegrid->s[XX];
1406     pny = pmegrid->s[YY];
1407     pnz = pmegrid->s[ZZ];
1408
1409     offx = pmegrid->offset[XX];
1410     offy = pmegrid->offset[YY];
1411     offz = pmegrid->offset[ZZ];
1412
1413     ndatatot = pnx*pny*pnz;
1414     grid     = pmegrid->grid;
1415     for (i = 0; i < ndatatot; i++)
1416     {
1417         grid[i] = 0;
1418     }
1419
1420     order = pmegrid->order;
1421
1422     for (nn = 0; nn < spline->n; nn++)
1423     {
1424         n           = spline->ind[nn];
1425         coefficient = atc->coefficient[n];
1426
1427         if (coefficient != 0)
1428         {
1429             idxptr = atc->idx[n];
1430             norder = nn*order;
1431
1432             i0   = idxptr[XX] - offx;
1433             j0   = idxptr[YY] - offy;
1434             k0   = idxptr[ZZ] - offz;
1435
1436             thx = spline->theta[XX] + norder;
1437             thy = spline->theta[YY] + norder;
1438             thz = spline->theta[ZZ] + norder;
1439
1440             switch (order)
1441             {
1442                 case 4:
1443 #ifdef PME_SIMD4_SPREAD_GATHER
1444 #ifdef PME_SIMD4_UNALIGNED
1445 #define PME_SPREAD_SIMD4_ORDER4
1446 #else
1447 #define PME_SPREAD_SIMD4_ALIGNED
1448 #define PME_ORDER 4
1449 #endif
1450 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
1451 #else
1452                     DO_BSPLINE(4);
1453 #endif
1454                     break;
1455                 case 5:
1456 #ifdef PME_SIMD4_SPREAD_GATHER
1457 #define PME_SPREAD_SIMD4_ALIGNED
1458 #define PME_ORDER 5
1459 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
1460 #else
1461                     DO_BSPLINE(5);
1462 #endif
1463                     break;
1464                 default:
1465                     DO_BSPLINE(order);
1466                     break;
1467             }
1468         }
1469     }
1470 }
1471
1472 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1473 {
1474 #ifdef PME_SIMD4_SPREAD_GATHER
1475     if (pme_order == 5
1476 #ifndef PME_SIMD4_UNALIGNED
1477         || pme_order == 4
1478 #endif
1479         )
1480     {
1481         /* Round nz up to a multiple of 4 to ensure alignment */
1482         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1483     }
1484 #endif
1485 }
1486
1487 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1488 {
1489 #ifdef PME_SIMD4_SPREAD_GATHER
1490 #ifndef PME_SIMD4_UNALIGNED
1491     if (pme_order == 4)
1492     {
1493         /* Add extra elements to ensured aligned operations do not go
1494          * beyond the allocated grid size.
1495          * Note that for pme_order=5, the pme grid z-size alignment
1496          * ensures that we will not go beyond the grid size.
1497          */
1498         *gridsize += 4;
1499     }
1500 #endif
1501 #endif
1502 }
1503
1504 static void pmegrid_init(pmegrid_t *grid,
1505                          int cx, int cy, int cz,
1506                          int x0, int y0, int z0,
1507                          int x1, int y1, int z1,
1508                          gmx_bool set_alignment,
1509                          int pme_order,
1510                          real *ptr)
1511 {
1512     int nz, gridsize;
1513
1514     grid->ci[XX]     = cx;
1515     grid->ci[YY]     = cy;
1516     grid->ci[ZZ]     = cz;
1517     grid->offset[XX] = x0;
1518     grid->offset[YY] = y0;
1519     grid->offset[ZZ] = z0;
1520     grid->n[XX]      = x1 - x0 + pme_order - 1;
1521     grid->n[YY]      = y1 - y0 + pme_order - 1;
1522     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1523     copy_ivec(grid->n, grid->s);
1524
1525     nz = grid->s[ZZ];
1526     set_grid_alignment(&nz, pme_order);
1527     if (set_alignment)
1528     {
1529         grid->s[ZZ] = nz;
1530     }
1531     else if (nz != grid->s[ZZ])
1532     {
1533         gmx_incons("pmegrid_init call with an unaligned z size");
1534     }
1535
1536     grid->order = pme_order;
1537     if (ptr == NULL)
1538     {
1539         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1540         set_gridsize_alignment(&gridsize, pme_order);
1541         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1542     }
1543     else
1544     {
1545         grid->grid = ptr;
1546     }
1547 }
1548
1549 static int div_round_up(int enumerator, int denominator)
1550 {
1551     return (enumerator + denominator - 1)/denominator;
1552 }
1553
1554 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1555                                   ivec nsub)
1556 {
1557     int gsize_opt, gsize;
1558     int nsx, nsy, nsz;
1559     char *env;
1560
1561     gsize_opt = -1;
1562     for (nsx = 1; nsx <= nthread; nsx++)
1563     {
1564         if (nthread % nsx == 0)
1565         {
1566             for (nsy = 1; nsy <= nthread; nsy++)
1567             {
1568                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1569                 {
1570                     nsz = nthread/(nsx*nsy);
1571
1572                     /* Determine the number of grid points per thread */
1573                     gsize =
1574                         (div_round_up(n[XX], nsx) + ovl)*
1575                         (div_round_up(n[YY], nsy) + ovl)*
1576                         (div_round_up(n[ZZ], nsz) + ovl);
1577
1578                     /* Minimize the number of grids points per thread
1579                      * and, secondarily, the number of cuts in minor dimensions.
1580                      */
1581                     if (gsize_opt == -1 ||
1582                         gsize < gsize_opt ||
1583                         (gsize == gsize_opt &&
1584                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1585                     {
1586                         nsub[XX]  = nsx;
1587                         nsub[YY]  = nsy;
1588                         nsub[ZZ]  = nsz;
1589                         gsize_opt = gsize;
1590                     }
1591                 }
1592             }
1593         }
1594     }
1595
1596     env = getenv("GMX_PME_THREAD_DIVISION");
1597     if (env != NULL)
1598     {
1599         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1600     }
1601
1602     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1603     {
1604         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1605     }
1606 }
1607
1608 static void pmegrids_init(pmegrids_t *grids,
1609                           int nx, int ny, int nz, int nz_base,
1610                           int pme_order,
1611                           gmx_bool bUseThreads,
1612                           int nthread,
1613                           int overlap_x,
1614                           int overlap_y)
1615 {
1616     ivec n, n_base, g0, g1;
1617     int t, x, y, z, d, i, tfac;
1618     int max_comm_lines = -1;
1619
1620     n[XX] = nx - (pme_order - 1);
1621     n[YY] = ny - (pme_order - 1);
1622     n[ZZ] = nz - (pme_order - 1);
1623
1624     copy_ivec(n, n_base);
1625     n_base[ZZ] = nz_base;
1626
1627     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1628                  NULL);
1629
1630     grids->nthread = nthread;
1631
1632     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1633
1634     if (bUseThreads)
1635     {
1636         ivec nst;
1637         int gridsize;
1638
1639         for (d = 0; d < DIM; d++)
1640         {
1641             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1642         }
1643         set_grid_alignment(&nst[ZZ], pme_order);
1644
1645         if (debug)
1646         {
1647             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1648                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1649             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1650                     nx, ny, nz,
1651                     nst[XX], nst[YY], nst[ZZ]);
1652         }
1653
1654         snew(grids->grid_th, grids->nthread);
1655         t        = 0;
1656         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1657         set_gridsize_alignment(&gridsize, pme_order);
1658         snew_aligned(grids->grid_all,
1659                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1660                      SIMD4_ALIGNMENT);
1661
1662         for (x = 0; x < grids->nc[XX]; x++)
1663         {
1664             for (y = 0; y < grids->nc[YY]; y++)
1665             {
1666                 for (z = 0; z < grids->nc[ZZ]; z++)
1667                 {
1668                     pmegrid_init(&grids->grid_th[t],
1669                                  x, y, z,
1670                                  (n[XX]*(x  ))/grids->nc[XX],
1671                                  (n[YY]*(y  ))/grids->nc[YY],
1672                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1673                                  (n[XX]*(x+1))/grids->nc[XX],
1674                                  (n[YY]*(y+1))/grids->nc[YY],
1675                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1676                                  TRUE,
1677                                  pme_order,
1678                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1679                     t++;
1680                 }
1681             }
1682         }
1683     }
1684     else
1685     {
1686         grids->grid_th = NULL;
1687     }
1688
1689     snew(grids->g2t, DIM);
1690     tfac = 1;
1691     for (d = DIM-1; d >= 0; d--)
1692     {
1693         snew(grids->g2t[d], n[d]);
1694         t = 0;
1695         for (i = 0; i < n[d]; i++)
1696         {
1697             /* The second check should match the parameters
1698              * of the pmegrid_init call above.
1699              */
1700             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1701             {
1702                 t++;
1703             }
1704             grids->g2t[d][i] = t*tfac;
1705         }
1706
1707         tfac *= grids->nc[d];
1708
1709         switch (d)
1710         {
1711             case XX: max_comm_lines = overlap_x;     break;
1712             case YY: max_comm_lines = overlap_y;     break;
1713             case ZZ: max_comm_lines = pme_order - 1; break;
1714         }
1715         grids->nthread_comm[d] = 0;
1716         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1717                grids->nthread_comm[d] < grids->nc[d])
1718         {
1719             grids->nthread_comm[d]++;
1720         }
1721         if (debug != NULL)
1722         {
1723             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1724                     'x'+d, grids->nthread_comm[d]);
1725         }
1726         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1727          * work, but this is not a problematic restriction.
1728          */
1729         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1730         {
1731             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1732         }
1733     }
1734 }
1735
1736
1737 static void pmegrids_destroy(pmegrids_t *grids)
1738 {
1739     int t;
1740
1741     if (grids->grid.grid != NULL)
1742     {
1743         sfree(grids->grid.grid);
1744
1745         if (grids->nthread > 0)
1746         {
1747             for (t = 0; t < grids->nthread; t++)
1748             {
1749                 sfree(grids->grid_th[t].grid);
1750             }
1751             sfree(grids->grid_th);
1752         }
1753     }
1754 }
1755
1756
1757 static void realloc_work(pme_work_t *work, int nkx)
1758 {
1759     int simd_width, i;
1760
1761     if (nkx > work->nalloc)
1762     {
1763         work->nalloc = nkx;
1764         srenew(work->mhx, work->nalloc);
1765         srenew(work->mhy, work->nalloc);
1766         srenew(work->mhz, work->nalloc);
1767         srenew(work->m2, work->nalloc);
1768         /* Allocate an aligned pointer for SIMD operations, including extra
1769          * elements at the end for padding.
1770          */
1771 #ifdef PME_SIMD_SOLVE
1772         simd_width = GMX_SIMD_REAL_WIDTH;
1773 #else
1774         /* We can use any alignment, apart from 0, so we use 4 */
1775         simd_width = 4;
1776 #endif
1777         sfree_aligned(work->denom);
1778         sfree_aligned(work->tmp1);
1779         sfree_aligned(work->tmp2);
1780         sfree_aligned(work->eterm);
1781         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1782         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1783         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1784         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1785         srenew(work->m2inv, work->nalloc);
1786 #ifndef NDEBUG
1787         for (i = 0; i < work->nalloc+simd_width; i++)
1788         {
1789             work->denom[i] = 1; /* init to 1 to avoid 1/0 exceptions of simd padded elements */
1790         }
1791 #endif
1792     }
1793 }
1794
1795
1796 static void free_work(pme_work_t *work)
1797 {
1798     sfree(work->mhx);
1799     sfree(work->mhy);
1800     sfree(work->mhz);
1801     sfree(work->m2);
1802     sfree_aligned(work->denom);
1803     sfree_aligned(work->tmp1);
1804     sfree_aligned(work->tmp2);
1805     sfree_aligned(work->eterm);
1806     sfree(work->m2inv);
1807 }
1808
1809
1810 #if defined PME_SIMD_SOLVE
1811 /* Calculate exponentials through SIMD */
1812 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1813 {
1814     {
1815         const gmx_simd_real_t two = gmx_simd_set1_r(2.0);
1816         gmx_simd_real_t f_simd;
1817         gmx_simd_real_t lu;
1818         gmx_simd_real_t tmp_d1, d_inv, tmp_r, tmp_e;
1819         int kx;
1820         f_simd = gmx_simd_set1_r(f);
1821         /* We only need to calculate from start. But since start is 0 or 1
1822          * and we want to use aligned loads/stores, we always start from 0.
1823          */
1824         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1825         {
1826             tmp_d1   = gmx_simd_load_r(d_aligned+kx);
1827             d_inv    = gmx_simd_inv_r(tmp_d1);
1828             tmp_r    = gmx_simd_load_r(r_aligned+kx);
1829             tmp_r    = gmx_simd_exp_r(tmp_r);
1830             tmp_e    = gmx_simd_mul_r(f_simd, d_inv);
1831             tmp_e    = gmx_simd_mul_r(tmp_e, tmp_r);
1832             gmx_simd_store_r(e_aligned+kx, tmp_e);
1833         }
1834     }
1835 }
1836 #else
1837 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1838 {
1839     int kx;
1840     for (kx = start; kx < end; kx++)
1841     {
1842         d[kx] = 1.0/d[kx];
1843     }
1844     for (kx = start; kx < end; kx++)
1845     {
1846         r[kx] = exp(r[kx]);
1847     }
1848     for (kx = start; kx < end; kx++)
1849     {
1850         e[kx] = f*r[kx]*d[kx];
1851     }
1852 }
1853 #endif
1854
1855 #if defined PME_SIMD_SOLVE
1856 /* Calculate exponentials through SIMD */
1857 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1858 {
1859     gmx_simd_real_t tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1860     const gmx_simd_real_t sqr_PI = gmx_simd_sqrt_r(gmx_simd_set1_r(M_PI));
1861     int kx;
1862     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1863     {
1864         /* We only need to calculate from start. But since start is 0 or 1
1865          * and we want to use aligned loads/stores, we always start from 0.
1866          */
1867         tmp_d = gmx_simd_load_r(d_aligned+kx);
1868         d_inv = gmx_simd_inv_r(tmp_d);
1869         gmx_simd_store_r(d_aligned+kx, d_inv);
1870         tmp_r = gmx_simd_load_r(r_aligned+kx);
1871         tmp_r = gmx_simd_exp_r(tmp_r);
1872         gmx_simd_store_r(r_aligned+kx, tmp_r);
1873         tmp_mk  = gmx_simd_load_r(factor_aligned+kx);
1874         tmp_fac = gmx_simd_mul_r(sqr_PI, gmx_simd_mul_r(tmp_mk, gmx_simd_erfc_r(tmp_mk)));
1875         gmx_simd_store_r(factor_aligned+kx, tmp_fac);
1876     }
1877 }
1878 #else
1879 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1880 {
1881     int kx;
1882     real mk;
1883     for (kx = start; kx < end; kx++)
1884     {
1885         d[kx] = 1.0/d[kx];
1886     }
1887
1888     for (kx = start; kx < end; kx++)
1889     {
1890         r[kx] = exp(r[kx]);
1891     }
1892
1893     for (kx = start; kx < end; kx++)
1894     {
1895         mk       = tmp2[kx];
1896         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1897     }
1898 }
1899 #endif
1900
1901 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1902                          real ewaldcoeff, real vol,
1903                          gmx_bool bEnerVir,
1904                          int nthread, int thread)
1905 {
1906     /* do recip sum over local cells in grid */
1907     /* y major, z middle, x minor or continuous */
1908     t_complex *p0;
1909     int     kx, ky, kz, maxkx, maxky, maxkz;
1910     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1911     real    mx, my, mz;
1912     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1913     real    ets2, struct2, vfactor, ets2vf;
1914     real    d1, d2, energy = 0;
1915     real    by, bz;
1916     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1917     real    rxx, ryx, ryy, rzx, rzy, rzz;
1918     pme_work_t *work;
1919     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1920     real    mhxk, mhyk, mhzk, m2k;
1921     real    corner_fac;
1922     ivec    complex_order;
1923     ivec    local_ndata, local_offset, local_size;
1924     real    elfac;
1925
1926     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1927
1928     nx = pme->nkx;
1929     ny = pme->nky;
1930     nz = pme->nkz;
1931
1932     /* Dimensions should be identical for A/B grid, so we just use A here */
1933     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
1934                                       complex_order,
1935                                       local_ndata,
1936                                       local_offset,
1937                                       local_size);
1938
1939     rxx = pme->recipbox[XX][XX];
1940     ryx = pme->recipbox[YY][XX];
1941     ryy = pme->recipbox[YY][YY];
1942     rzx = pme->recipbox[ZZ][XX];
1943     rzy = pme->recipbox[ZZ][YY];
1944     rzz = pme->recipbox[ZZ][ZZ];
1945
1946     maxkx = (nx+1)/2;
1947     maxky = (ny+1)/2;
1948     maxkz = nz/2+1;
1949
1950     work  = &pme->work[thread];
1951     mhx   = work->mhx;
1952     mhy   = work->mhy;
1953     mhz   = work->mhz;
1954     m2    = work->m2;
1955     denom = work->denom;
1956     tmp1  = work->tmp1;
1957     eterm = work->eterm;
1958     m2inv = work->m2inv;
1959
1960     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1961     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1962
1963     for (iyz = iyz0; iyz < iyz1; iyz++)
1964     {
1965         iy = iyz/local_ndata[ZZ];
1966         iz = iyz - iy*local_ndata[ZZ];
1967
1968         ky = iy + local_offset[YY];
1969
1970         if (ky < maxky)
1971         {
1972             my = ky;
1973         }
1974         else
1975         {
1976             my = (ky - ny);
1977         }
1978
1979         by = M_PI*vol*pme->bsp_mod[YY][ky];
1980
1981         kz = iz + local_offset[ZZ];
1982
1983         mz = kz;
1984
1985         bz = pme->bsp_mod[ZZ][kz];
1986
1987         /* 0.5 correction for corner points */
1988         corner_fac = 1;
1989         if (kz == 0 || kz == (nz+1)/2)
1990         {
1991             corner_fac = 0.5;
1992         }
1993
1994         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1995
1996         /* We should skip the k-space point (0,0,0) */
1997         /* Note that since here x is the minor index, local_offset[XX]=0 */
1998         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1999         {
2000             kxstart = local_offset[XX];
2001         }
2002         else
2003         {
2004             kxstart = local_offset[XX] + 1;
2005             p0++;
2006         }
2007         kxend = local_offset[XX] + local_ndata[XX];
2008
2009         if (bEnerVir)
2010         {
2011             /* More expensive inner loop, especially because of the storage
2012              * of the mh elements in array's.
2013              * Because x is the minor grid index, all mh elements
2014              * depend on kx for triclinic unit cells.
2015              */
2016
2017             /* Two explicit loops to avoid a conditional inside the loop */
2018             for (kx = kxstart; kx < maxkx; kx++)
2019             {
2020                 mx = kx;
2021
2022                 mhxk      = mx * rxx;
2023                 mhyk      = mx * ryx + my * ryy;
2024                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2025                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2026                 mhx[kx]   = mhxk;
2027                 mhy[kx]   = mhyk;
2028                 mhz[kx]   = mhzk;
2029                 m2[kx]    = m2k;
2030                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2031                 tmp1[kx]  = -factor*m2k;
2032             }
2033
2034             for (kx = maxkx; kx < kxend; kx++)
2035             {
2036                 mx = (kx - nx);
2037
2038                 mhxk      = mx * rxx;
2039                 mhyk      = mx * ryx + my * ryy;
2040                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2041                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2042                 mhx[kx]   = mhxk;
2043                 mhy[kx]   = mhyk;
2044                 mhz[kx]   = mhzk;
2045                 m2[kx]    = m2k;
2046                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2047                 tmp1[kx]  = -factor*m2k;
2048             }
2049
2050             for (kx = kxstart; kx < kxend; kx++)
2051             {
2052                 m2inv[kx] = 1.0/m2[kx];
2053             }
2054
2055             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2056
2057             for (kx = kxstart; kx < kxend; kx++, p0++)
2058             {
2059                 d1      = p0->re;
2060                 d2      = p0->im;
2061
2062                 p0->re  = d1*eterm[kx];
2063                 p0->im  = d2*eterm[kx];
2064
2065                 struct2 = 2.0*(d1*d1+d2*d2);
2066
2067                 tmp1[kx] = eterm[kx]*struct2;
2068             }
2069
2070             for (kx = kxstart; kx < kxend; kx++)
2071             {
2072                 ets2     = corner_fac*tmp1[kx];
2073                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2074                 energy  += ets2;
2075
2076                 ets2vf   = ets2*vfactor;
2077                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2078                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2079                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2080                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2081                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2082                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2083             }
2084         }
2085         else
2086         {
2087             /* We don't need to calculate the energy and the virial.
2088              * In this case the triclinic overhead is small.
2089              */
2090
2091             /* Two explicit loops to avoid a conditional inside the loop */
2092
2093             for (kx = kxstart; kx < maxkx; kx++)
2094             {
2095                 mx = kx;
2096
2097                 mhxk      = mx * rxx;
2098                 mhyk      = mx * ryx + my * ryy;
2099                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2100                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2101                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2102                 tmp1[kx]  = -factor*m2k;
2103             }
2104
2105             for (kx = maxkx; kx < kxend; kx++)
2106             {
2107                 mx = (kx - nx);
2108
2109                 mhxk      = mx * rxx;
2110                 mhyk      = mx * ryx + my * ryy;
2111                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2112                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2113                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2114                 tmp1[kx]  = -factor*m2k;
2115             }
2116
2117             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2118
2119             for (kx = kxstart; kx < kxend; kx++, p0++)
2120             {
2121                 d1      = p0->re;
2122                 d2      = p0->im;
2123
2124                 p0->re  = d1*eterm[kx];
2125                 p0->im  = d2*eterm[kx];
2126             }
2127         }
2128     }
2129
2130     if (bEnerVir)
2131     {
2132         /* Update virial with local values.
2133          * The virial is symmetric by definition.
2134          * this virial seems ok for isotropic scaling, but I'm
2135          * experiencing problems on semiisotropic membranes.
2136          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2137          */
2138         work->vir_q[XX][XX] = 0.25*virxx;
2139         work->vir_q[YY][YY] = 0.25*viryy;
2140         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2141         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2142         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2143         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2144
2145         /* This energy should be corrected for a charged system */
2146         work->energy_q = 0.5*energy;
2147     }
2148
2149     /* Return the loop count */
2150     return local_ndata[YY]*local_ndata[XX];
2151 }
2152
2153 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2154                             real ewaldcoeff, real vol,
2155                             gmx_bool bEnerVir, int nthread, int thread)
2156 {
2157     /* do recip sum over local cells in grid */
2158     /* y major, z middle, x minor or continuous */
2159     int     ig, gcount;
2160     int     kx, ky, kz, maxkx, maxky, maxkz;
2161     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2162     real    mx, my, mz;
2163     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2164     real    ets2, ets2vf;
2165     real    eterm, vterm, d1, d2, energy = 0;
2166     real    by, bz;
2167     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2168     real    rxx, ryx, ryy, rzx, rzy, rzz;
2169     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2170     real    mhxk, mhyk, mhzk, m2k;
2171     real    mk;
2172     pme_work_t *work;
2173     real    corner_fac;
2174     ivec    complex_order;
2175     ivec    local_ndata, local_offset, local_size;
2176     nx = pme->nkx;
2177     ny = pme->nky;
2178     nz = pme->nkz;
2179
2180     /* Dimensions should be identical for A/B grid, so we just use A here */
2181     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2182                                       complex_order,
2183                                       local_ndata,
2184                                       local_offset,
2185                                       local_size);
2186     rxx = pme->recipbox[XX][XX];
2187     ryx = pme->recipbox[YY][XX];
2188     ryy = pme->recipbox[YY][YY];
2189     rzx = pme->recipbox[ZZ][XX];
2190     rzy = pme->recipbox[ZZ][YY];
2191     rzz = pme->recipbox[ZZ][ZZ];
2192
2193     maxkx = (nx+1)/2;
2194     maxky = (ny+1)/2;
2195     maxkz = nz/2+1;
2196
2197     work  = &pme->work[thread];
2198     mhx   = work->mhx;
2199     mhy   = work->mhy;
2200     mhz   = work->mhz;
2201     m2    = work->m2;
2202     denom = work->denom;
2203     tmp1  = work->tmp1;
2204     tmp2  = work->tmp2;
2205
2206     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2207     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2208
2209     for (iyz = iyz0; iyz < iyz1; iyz++)
2210     {
2211         iy = iyz/local_ndata[ZZ];
2212         iz = iyz - iy*local_ndata[ZZ];
2213
2214         ky = iy + local_offset[YY];
2215
2216         if (ky < maxky)
2217         {
2218             my = ky;
2219         }
2220         else
2221         {
2222             my = (ky - ny);
2223         }
2224
2225         by = 3.0*vol*pme->bsp_mod[YY][ky]
2226             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2227
2228         kz = iz + local_offset[ZZ];
2229
2230         mz = kz;
2231
2232         bz = pme->bsp_mod[ZZ][kz];
2233
2234         /* 0.5 correction for corner points */
2235         corner_fac = 1;
2236         if (kz == 0 || kz == (nz+1)/2)
2237         {
2238             corner_fac = 0.5;
2239         }
2240
2241         kxstart = local_offset[XX];
2242         kxend   = local_offset[XX] + local_ndata[XX];
2243         if (bEnerVir)
2244         {
2245             /* More expensive inner loop, especially because of the
2246              * storage of the mh elements in array's.  Because x is the
2247              * minor grid index, all mh elements depend on kx for
2248              * triclinic unit cells.
2249              */
2250
2251             /* Two explicit loops to avoid a conditional inside the loop */
2252             for (kx = kxstart; kx < maxkx; kx++)
2253             {
2254                 mx = kx;
2255
2256                 mhxk      = mx * rxx;
2257                 mhyk      = mx * ryx + my * ryy;
2258                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2259                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2260                 mhx[kx]   = mhxk;
2261                 mhy[kx]   = mhyk;
2262                 mhz[kx]   = mhzk;
2263                 m2[kx]    = m2k;
2264                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2265                 tmp1[kx]  = -factor*m2k;
2266                 tmp2[kx]  = sqrt(factor*m2k);
2267             }
2268
2269             for (kx = maxkx; kx < kxend; kx++)
2270             {
2271                 mx = (kx - nx);
2272
2273                 mhxk      = mx * rxx;
2274                 mhyk      = mx * ryx + my * ryy;
2275                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2276                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2277                 mhx[kx]   = mhxk;
2278                 mhy[kx]   = mhyk;
2279                 mhz[kx]   = mhzk;
2280                 m2[kx]    = m2k;
2281                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2282                 tmp1[kx]  = -factor*m2k;
2283                 tmp2[kx]  = sqrt(factor*m2k);
2284             }
2285
2286             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2287
2288             for (kx = kxstart; kx < kxend; kx++)
2289             {
2290                 m2k   = factor*m2[kx];
2291                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2292                           + 2.0*m2k*tmp2[kx]);
2293                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2294                 tmp1[kx] = eterm*denom[kx];
2295                 tmp2[kx] = vterm*denom[kx];
2296             }
2297
2298             if (!bLB)
2299             {
2300                 t_complex *p0;
2301                 real       struct2;
2302
2303                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2304                 for (kx = kxstart; kx < kxend; kx++, p0++)
2305                 {
2306                     d1      = p0->re;
2307                     d2      = p0->im;
2308
2309                     eterm   = tmp1[kx];
2310                     vterm   = tmp2[kx];
2311                     p0->re  = d1*eterm;
2312                     p0->im  = d2*eterm;
2313
2314                     struct2 = 2.0*(d1*d1+d2*d2);
2315
2316                     tmp1[kx] = eterm*struct2;
2317                     tmp2[kx] = vterm*struct2;
2318                 }
2319             }
2320             else
2321             {
2322                 real *struct2 = denom;
2323                 real  str2;
2324
2325                 for (kx = kxstart; kx < kxend; kx++)
2326                 {
2327                     struct2[kx] = 0.0;
2328                 }
2329                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2330                 for (ig = 0; ig <= 3; ++ig)
2331                 {
2332                     t_complex *p0, *p1;
2333                     real       scale;
2334
2335                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2336                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2337                     scale = 2.0*lb_scale_factor_symm[ig];
2338                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2339                     {
2340                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2341                     }
2342
2343                 }
2344                 for (ig = 0; ig <= 6; ++ig)
2345                 {
2346                     t_complex *p0;
2347
2348                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2349                     for (kx = kxstart; kx < kxend; kx++, p0++)
2350                     {
2351                         d1     = p0->re;
2352                         d2     = p0->im;
2353
2354                         eterm  = tmp1[kx];
2355                         p0->re = d1*eterm;
2356                         p0->im = d2*eterm;
2357                     }
2358                 }
2359                 for (kx = kxstart; kx < kxend; kx++)
2360                 {
2361                     eterm    = tmp1[kx];
2362                     vterm    = tmp2[kx];
2363                     str2     = struct2[kx];
2364                     tmp1[kx] = eterm*str2;
2365                     tmp2[kx] = vterm*str2;
2366                 }
2367             }
2368
2369             for (kx = kxstart; kx < kxend; kx++)
2370             {
2371                 ets2     = corner_fac*tmp1[kx];
2372                 vterm    = 2.0*factor*tmp2[kx];
2373                 energy  += ets2;
2374                 ets2vf   = corner_fac*vterm;
2375                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2376                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2377                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2378                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2379                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2380                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2381             }
2382         }
2383         else
2384         {
2385             /* We don't need to calculate the energy and the virial.
2386              *  In this case the triclinic overhead is small.
2387              */
2388
2389             /* Two explicit loops to avoid a conditional inside the loop */
2390
2391             for (kx = kxstart; kx < maxkx; kx++)
2392             {
2393                 mx = kx;
2394
2395                 mhxk      = mx * rxx;
2396                 mhyk      = mx * ryx + my * ryy;
2397                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2398                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2399                 m2[kx]    = m2k;
2400                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2401                 tmp1[kx]  = -factor*m2k;
2402                 tmp2[kx]  = sqrt(factor*m2k);
2403             }
2404
2405             for (kx = maxkx; kx < kxend; kx++)
2406             {
2407                 mx = (kx - nx);
2408
2409                 mhxk      = mx * rxx;
2410                 mhyk      = mx * ryx + my * ryy;
2411                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2412                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2413                 m2[kx]    = m2k;
2414                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2415                 tmp1[kx]  = -factor*m2k;
2416                 tmp2[kx]  = sqrt(factor*m2k);
2417             }
2418
2419             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2420
2421             for (kx = kxstart; kx < kxend; kx++)
2422             {
2423                 m2k    = factor*m2[kx];
2424                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2425                            + 2.0*m2k*tmp2[kx]);
2426                 tmp1[kx] = eterm*denom[kx];
2427             }
2428             gcount = (bLB ? 7 : 1);
2429             for (ig = 0; ig < gcount; ++ig)
2430             {
2431                 t_complex *p0;
2432
2433                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2434                 for (kx = kxstart; kx < kxend; kx++, p0++)
2435                 {
2436                     d1      = p0->re;
2437                     d2      = p0->im;
2438
2439                     eterm   = tmp1[kx];
2440
2441                     p0->re  = d1*eterm;
2442                     p0->im  = d2*eterm;
2443                 }
2444             }
2445         }
2446     }
2447     if (bEnerVir)
2448     {
2449         work->vir_lj[XX][XX] = 0.25*virxx;
2450         work->vir_lj[YY][YY] = 0.25*viryy;
2451         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2452         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2453         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2454         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2455
2456         /* This energy should be corrected for a charged system */
2457         work->energy_lj = 0.5*energy;
2458     }
2459     /* Return the loop count */
2460     return local_ndata[YY]*local_ndata[XX];
2461 }
2462
2463 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2464                                real *mesh_energy, matrix vir)
2465 {
2466     /* This function sums output over threads and should therefore
2467      * only be called after thread synchronization.
2468      */
2469     int thread;
2470
2471     *mesh_energy = pme->work[0].energy_q;
2472     copy_mat(pme->work[0].vir_q, vir);
2473
2474     for (thread = 1; thread < nthread; thread++)
2475     {
2476         *mesh_energy += pme->work[thread].energy_q;
2477         m_add(vir, pme->work[thread].vir_q, vir);
2478     }
2479 }
2480
2481 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2482                                 real *mesh_energy, matrix vir)
2483 {
2484     /* This function sums output over threads and should therefore
2485      * only be called after thread synchronization.
2486      */
2487     int thread;
2488
2489     *mesh_energy = pme->work[0].energy_lj;
2490     copy_mat(pme->work[0].vir_lj, vir);
2491
2492     for (thread = 1; thread < nthread; thread++)
2493     {
2494         *mesh_energy += pme->work[thread].energy_lj;
2495         m_add(vir, pme->work[thread].vir_lj, vir);
2496     }
2497 }
2498
2499
2500 #define DO_FSPLINE(order)                      \
2501     for (ithx = 0; (ithx < order); ithx++)              \
2502     {                                              \
2503         index_x = (i0+ithx)*pny*pnz;               \
2504         tx      = thx[ithx];                       \
2505         dx      = dthx[ithx];                      \
2506                                                \
2507         for (ithy = 0; (ithy < order); ithy++)          \
2508         {                                          \
2509             index_xy = index_x+(j0+ithy)*pnz;      \
2510             ty       = thy[ithy];                  \
2511             dy       = dthy[ithy];                 \
2512             fxy1     = fz1 = 0;                    \
2513                                                \
2514             for (ithz = 0; (ithz < order); ithz++)      \
2515             {                                      \
2516                 gval  = grid[index_xy+(k0+ithz)];  \
2517                 fxy1 += thz[ithz]*gval;            \
2518                 fz1  += dthz[ithz]*gval;           \
2519             }                                      \
2520             fx += dx*ty*fxy1;                      \
2521             fy += tx*dy*fxy1;                      \
2522             fz += tx*ty*fz1;                       \
2523         }                                          \
2524     }
2525
2526
2527 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2528                               gmx_bool bClearF, pme_atomcomm_t *atc,
2529                               splinedata_t *spline,
2530                               real scale)
2531 {
2532     /* sum forces for local particles */
2533     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2534     int     index_x, index_xy;
2535     int     nx, ny, nz, pnx, pny, pnz;
2536     int *   idxptr;
2537     real    tx, ty, dx, dy, coefficient;
2538     real    fx, fy, fz, gval;
2539     real    fxy1, fz1;
2540     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2541     int     norder;
2542     real    rxx, ryx, ryy, rzx, rzy, rzz;
2543     int     order;
2544
2545     pme_spline_work_t *work;
2546
2547 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2548     real           thz_buffer[GMX_SIMD4_WIDTH*3],  *thz_aligned;
2549     real           dthz_buffer[GMX_SIMD4_WIDTH*3], *dthz_aligned;
2550
2551     thz_aligned  = gmx_simd4_align_r(thz_buffer);
2552     dthz_aligned = gmx_simd4_align_r(dthz_buffer);
2553 #endif
2554
2555     work = pme->spline_work;
2556
2557     order = pme->pme_order;
2558     thx   = spline->theta[XX];
2559     thy   = spline->theta[YY];
2560     thz   = spline->theta[ZZ];
2561     dthx  = spline->dtheta[XX];
2562     dthy  = spline->dtheta[YY];
2563     dthz  = spline->dtheta[ZZ];
2564     nx    = pme->nkx;
2565     ny    = pme->nky;
2566     nz    = pme->nkz;
2567     pnx   = pme->pmegrid_nx;
2568     pny   = pme->pmegrid_ny;
2569     pnz   = pme->pmegrid_nz;
2570
2571     rxx   = pme->recipbox[XX][XX];
2572     ryx   = pme->recipbox[YY][XX];
2573     ryy   = pme->recipbox[YY][YY];
2574     rzx   = pme->recipbox[ZZ][XX];
2575     rzy   = pme->recipbox[ZZ][YY];
2576     rzz   = pme->recipbox[ZZ][ZZ];
2577
2578     for (nn = 0; nn < spline->n; nn++)
2579     {
2580         n           = spline->ind[nn];
2581         coefficient = scale*atc->coefficient[n];
2582
2583         if (bClearF)
2584         {
2585             atc->f[n][XX] = 0;
2586             atc->f[n][YY] = 0;
2587             atc->f[n][ZZ] = 0;
2588         }
2589         if (coefficient != 0)
2590         {
2591             fx     = 0;
2592             fy     = 0;
2593             fz     = 0;
2594             idxptr = atc->idx[n];
2595             norder = nn*order;
2596
2597             i0   = idxptr[XX];
2598             j0   = idxptr[YY];
2599             k0   = idxptr[ZZ];
2600
2601             /* Pointer arithmetic alert, next six statements */
2602             thx  = spline->theta[XX] + norder;
2603             thy  = spline->theta[YY] + norder;
2604             thz  = spline->theta[ZZ] + norder;
2605             dthx = spline->dtheta[XX] + norder;
2606             dthy = spline->dtheta[YY] + norder;
2607             dthz = spline->dtheta[ZZ] + norder;
2608
2609             switch (order)
2610             {
2611                 case 4:
2612 #ifdef PME_SIMD4_SPREAD_GATHER
2613 #ifdef PME_SIMD4_UNALIGNED
2614 #define PME_GATHER_F_SIMD4_ORDER4
2615 #else
2616 #define PME_GATHER_F_SIMD4_ALIGNED
2617 #define PME_ORDER 4
2618 #endif
2619 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
2620 #else
2621                     DO_FSPLINE(4);
2622 #endif
2623                     break;
2624                 case 5:
2625 #ifdef PME_SIMD4_SPREAD_GATHER
2626 #define PME_GATHER_F_SIMD4_ALIGNED
2627 #define PME_ORDER 5
2628 #include "gromacs/ewald/pme-simd4.h" /* IWYU pragma: keep */
2629 #else
2630                     DO_FSPLINE(5);
2631 #endif
2632                     break;
2633                 default:
2634                     DO_FSPLINE(order);
2635                     break;
2636             }
2637
2638             atc->f[n][XX] += -coefficient*( fx*nx*rxx );
2639             atc->f[n][YY] += -coefficient*( fx*nx*ryx + fy*ny*ryy );
2640             atc->f[n][ZZ] += -coefficient*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2641         }
2642     }
2643     /* Since the energy and not forces are interpolated
2644      * the net force might not be exactly zero.
2645      * This can be solved by also interpolating F, but
2646      * that comes at a cost.
2647      * A better hack is to remove the net force every
2648      * step, but that must be done at a higher level
2649      * since this routine doesn't see all atoms if running
2650      * in parallel. Don't know how important it is?  EL 990726
2651      */
2652 }
2653
2654
2655 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2656                                    pme_atomcomm_t *atc)
2657 {
2658     splinedata_t *spline;
2659     int     n, ithx, ithy, ithz, i0, j0, k0;
2660     int     index_x, index_xy;
2661     int *   idxptr;
2662     real    energy, pot, tx, ty, coefficient, gval;
2663     real    *thx, *thy, *thz;
2664     int     norder;
2665     int     order;
2666
2667     spline = &atc->spline[0];
2668
2669     order = pme->pme_order;
2670
2671     energy = 0;
2672     for (n = 0; (n < atc->n); n++)
2673     {
2674         coefficient      = atc->coefficient[n];
2675
2676         if (coefficient != 0)
2677         {
2678             idxptr = atc->idx[n];
2679             norder = n*order;
2680
2681             i0   = idxptr[XX];
2682             j0   = idxptr[YY];
2683             k0   = idxptr[ZZ];
2684
2685             /* Pointer arithmetic alert, next three statements */
2686             thx  = spline->theta[XX] + norder;
2687             thy  = spline->theta[YY] + norder;
2688             thz  = spline->theta[ZZ] + norder;
2689
2690             pot = 0;
2691             for (ithx = 0; (ithx < order); ithx++)
2692             {
2693                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2694                 tx      = thx[ithx];
2695
2696                 for (ithy = 0; (ithy < order); ithy++)
2697                 {
2698                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2699                     ty       = thy[ithy];
2700
2701                     for (ithz = 0; (ithz < order); ithz++)
2702                     {
2703                         gval  = grid[index_xy+(k0+ithz)];
2704                         pot  += tx*ty*thz[ithz]*gval;
2705                     }
2706
2707                 }
2708             }
2709
2710             energy += pot*coefficient;
2711         }
2712     }
2713
2714     return energy;
2715 }
2716
2717 /* Macro to force loop unrolling by fixing order.
2718  * This gives a significant performance gain.
2719  */
2720 #define CALC_SPLINE(order)                     \
2721     {                                              \
2722         int j, k, l;                                 \
2723         real dr, div;                               \
2724         real data[PME_ORDER_MAX];                  \
2725         real ddata[PME_ORDER_MAX];                 \
2726                                                \
2727         for (j = 0; (j < DIM); j++)                     \
2728         {                                          \
2729             dr  = xptr[j];                         \
2730                                                \
2731             /* dr is relative offset from lower cell limit */ \
2732             data[order-1] = 0;                     \
2733             data[1]       = dr;                          \
2734             data[0]       = 1 - dr;                      \
2735                                                \
2736             for (k = 3; (k < order); k++)               \
2737             {                                      \
2738                 div       = 1.0/(k - 1.0);               \
2739                 data[k-1] = div*dr*data[k-2];      \
2740                 for (l = 1; (l < (k-1)); l++)           \
2741                 {                                  \
2742                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2743                                        data[k-l-1]);                \
2744                 }                                  \
2745                 data[0] = div*(1-dr)*data[0];      \
2746             }                                      \
2747             /* differentiate */                    \
2748             ddata[0] = -data[0];                   \
2749             for (k = 1; (k < order); k++)               \
2750             {                                      \
2751                 ddata[k] = data[k-1] - data[k];    \
2752             }                                      \
2753                                                \
2754             div           = 1.0/(order - 1);                 \
2755             data[order-1] = div*dr*data[order-2];  \
2756             for (l = 1; (l < (order-1)); l++)           \
2757             {                                      \
2758                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2759                                        (order-l-dr)*data[order-l-1]); \
2760             }                                      \
2761             data[0] = div*(1 - dr)*data[0];        \
2762                                                \
2763             for (k = 0; k < order; k++)                 \
2764             {                                      \
2765                 theta[j][i*order+k]  = data[k];    \
2766                 dtheta[j][i*order+k] = ddata[k];   \
2767             }                                      \
2768         }                                          \
2769     }
2770
2771 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2772                    rvec fractx[], int nr, int ind[], real coefficient[],
2773                    gmx_bool bDoSplines)
2774 {
2775     /* construct splines for local atoms */
2776     int  i, ii;
2777     real *xptr;
2778
2779     for (i = 0; i < nr; i++)
2780     {
2781         /* With free energy we do not use the coefficient check.
2782          * In most cases this will be more efficient than calling make_bsplines
2783          * twice, since usually more than half the particles have non-zero coefficients.
2784          */
2785         ii = ind[i];
2786         if (bDoSplines || coefficient[ii] != 0.0)
2787         {
2788             xptr = fractx[ii];
2789             switch (order)
2790             {
2791                 case 4:  CALC_SPLINE(4);     break;
2792                 case 5:  CALC_SPLINE(5);     break;
2793                 default: CALC_SPLINE(order); break;
2794             }
2795         }
2796     }
2797 }
2798
2799
2800 void make_dft_mod(real *mod, real *data, int ndata)
2801 {
2802     int i, j;
2803     real sc, ss, arg;
2804
2805     for (i = 0; i < ndata; i++)
2806     {
2807         sc = ss = 0;
2808         for (j = 0; j < ndata; j++)
2809         {
2810             arg = (2.0*M_PI*i*j)/ndata;
2811             sc += data[j]*cos(arg);
2812             ss += data[j]*sin(arg);
2813         }
2814         mod[i] = sc*sc+ss*ss;
2815     }
2816     for (i = 0; i < ndata; i++)
2817     {
2818         if (mod[i] < 1e-7)
2819         {
2820             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2821         }
2822     }
2823 }
2824
2825
2826 static void make_bspline_moduli(splinevec bsp_mod,
2827                                 int nx, int ny, int nz, int order)
2828 {
2829     int nmax = max(nx, max(ny, nz));
2830     real *data, *ddata, *bsp_data;
2831     int i, k, l;
2832     real div;
2833
2834     snew(data, order);
2835     snew(ddata, order);
2836     snew(bsp_data, nmax);
2837
2838     data[order-1] = 0;
2839     data[1]       = 0;
2840     data[0]       = 1;
2841
2842     for (k = 3; k < order; k++)
2843     {
2844         div       = 1.0/(k-1.0);
2845         data[k-1] = 0;
2846         for (l = 1; l < (k-1); l++)
2847         {
2848             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2849         }
2850         data[0] = div*data[0];
2851     }
2852     /* differentiate */
2853     ddata[0] = -data[0];
2854     for (k = 1; k < order; k++)
2855     {
2856         ddata[k] = data[k-1]-data[k];
2857     }
2858     div           = 1.0/(order-1);
2859     data[order-1] = 0;
2860     for (l = 1; l < (order-1); l++)
2861     {
2862         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2863     }
2864     data[0] = div*data[0];
2865
2866     for (i = 0; i < nmax; i++)
2867     {
2868         bsp_data[i] = 0;
2869     }
2870     for (i = 1; i <= order; i++)
2871     {
2872         bsp_data[i] = data[i-1];
2873     }
2874
2875     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2876     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2877     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2878
2879     sfree(data);
2880     sfree(ddata);
2881     sfree(bsp_data);
2882 }
2883
2884
2885 /* Return the P3M optimal influence function */
2886 static double do_p3m_influence(double z, int order)
2887 {
2888     double z2, z4;
2889
2890     z2 = z*z;
2891     z4 = z2*z2;
2892
2893     /* The formula and most constants can be found in:
2894      * Ballenegger et al., JCTC 8, 936 (2012)
2895      */
2896     switch (order)
2897     {
2898         case 2:
2899             return 1.0 - 2.0*z2/3.0;
2900             break;
2901         case 3:
2902             return 1.0 - z2 + 2.0*z4/15.0;
2903             break;
2904         case 4:
2905             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2906             break;
2907         case 5:
2908             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2909             break;
2910         case 6:
2911             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2912             break;
2913         case 7:
2914             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2915         case 8:
2916             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2917             break;
2918     }
2919
2920     return 0.0;
2921 }
2922
2923 /* Calculate the P3M B-spline moduli for one dimension */
2924 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2925 {
2926     double zarg, zai, sinzai, infl;
2927     int    maxk, i;
2928
2929     if (order > 8)
2930     {
2931         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2932     }
2933
2934     zarg = M_PI/n;
2935
2936     maxk = (n + 1)/2;
2937
2938     for (i = -maxk; i < 0; i++)
2939     {
2940         zai          = zarg*i;
2941         sinzai       = sin(zai);
2942         infl         = do_p3m_influence(sinzai, order);
2943         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2944     }
2945     bsp_mod[0] = 1.0;
2946     for (i = 1; i < maxk; i++)
2947     {
2948         zai        = zarg*i;
2949         sinzai     = sin(zai);
2950         infl       = do_p3m_influence(sinzai, order);
2951         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2952     }
2953 }
2954
2955 /* Calculate the P3M B-spline moduli */
2956 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2957                                     int nx, int ny, int nz, int order)
2958 {
2959     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2960     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2961     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2962 }
2963
2964
2965 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2966 {
2967     int nslab, n, i;
2968     int fw, bw;
2969
2970     nslab = atc->nslab;
2971
2972     n = 0;
2973     for (i = 1; i <= nslab/2; i++)
2974     {
2975         fw = (atc->nodeid + i) % nslab;
2976         bw = (atc->nodeid - i + nslab) % nslab;
2977         if (n < nslab - 1)
2978         {
2979             atc->node_dest[n] = fw;
2980             atc->node_src[n]  = bw;
2981             n++;
2982         }
2983         if (n < nslab - 1)
2984         {
2985             atc->node_dest[n] = bw;
2986             atc->node_src[n]  = fw;
2987             n++;
2988         }
2989     }
2990 }
2991
2992 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2993 {
2994     int thread, i;
2995
2996     if (NULL != log)
2997     {
2998         fprintf(log, "Destroying PME data structures.\n");
2999     }
3000
3001     sfree((*pmedata)->nnx);
3002     sfree((*pmedata)->nny);
3003     sfree((*pmedata)->nnz);
3004
3005     for (i = 0; i < (*pmedata)->ngrids; ++i)
3006     {
3007         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
3008         sfree((*pmedata)->fftgrid[i]);
3009         sfree((*pmedata)->cfftgrid[i]);
3010         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
3011     }
3012
3013     sfree((*pmedata)->lb_buf1);
3014     sfree((*pmedata)->lb_buf2);
3015
3016     for (thread = 0; thread < (*pmedata)->nthread; thread++)
3017     {
3018         free_work(&(*pmedata)->work[thread]);
3019     }
3020     sfree((*pmedata)->work);
3021
3022     sfree(*pmedata);
3023     *pmedata = NULL;
3024
3025     return 0;
3026 }
3027
3028 static int mult_up(int n, int f)
3029 {
3030     return ((n + f - 1)/f)*f;
3031 }
3032
3033
3034 static double pme_load_imbalance(gmx_pme_t pme)
3035 {
3036     int    nma, nmi;
3037     double n1, n2, n3;
3038
3039     nma = pme->nnodes_major;
3040     nmi = pme->nnodes_minor;
3041
3042     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3043     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3044     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3045
3046     /* pme_solve is roughly double the cost of an fft */
3047
3048     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3049 }
3050
3051 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3052                           int dimind, gmx_bool bSpread)
3053 {
3054     int nk, k, s, thread;
3055
3056     atc->dimind    = dimind;
3057     atc->nslab     = 1;
3058     atc->nodeid    = 0;
3059     atc->pd_nalloc = 0;
3060 #ifdef GMX_MPI
3061     if (pme->nnodes > 1)
3062     {
3063         atc->mpi_comm = pme->mpi_comm_d[dimind];
3064         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3065         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3066     }
3067     if (debug)
3068     {
3069         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3070     }
3071 #endif
3072
3073     atc->bSpread   = bSpread;
3074     atc->pme_order = pme->pme_order;
3075
3076     if (atc->nslab > 1)
3077     {
3078         snew(atc->node_dest, atc->nslab);
3079         snew(atc->node_src, atc->nslab);
3080         setup_coordinate_communication(atc);
3081
3082         snew(atc->count_thread, pme->nthread);
3083         for (thread = 0; thread < pme->nthread; thread++)
3084         {
3085             snew(atc->count_thread[thread], atc->nslab);
3086         }
3087         atc->count = atc->count_thread[0];
3088         snew(atc->rcount, atc->nslab);
3089         snew(atc->buf_index, atc->nslab);
3090     }
3091
3092     atc->nthread = pme->nthread;
3093     if (atc->nthread > 1)
3094     {
3095         snew(atc->thread_plist, atc->nthread);
3096     }
3097     snew(atc->spline, atc->nthread);
3098     for (thread = 0; thread < atc->nthread; thread++)
3099     {
3100         if (atc->nthread > 1)
3101         {
3102             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3103             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3104         }
3105         snew(atc->spline[thread].thread_one, pme->nthread);
3106         atc->spline[thread].thread_one[thread] = 1;
3107     }
3108 }
3109
3110 static void
3111 init_overlap_comm(pme_overlap_t *  ol,
3112                   int              norder,
3113 #ifdef GMX_MPI
3114                   MPI_Comm         comm,
3115 #endif
3116                   int              nnodes,
3117                   int              nodeid,
3118                   int              ndata,
3119                   int              commplainsize)
3120 {
3121     int lbnd, rbnd, maxlr, b, i;
3122     int exten;
3123     int nn, nk;
3124     pme_grid_comm_t *pgc;
3125     gmx_bool bCont;
3126     int fft_start, fft_end, send_index1, recv_index1;
3127 #ifdef GMX_MPI
3128     MPI_Status stat;
3129
3130     ol->mpi_comm = comm;
3131 #endif
3132
3133     ol->nnodes = nnodes;
3134     ol->nodeid = nodeid;
3135
3136     /* Linear translation of the PME grid won't affect reciprocal space
3137      * calculations, so to optimize we only interpolate "upwards",
3138      * which also means we only have to consider overlap in one direction.
3139      * I.e., particles on this node might also be spread to grid indices
3140      * that belong to higher nodes (modulo nnodes)
3141      */
3142
3143     snew(ol->s2g0, ol->nnodes+1);
3144     snew(ol->s2g1, ol->nnodes);
3145     if (debug)
3146     {
3147         fprintf(debug, "PME slab boundaries:");
3148     }
3149     for (i = 0; i < nnodes; i++)
3150     {
3151         /* s2g0 the local interpolation grid start.
3152          * s2g1 the local interpolation grid end.
3153          * Since in calc_pidx we divide particles, and not grid lines,
3154          * spatially uniform along dimension x or y, we need to round
3155          * s2g0 down and s2g1 up.
3156          */
3157         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3158         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3159
3160         if (debug)
3161         {
3162             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3163         }
3164     }
3165     ol->s2g0[nnodes] = ndata;
3166     if (debug)
3167     {
3168         fprintf(debug, "\n");
3169     }
3170
3171     /* Determine with how many nodes we need to communicate the grid overlap */
3172     b = 0;
3173     do
3174     {
3175         b++;
3176         bCont = FALSE;
3177         for (i = 0; i < nnodes; i++)
3178         {
3179             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3180                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3181             {
3182                 bCont = TRUE;
3183             }
3184         }
3185     }
3186     while (bCont && b < nnodes);
3187     ol->noverlap_nodes = b - 1;
3188
3189     snew(ol->send_id, ol->noverlap_nodes);
3190     snew(ol->recv_id, ol->noverlap_nodes);
3191     for (b = 0; b < ol->noverlap_nodes; b++)
3192     {
3193         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3194         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3195     }
3196     snew(ol->comm_data, ol->noverlap_nodes);
3197
3198     ol->send_size = 0;
3199     for (b = 0; b < ol->noverlap_nodes; b++)
3200     {
3201         pgc = &ol->comm_data[b];
3202         /* Send */
3203         fft_start        = ol->s2g0[ol->send_id[b]];
3204         fft_end          = ol->s2g0[ol->send_id[b]+1];
3205         if (ol->send_id[b] < nodeid)
3206         {
3207             fft_start += ndata;
3208             fft_end   += ndata;
3209         }
3210         send_index1       = ol->s2g1[nodeid];
3211         send_index1       = min(send_index1, fft_end);
3212         pgc->send_index0  = fft_start;
3213         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3214         ol->send_size    += pgc->send_nindex;
3215
3216         /* We always start receiving to the first index of our slab */
3217         fft_start        = ol->s2g0[ol->nodeid];
3218         fft_end          = ol->s2g0[ol->nodeid+1];
3219         recv_index1      = ol->s2g1[ol->recv_id[b]];
3220         if (ol->recv_id[b] > nodeid)
3221         {
3222             recv_index1 -= ndata;
3223         }
3224         recv_index1      = min(recv_index1, fft_end);
3225         pgc->recv_index0 = fft_start;
3226         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3227     }
3228
3229 #ifdef GMX_MPI
3230     /* Communicate the buffer sizes to receive */
3231     for (b = 0; b < ol->noverlap_nodes; b++)
3232     {
3233         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3234                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3235                      ol->mpi_comm, &stat);
3236     }
3237 #endif
3238
3239     /* For non-divisible grid we need pme_order iso pme_order-1 */
3240     snew(ol->sendbuf, norder*commplainsize);
3241     snew(ol->recvbuf, norder*commplainsize);
3242 }
3243
3244 static void
3245 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3246                               int **global_to_local,
3247                               real **fraction_shift)
3248 {
3249     int i;
3250     int * gtl;
3251     real * fsh;
3252
3253     snew(gtl, 5*n);
3254     snew(fsh, 5*n);
3255     for (i = 0; (i < 5*n); i++)
3256     {
3257         /* Determine the global to local grid index */
3258         gtl[i] = (i - local_start + n) % n;
3259         /* For coordinates that fall within the local grid the fraction
3260          * is correct, we don't need to shift it.
3261          */
3262         fsh[i] = 0;
3263         if (local_range < n)
3264         {
3265             /* Due to rounding issues i could be 1 beyond the lower or
3266              * upper boundary of the local grid. Correct the index for this.
3267              * If we shift the index, we need to shift the fraction by
3268              * the same amount in the other direction to not affect
3269              * the weights.
3270              * Note that due to this shifting the weights at the end of
3271              * the spline might change, but that will only involve values
3272              * between zero and values close to the precision of a real,
3273              * which is anyhow the accuracy of the whole mesh calculation.
3274              */
3275             /* With local_range=0 we should not change i=local_start */
3276             if (i % n != local_start)
3277             {
3278                 if (gtl[i] == n-1)
3279                 {
3280                     gtl[i] = 0;
3281                     fsh[i] = -1;
3282                 }
3283                 else if (gtl[i] == local_range)
3284                 {
3285                     gtl[i] = local_range - 1;
3286                     fsh[i] = 1;
3287                 }
3288             }
3289         }
3290     }
3291
3292     *global_to_local = gtl;
3293     *fraction_shift  = fsh;
3294 }
3295
3296 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3297 {
3298     pme_spline_work_t *work;
3299
3300 #ifdef PME_SIMD4_SPREAD_GATHER
3301     real             tmp[GMX_SIMD4_WIDTH*3], *tmp_aligned;
3302     gmx_simd4_real_t zero_S;
3303     gmx_simd4_real_t real_mask_S0, real_mask_S1;
3304     int              of, i;
3305
3306     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3307
3308     tmp_aligned = gmx_simd4_align_r(tmp);
3309
3310     zero_S = gmx_simd4_setzero_r();
3311
3312     /* Generate bit masks to mask out the unused grid entries,
3313      * as we only operate on order of the 8 grid entries that are
3314      * load into 2 SIMD registers.
3315      */
3316     for (of = 0; of < 2*GMX_SIMD4_WIDTH-(order-1); of++)
3317     {
3318         for (i = 0; i < 2*GMX_SIMD4_WIDTH; i++)
3319         {
3320             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3321         }
3322         real_mask_S0      = gmx_simd4_load_r(tmp_aligned);
3323         real_mask_S1      = gmx_simd4_load_r(tmp_aligned+GMX_SIMD4_WIDTH);
3324         work->mask_S0[of] = gmx_simd4_cmplt_r(real_mask_S0, zero_S);
3325         work->mask_S1[of] = gmx_simd4_cmplt_r(real_mask_S1, zero_S);
3326     }
3327 #else
3328     work = NULL;
3329 #endif
3330
3331     return work;
3332 }
3333
3334 void gmx_pme_check_restrictions(int pme_order,
3335                                 int nkx, int nky, int nkz,
3336                                 int nnodes_major,
3337                                 int nnodes_minor,
3338                                 gmx_bool bUseThreads,
3339                                 gmx_bool bFatal,
3340                                 gmx_bool *bValidSettings)
3341 {
3342     if (pme_order > PME_ORDER_MAX)
3343     {
3344         if (!bFatal)
3345         {
3346             *bValidSettings = FALSE;
3347             return;
3348         }
3349         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3350                   pme_order, PME_ORDER_MAX);
3351     }
3352
3353     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3354         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3355         nkz <= pme_order)
3356     {
3357         if (!bFatal)
3358         {
3359             *bValidSettings = FALSE;
3360             return;
3361         }
3362         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3363                   pme_order);
3364     }
3365
3366     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3367      * We only allow multiple communication pulses in dim 1, not in dim 0.
3368      */
3369     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3370                         nkx != nnodes_major*(pme_order - 1)))
3371     {
3372         if (!bFatal)
3373         {
3374             *bValidSettings = FALSE;
3375             return;
3376         }
3377         gmx_fatal(FARGS, "The number of PME grid lines per rank along x is %g. But when using OpenMP threads, the number of grid lines per rank along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use fewer ranks along x (and possibly more along y and/or z) by specifying -dd manually.",
3378                   nkx/(double)nnodes_major, pme_order);
3379     }
3380
3381     if (bValidSettings != NULL)
3382     {
3383         *bValidSettings = TRUE;
3384     }
3385
3386     return;
3387 }
3388
3389 int gmx_pme_init(gmx_pme_t *         pmedata,
3390                  t_commrec *         cr,
3391                  int                 nnodes_major,
3392                  int                 nnodes_minor,
3393                  t_inputrec *        ir,
3394                  int                 homenr,
3395                  gmx_bool            bFreeEnergy_q,
3396                  gmx_bool            bFreeEnergy_lj,
3397                  gmx_bool            bReproducible,
3398                  int                 nthread)
3399 {
3400     gmx_pme_t pme = NULL;
3401
3402     int  use_threads, sum_use_threads, i;
3403     ivec ndata;
3404
3405     if (debug)
3406     {
3407         fprintf(debug, "Creating PME data structures.\n");
3408     }
3409     snew(pme, 1);
3410
3411     pme->sum_qgrid_tmp       = NULL;
3412     pme->sum_qgrid_dd_tmp    = NULL;
3413     pme->buf_nalloc          = 0;
3414
3415     pme->nnodes              = 1;
3416     pme->bPPnode             = TRUE;
3417
3418     pme->nnodes_major        = nnodes_major;
3419     pme->nnodes_minor        = nnodes_minor;
3420
3421 #ifdef GMX_MPI
3422     if (nnodes_major*nnodes_minor > 1)
3423     {
3424         pme->mpi_comm = cr->mpi_comm_mygroup;
3425
3426         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3427         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3428         if (pme->nnodes != nnodes_major*nnodes_minor)
3429         {
3430             gmx_incons("PME rank count mismatch");
3431         }
3432     }
3433     else
3434     {
3435         pme->mpi_comm = MPI_COMM_NULL;
3436     }
3437 #endif
3438
3439     if (pme->nnodes == 1)
3440     {
3441 #ifdef GMX_MPI
3442         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3443         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3444 #endif
3445         pme->ndecompdim   = 0;
3446         pme->nodeid_major = 0;
3447         pme->nodeid_minor = 0;
3448 #ifdef GMX_MPI
3449         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3450 #endif
3451     }
3452     else
3453     {
3454         if (nnodes_minor == 1)
3455         {
3456 #ifdef GMX_MPI
3457             pme->mpi_comm_d[0] = pme->mpi_comm;
3458             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3459 #endif
3460             pme->ndecompdim   = 1;
3461             pme->nodeid_major = pme->nodeid;
3462             pme->nodeid_minor = 0;
3463
3464         }
3465         else if (nnodes_major == 1)
3466         {
3467 #ifdef GMX_MPI
3468             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3469             pme->mpi_comm_d[1] = pme->mpi_comm;
3470 #endif
3471             pme->ndecompdim   = 1;
3472             pme->nodeid_major = 0;
3473             pme->nodeid_minor = pme->nodeid;
3474         }
3475         else
3476         {
3477             if (pme->nnodes % nnodes_major != 0)
3478             {
3479                 gmx_incons("For 2D PME decomposition, #PME ranks must be divisible by the number of ranks in the major dimension");
3480             }
3481             pme->ndecompdim = 2;
3482
3483 #ifdef GMX_MPI
3484             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3485                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3486             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3487                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3488
3489             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3490             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3491             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3492             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3493 #endif
3494         }
3495         pme->bPPnode = (cr->duty & DUTY_PP);
3496     }
3497
3498     pme->nthread = nthread;
3499
3500     /* Check if any of the PME MPI ranks uses threads */
3501     use_threads = (pme->nthread > 1 ? 1 : 0);
3502 #ifdef GMX_MPI
3503     if (pme->nnodes > 1)
3504     {
3505         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3506                       MPI_SUM, pme->mpi_comm);
3507     }
3508     else
3509 #endif
3510     {
3511         sum_use_threads = use_threads;
3512     }
3513     pme->bUseThreads = (sum_use_threads > 0);
3514
3515     if (ir->ePBC == epbcSCREW)
3516     {
3517         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3518     }
3519
3520     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3521     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3522     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3523     pme->nkx         = ir->nkx;
3524     pme->nky         = ir->nky;
3525     pme->nkz         = ir->nkz;
3526     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3527     pme->pme_order   = ir->pme_order;
3528
3529     /* Always constant electrostatics coefficients */
3530     pme->epsilon_r   = ir->epsilon_r;
3531
3532     /* Always constant LJ coefficients */
3533     pme->ljpme_combination_rule = ir->ljpme_combination_rule;
3534
3535     /* If we violate restrictions, generate a fatal error here */
3536     gmx_pme_check_restrictions(pme->pme_order,
3537                                pme->nkx, pme->nky, pme->nkz,
3538                                pme->nnodes_major,
3539                                pme->nnodes_minor,
3540                                pme->bUseThreads,
3541                                TRUE,
3542                                NULL);
3543
3544     if (pme->nnodes > 1)
3545     {
3546         double imbal;
3547
3548 #ifdef GMX_MPI
3549         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3550         MPI_Type_commit(&(pme->rvec_mpi));
3551 #endif
3552
3553         /* Note that the coefficient spreading and force gathering, which usually
3554          * takes about the same amount of time as FFT+solve_pme,
3555          * is always fully load balanced
3556          * (unless the coefficient distribution is inhomogeneous).
3557          */
3558
3559         imbal = pme_load_imbalance(pme);
3560         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3561         {
3562             fprintf(stderr,
3563                     "\n"
3564                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3565                     "      For optimal PME load balancing\n"
3566                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_ranks_x (%d)\n"
3567                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_ranks_y (%d)\n"
3568                     "\n",
3569                     (int)((imbal-1)*100 + 0.5),
3570                     pme->nkx, pme->nky, pme->nnodes_major,
3571                     pme->nky, pme->nkz, pme->nnodes_minor);
3572         }
3573     }
3574
3575     /* For non-divisible grid we need pme_order iso pme_order-1 */
3576     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3577      * y is always copied through a buffer: we don't need padding in z,
3578      * but we do need the overlap in x because of the communication order.
3579      */
3580     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3581 #ifdef GMX_MPI
3582                       pme->mpi_comm_d[0],
3583 #endif
3584                       pme->nnodes_major, pme->nodeid_major,
3585                       pme->nkx,
3586                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3587
3588     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3589      * We do this with an offset buffer of equal size, so we need to allocate
3590      * extra for the offset. That's what the (+1)*pme->nkz is for.
3591      */
3592     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3593 #ifdef GMX_MPI
3594                       pme->mpi_comm_d[1],
3595 #endif
3596                       pme->nnodes_minor, pme->nodeid_minor,
3597                       pme->nky,
3598                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3599
3600     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3601      * Note that gmx_pme_check_restrictions checked for this already.
3602      */
3603     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3604     {
3605         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3606     }
3607
3608     snew(pme->bsp_mod[XX], pme->nkx);
3609     snew(pme->bsp_mod[YY], pme->nky);
3610     snew(pme->bsp_mod[ZZ], pme->nkz);
3611
3612     /* The required size of the interpolation grid, including overlap.
3613      * The allocated size (pmegrid_n?) might be slightly larger.
3614      */
3615     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3616         pme->overlap[0].s2g0[pme->nodeid_major];
3617     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3618         pme->overlap[1].s2g0[pme->nodeid_minor];
3619     pme->pmegrid_nz_base = pme->nkz;
3620     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3621     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3622
3623     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3624     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3625     pme->pmegrid_start_iz = 0;
3626
3627     make_gridindex5_to_localindex(pme->nkx,
3628                                   pme->pmegrid_start_ix,
3629                                   pme->pmegrid_nx - (pme->pme_order-1),
3630                                   &pme->nnx, &pme->fshx);
3631     make_gridindex5_to_localindex(pme->nky,
3632                                   pme->pmegrid_start_iy,
3633                                   pme->pmegrid_ny - (pme->pme_order-1),
3634                                   &pme->nny, &pme->fshy);
3635     make_gridindex5_to_localindex(pme->nkz,
3636                                   pme->pmegrid_start_iz,
3637                                   pme->pmegrid_nz_base,
3638                                   &pme->nnz, &pme->fshz);
3639
3640     pme->spline_work = make_pme_spline_work(pme->pme_order);
3641
3642     ndata[0]    = pme->nkx;
3643     ndata[1]    = pme->nky;
3644     ndata[2]    = pme->nkz;
3645     /* It doesn't matter if we allocate too many grids here,
3646      * we only allocate and use the ones we need.
3647      */
3648     if (EVDW_PME(ir->vdwtype))
3649     {
3650         pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3651     }
3652     else
3653     {
3654         pme->ngrids = DO_Q;
3655     }
3656     snew(pme->fftgrid, pme->ngrids);
3657     snew(pme->cfftgrid, pme->ngrids);
3658     snew(pme->pfft_setup, pme->ngrids);
3659
3660     for (i = 0; i < pme->ngrids; ++i)
3661     {
3662         if ((i <  DO_Q && EEL_PME(ir->coulombtype) && (i == 0 ||
3663                                                        bFreeEnergy_q)) ||
3664             (i >= DO_Q && EVDW_PME(ir->vdwtype) && (i == 2 ||
3665                                                     bFreeEnergy_lj ||
3666                                                     ir->ljpme_combination_rule == eljpmeLB)))
3667         {
3668             pmegrids_init(&pme->pmegrid[i],
3669                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3670                           pme->pmegrid_nz_base,
3671                           pme->pme_order,
3672                           pme->bUseThreads,
3673                           pme->nthread,
3674                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3675                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3676             /* This routine will allocate the grid data to fit the FFTs */
3677             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3678                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3679                                     pme->mpi_comm_d,
3680                                     bReproducible, pme->nthread);
3681
3682         }
3683     }
3684
3685     if (!pme->bP3M)
3686     {
3687         /* Use plain SPME B-spline interpolation */
3688         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3689     }
3690     else
3691     {
3692         /* Use the P3M grid-optimized influence function */
3693         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3694     }
3695
3696     /* Use atc[0] for spreading */
3697     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3698     if (pme->ndecompdim >= 2)
3699     {
3700         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3701     }
3702
3703     if (pme->nnodes == 1)
3704     {
3705         pme->atc[0].n = homenr;
3706         pme_realloc_atomcomm_things(&pme->atc[0]);
3707     }
3708
3709     pme->lb_buf1       = NULL;
3710     pme->lb_buf2       = NULL;
3711     pme->lb_buf_nalloc = 0;
3712
3713     {
3714         int thread;
3715
3716         /* Use fft5d, order after FFT is y major, z, x minor */
3717
3718         snew(pme->work, pme->nthread);
3719         for (thread = 0; thread < pme->nthread; thread++)
3720         {
3721             realloc_work(&pme->work[thread], pme->nkx);
3722         }
3723     }
3724
3725     *pmedata = pme;
3726
3727     return 0;
3728 }
3729
3730 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3731 {
3732     int d, t;
3733
3734     for (d = 0; d < DIM; d++)
3735     {
3736         if (new->grid.n[d] > old->grid.n[d])
3737         {
3738             return;
3739         }
3740     }
3741
3742     sfree_aligned(new->grid.grid);
3743     new->grid.grid = old->grid.grid;
3744
3745     if (new->grid_th != NULL && new->nthread == old->nthread)
3746     {
3747         sfree_aligned(new->grid_all);
3748         for (t = 0; t < new->nthread; t++)
3749         {
3750             new->grid_th[t].grid = old->grid_th[t].grid;
3751         }
3752     }
3753 }
3754
3755 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3756                    t_commrec *         cr,
3757                    gmx_pme_t           pme_src,
3758                    const t_inputrec *  ir,
3759                    ivec                grid_size)
3760 {
3761     t_inputrec irc;
3762     int homenr;
3763     int ret;
3764
3765     irc     = *ir;
3766     irc.nkx = grid_size[XX];
3767     irc.nky = grid_size[YY];
3768     irc.nkz = grid_size[ZZ];
3769
3770     if (pme_src->nnodes == 1)
3771     {
3772         homenr = pme_src->atc[0].n;
3773     }
3774     else
3775     {
3776         homenr = -1;
3777     }
3778
3779     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3780                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3781
3782     if (ret == 0)
3783     {
3784         /* We can easily reuse the allocated pme grids in pme_src */
3785         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3786         /* We would like to reuse the fft grids, but that's harder */
3787     }
3788
3789     return ret;
3790 }
3791
3792
3793 static void copy_local_grid(gmx_pme_t pme, pmegrids_t *pmegrids,
3794                             int grid_index, int thread, real *fftgrid)
3795 {
3796     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3797     int  fft_my, fft_mz;
3798     int  nsx, nsy, nsz;
3799     ivec nf;
3800     int  offx, offy, offz, x, y, z, i0, i0t;
3801     int  d;
3802     pmegrid_t *pmegrid;
3803     real *grid_th;
3804
3805     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3806                                    local_fft_ndata,
3807                                    local_fft_offset,
3808                                    local_fft_size);
3809     fft_my = local_fft_size[YY];
3810     fft_mz = local_fft_size[ZZ];
3811
3812     pmegrid = &pmegrids->grid_th[thread];
3813
3814     nsx = pmegrid->s[XX];
3815     nsy = pmegrid->s[YY];
3816     nsz = pmegrid->s[ZZ];
3817
3818     for (d = 0; d < DIM; d++)
3819     {
3820         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3821                     local_fft_ndata[d] - pmegrid->offset[d]);
3822     }
3823
3824     offx = pmegrid->offset[XX];
3825     offy = pmegrid->offset[YY];
3826     offz = pmegrid->offset[ZZ];
3827
3828     /* Directly copy the non-overlapping parts of the local grids.
3829      * This also initializes the full grid.
3830      */
3831     grid_th = pmegrid->grid;
3832     for (x = 0; x < nf[XX]; x++)
3833     {
3834         for (y = 0; y < nf[YY]; y++)
3835         {
3836             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3837             i0t = (x*nsy + y)*nsz;
3838             for (z = 0; z < nf[ZZ]; z++)
3839             {
3840                 fftgrid[i0+z] = grid_th[i0t+z];
3841             }
3842         }
3843     }
3844 }
3845
3846 static void
3847 reduce_threadgrid_overlap(gmx_pme_t pme,
3848                           const pmegrids_t *pmegrids, int thread,
3849                           real *fftgrid, real *commbuf_x, real *commbuf_y,
3850                           int grid_index)
3851 {
3852     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3853     int  fft_nx, fft_ny, fft_nz;
3854     int  fft_my, fft_mz;
3855     int  buf_my = -1;
3856     int  nsx, nsy, nsz;
3857     ivec localcopy_end;
3858     int  offx, offy, offz, x, y, z, i0, i0t;
3859     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3860     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3861     gmx_bool bCommX, bCommY;
3862     int  d;
3863     int  thread_f;
3864     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3865     const real *grid_th;
3866     real *commbuf = NULL;
3867
3868     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3869                                    local_fft_ndata,
3870                                    local_fft_offset,
3871                                    local_fft_size);
3872     fft_nx = local_fft_ndata[XX];
3873     fft_ny = local_fft_ndata[YY];
3874     fft_nz = local_fft_ndata[ZZ];
3875
3876     fft_my = local_fft_size[YY];
3877     fft_mz = local_fft_size[ZZ];
3878
3879     /* This routine is called when all thread have finished spreading.
3880      * Here each thread sums grid contributions calculated by other threads
3881      * to the thread local grid volume.
3882      * To minimize the number of grid copying operations,
3883      * this routines sums immediately from the pmegrid to the fftgrid.
3884      */
3885
3886     /* Determine which part of the full node grid we should operate on,
3887      * this is our thread local part of the full grid.
3888      */
3889     pmegrid = &pmegrids->grid_th[thread];
3890
3891     for (d = 0; d < DIM; d++)
3892     {
3893         /* Determine up to where our thread needs to copy from the
3894          * thread-local charge spreading grid to the rank-local FFT grid.
3895          * This is up to our spreading grid end minus order-1 and
3896          * not beyond the local FFT grid.
3897          */
3898         localcopy_end[d] =
3899             min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3900                 local_fft_ndata[d]);
3901     }
3902
3903     offx = pmegrid->offset[XX];
3904     offy = pmegrid->offset[YY];
3905     offz = pmegrid->offset[ZZ];
3906
3907
3908     bClearBufX  = TRUE;
3909     bClearBufY  = TRUE;
3910     bClearBufXY = TRUE;
3911
3912     /* Now loop over all the thread data blocks that contribute
3913      * to the grid region we (our thread) are operating on.
3914      */
3915     /* Note that fft_nx/y is equal to the number of grid points
3916      * between the first point of our node grid and the one of the next node.
3917      */
3918     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3919     {
3920         fx     = pmegrid->ci[XX] + sx;
3921         ox     = 0;
3922         bCommX = FALSE;
3923         if (fx < 0)
3924         {
3925             fx    += pmegrids->nc[XX];
3926             ox    -= fft_nx;
3927             bCommX = (pme->nnodes_major > 1);
3928         }
3929         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3930         ox       += pmegrid_g->offset[XX];
3931         /* Determine the end of our part of the source grid */
3932         if (!bCommX)
3933         {
3934             /* Use our thread local source grid and target grid part */
3935             tx1 = min(ox + pmegrid_g->n[XX], localcopy_end[XX]);
3936         }
3937         else
3938         {
3939             /* Use our thread local source grid and the spreading range */
3940             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3941         }
3942
3943         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3944         {
3945             fy     = pmegrid->ci[YY] + sy;
3946             oy     = 0;
3947             bCommY = FALSE;
3948             if (fy < 0)
3949             {
3950                 fy    += pmegrids->nc[YY];
3951                 oy    -= fft_ny;
3952                 bCommY = (pme->nnodes_minor > 1);
3953             }
3954             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3955             oy       += pmegrid_g->offset[YY];
3956             /* Determine the end of our part of the source grid */
3957             if (!bCommY)
3958             {
3959                 /* Use our thread local source grid and target grid part */
3960                 ty1 = min(oy + pmegrid_g->n[YY], localcopy_end[YY]);
3961             }
3962             else
3963             {
3964                 /* Use our thread local source grid and the spreading range */
3965                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3966             }
3967
3968             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3969             {
3970                 fz = pmegrid->ci[ZZ] + sz;
3971                 oz = 0;
3972                 if (fz < 0)
3973                 {
3974                     fz += pmegrids->nc[ZZ];
3975                     oz -= fft_nz;
3976                 }
3977                 pmegrid_g = &pmegrids->grid_th[fz];
3978                 oz       += pmegrid_g->offset[ZZ];
3979                 tz1       = min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
3980
3981                 if (sx == 0 && sy == 0 && sz == 0)
3982                 {
3983                     /* We have already added our local contribution
3984                      * before calling this routine, so skip it here.
3985                      */
3986                     continue;
3987                 }
3988
3989                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3990
3991                 pmegrid_f = &pmegrids->grid_th[thread_f];
3992
3993                 grid_th = pmegrid_f->grid;
3994
3995                 nsx = pmegrid_f->s[XX];
3996                 nsy = pmegrid_f->s[YY];
3997                 nsz = pmegrid_f->s[ZZ];
3998
3999 #ifdef DEBUG_PME_REDUCE
4000                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
4001                        pme->nodeid, thread, thread_f,
4002                        pme->pmegrid_start_ix,
4003                        pme->pmegrid_start_iy,
4004                        pme->pmegrid_start_iz,
4005                        sx, sy, sz,
4006                        offx-ox, tx1-ox, offx, tx1,
4007                        offy-oy, ty1-oy, offy, ty1,
4008                        offz-oz, tz1-oz, offz, tz1);
4009 #endif
4010
4011                 if (!(bCommX || bCommY))
4012                 {
4013                     /* Copy from the thread local grid to the node grid */
4014                     for (x = offx; x < tx1; x++)
4015                     {
4016                         for (y = offy; y < ty1; y++)
4017                         {
4018                             i0  = (x*fft_my + y)*fft_mz;
4019                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4020                             for (z = offz; z < tz1; z++)
4021                             {
4022                                 fftgrid[i0+z] += grid_th[i0t+z];
4023                             }
4024                         }
4025                     }
4026                 }
4027                 else
4028                 {
4029                     /* The order of this conditional decides
4030                      * where the corner volume gets stored with x+y decomp.
4031                      */
4032                     if (bCommY)
4033                     {
4034                         commbuf = commbuf_y;
4035                         /* The y-size of the communication buffer is set by
4036                          * the overlap of the grid part of our local slab
4037                          * with the part starting at the next slab.
4038                          */
4039                         buf_my  =
4040                             pme->overlap[1].s2g1[pme->nodeid_minor] -
4041                             pme->overlap[1].s2g0[pme->nodeid_minor+1];
4042                         if (bCommX)
4043                         {
4044                             /* We index commbuf modulo the local grid size */
4045                             commbuf += buf_my*fft_nx*fft_nz;
4046
4047                             bClearBuf   = bClearBufXY;
4048                             bClearBufXY = FALSE;
4049                         }
4050                         else
4051                         {
4052                             bClearBuf  = bClearBufY;
4053                             bClearBufY = FALSE;
4054                         }
4055                     }
4056                     else
4057                     {
4058                         commbuf    = commbuf_x;
4059                         buf_my     = fft_ny;
4060                         bClearBuf  = bClearBufX;
4061                         bClearBufX = FALSE;
4062                     }
4063
4064                     /* Copy to the communication buffer */
4065                     for (x = offx; x < tx1; x++)
4066                     {
4067                         for (y = offy; y < ty1; y++)
4068                         {
4069                             i0  = (x*buf_my + y)*fft_nz;
4070                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4071
4072                             if (bClearBuf)
4073                             {
4074                                 /* First access of commbuf, initialize it */
4075                                 for (z = offz; z < tz1; z++)
4076                                 {
4077                                     commbuf[i0+z]  = grid_th[i0t+z];
4078                                 }
4079                             }
4080                             else
4081                             {
4082                                 for (z = offz; z < tz1; z++)
4083                                 {
4084                                     commbuf[i0+z] += grid_th[i0t+z];
4085                                 }
4086                             }
4087                         }
4088                     }
4089                 }
4090             }
4091         }
4092     }
4093 }
4094
4095
4096 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid, int grid_index)
4097 {
4098     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4099     pme_overlap_t *overlap;
4100     int  send_index0, send_nindex;
4101     int  recv_nindex;
4102 #ifdef GMX_MPI
4103     MPI_Status stat;
4104 #endif
4105     int  send_size_y, recv_size_y;
4106     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4107     real *sendptr, *recvptr;
4108     int  x, y, z, indg, indb;
4109
4110     /* Note that this routine is only used for forward communication.
4111      * Since the force gathering, unlike the coefficient spreading,
4112      * can be trivially parallelized over the particles,
4113      * the backwards process is much simpler and can use the "old"
4114      * communication setup.
4115      */
4116
4117     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
4118                                    local_fft_ndata,
4119                                    local_fft_offset,
4120                                    local_fft_size);
4121
4122     if (pme->nnodes_minor > 1)
4123     {
4124         /* Major dimension */
4125         overlap = &pme->overlap[1];
4126
4127         if (pme->nnodes_major > 1)
4128         {
4129             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4130         }
4131         else
4132         {
4133             size_yx = 0;
4134         }
4135         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4136
4137         send_size_y = overlap->send_size;
4138
4139         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4140         {
4141             send_id       = overlap->send_id[ipulse];
4142             recv_id       = overlap->recv_id[ipulse];
4143             send_index0   =
4144                 overlap->comm_data[ipulse].send_index0 -
4145                 overlap->comm_data[0].send_index0;
4146             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4147             /* We don't use recv_index0, as we always receive starting at 0 */
4148             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4149             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4150
4151             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4152             recvptr = overlap->recvbuf;
4153
4154             if (debug != NULL)
4155             {
4156                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n",
4157                         local_fft_ndata[XX], send_nindex, local_fft_ndata[ZZ]);
4158             }
4159
4160 #ifdef GMX_MPI
4161             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4162                          send_id, ipulse,
4163                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4164                          recv_id, ipulse,
4165                          overlap->mpi_comm, &stat);
4166 #endif
4167
4168             for (x = 0; x < local_fft_ndata[XX]; x++)
4169             {
4170                 for (y = 0; y < recv_nindex; y++)
4171                 {
4172                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4173                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4174                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4175                     {
4176                         fftgrid[indg+z] += recvptr[indb+z];
4177                     }
4178                 }
4179             }
4180
4181             if (pme->nnodes_major > 1)
4182             {
4183                 /* Copy from the received buffer to the send buffer for dim 0 */
4184                 sendptr = pme->overlap[0].sendbuf;
4185                 for (x = 0; x < size_yx; x++)
4186                 {
4187                     for (y = 0; y < recv_nindex; y++)
4188                     {
4189                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4190                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4191                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4192                         {
4193                             sendptr[indg+z] += recvptr[indb+z];
4194                         }
4195                     }
4196                 }
4197             }
4198         }
4199     }
4200
4201     /* We only support a single pulse here.
4202      * This is not a severe limitation, as this code is only used
4203      * with OpenMP and with OpenMP the (PME) domains can be larger.
4204      */
4205     if (pme->nnodes_major > 1)
4206     {
4207         /* Major dimension */
4208         overlap = &pme->overlap[0];
4209
4210         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4211         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4212
4213         ipulse = 0;
4214
4215         send_id       = overlap->send_id[ipulse];
4216         recv_id       = overlap->recv_id[ipulse];
4217         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4218         /* We don't use recv_index0, as we always receive starting at 0 */
4219         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4220
4221         sendptr = overlap->sendbuf;
4222         recvptr = overlap->recvbuf;
4223
4224         if (debug != NULL)
4225         {
4226             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n",
4227                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4228         }
4229
4230 #ifdef GMX_MPI
4231         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4232                      send_id, ipulse,
4233                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4234                      recv_id, ipulse,
4235                      overlap->mpi_comm, &stat);
4236 #endif
4237
4238         for (x = 0; x < recv_nindex; x++)
4239         {
4240             for (y = 0; y < local_fft_ndata[YY]; y++)
4241             {
4242                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4243                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4244                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4245                 {
4246                     fftgrid[indg+z] += recvptr[indb+z];
4247                 }
4248             }
4249         }
4250     }
4251 }
4252
4253
4254 static void spread_on_grid(gmx_pme_t pme,
4255                            pme_atomcomm_t *atc, pmegrids_t *grids,
4256                            gmx_bool bCalcSplines, gmx_bool bSpread,
4257                            real *fftgrid, gmx_bool bDoSplines, int grid_index)
4258 {
4259     int nthread, thread;
4260 #ifdef PME_TIME_THREADS
4261     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4262     static double cs1     = 0, cs2 = 0, cs3 = 0;
4263     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4264     static int cnt        = 0;
4265 #endif
4266
4267     nthread = pme->nthread;
4268     assert(nthread > 0);
4269
4270 #ifdef PME_TIME_THREADS
4271     c1 = omp_cyc_start();
4272 #endif
4273     if (bCalcSplines)
4274     {
4275 #pragma omp parallel for num_threads(nthread) schedule(static)
4276         for (thread = 0; thread < nthread; thread++)
4277         {
4278             int start, end;
4279
4280             start = atc->n* thread   /nthread;
4281             end   = atc->n*(thread+1)/nthread;
4282
4283             /* Compute fftgrid index for all atoms,
4284              * with help of some extra variables.
4285              */
4286             calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
4287         }
4288     }
4289 #ifdef PME_TIME_THREADS
4290     c1   = omp_cyc_end(c1);
4291     cs1 += (double)c1;
4292 #endif
4293
4294 #ifdef PME_TIME_THREADS
4295     c2 = omp_cyc_start();
4296 #endif
4297 #pragma omp parallel for num_threads(nthread) schedule(static)
4298     for (thread = 0; thread < nthread; thread++)
4299     {
4300         splinedata_t *spline;
4301         pmegrid_t *grid = NULL;
4302
4303         /* make local bsplines  */
4304         if (grids == NULL || !pme->bUseThreads)
4305         {
4306             spline = &atc->spline[0];
4307
4308             spline->n = atc->n;
4309
4310             if (bSpread)
4311             {
4312                 grid = &grids->grid;
4313             }
4314         }
4315         else
4316         {
4317             spline = &atc->spline[thread];
4318
4319             if (grids->nthread == 1)
4320             {
4321                 /* One thread, we operate on all coefficients */
4322                 spline->n = atc->n;
4323             }
4324             else
4325             {
4326                 /* Get the indices our thread should operate on */
4327                 make_thread_local_ind(atc, thread, spline);
4328             }
4329
4330             grid = &grids->grid_th[thread];
4331         }
4332
4333         if (bCalcSplines)
4334         {
4335             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4336                           atc->fractx, spline->n, spline->ind, atc->coefficient, bDoSplines);
4337         }
4338
4339         if (bSpread)
4340         {
4341             /* put local atoms on grid. */
4342 #ifdef PME_TIME_SPREAD
4343             ct1a = omp_cyc_start();
4344 #endif
4345             spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
4346
4347             if (pme->bUseThreads)
4348             {
4349                 copy_local_grid(pme, grids, grid_index, thread, fftgrid);
4350             }
4351 #ifdef PME_TIME_SPREAD
4352             ct1a          = omp_cyc_end(ct1a);
4353             cs1a[thread] += (double)ct1a;
4354 #endif
4355         }
4356     }
4357 #ifdef PME_TIME_THREADS
4358     c2   = omp_cyc_end(c2);
4359     cs2 += (double)c2;
4360 #endif
4361
4362     if (bSpread && pme->bUseThreads)
4363     {
4364 #ifdef PME_TIME_THREADS
4365         c3 = omp_cyc_start();
4366 #endif
4367 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4368         for (thread = 0; thread < grids->nthread; thread++)
4369         {
4370             reduce_threadgrid_overlap(pme, grids, thread,
4371                                       fftgrid,
4372                                       pme->overlap[0].sendbuf,
4373                                       pme->overlap[1].sendbuf,
4374                                       grid_index);
4375         }
4376 #ifdef PME_TIME_THREADS
4377         c3   = omp_cyc_end(c3);
4378         cs3 += (double)c3;
4379 #endif
4380
4381         if (pme->nnodes > 1)
4382         {
4383             /* Communicate the overlapping part of the fftgrid.
4384              * For this communication call we need to check pme->bUseThreads
4385              * to have all ranks communicate here, regardless of pme->nthread.
4386              */
4387             sum_fftgrid_dd(pme, fftgrid, grid_index);
4388         }
4389     }
4390
4391 #ifdef PME_TIME_THREADS
4392     cnt++;
4393     if (cnt % 20 == 0)
4394     {
4395         printf("idx %.2f spread %.2f red %.2f",
4396                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4397 #ifdef PME_TIME_SPREAD
4398         for (thread = 0; thread < nthread; thread++)
4399         {
4400             printf(" %.2f", cs1a[thread]*1e-9);
4401         }
4402 #endif
4403         printf("\n");
4404     }
4405 #endif
4406 }
4407
4408
4409 static void dump_grid(FILE *fp,
4410                       int sx, int sy, int sz, int nx, int ny, int nz,
4411                       int my, int mz, const real *g)
4412 {
4413     int x, y, z;
4414
4415     for (x = 0; x < nx; x++)
4416     {
4417         for (y = 0; y < ny; y++)
4418         {
4419             for (z = 0; z < nz; z++)
4420             {
4421                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4422                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4423             }
4424         }
4425     }
4426 }
4427
4428 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4429 {
4430     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4431
4432     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4433                                    local_fft_ndata,
4434                                    local_fft_offset,
4435                                    local_fft_size);
4436
4437     dump_grid(stderr,
4438               pme->pmegrid_start_ix,
4439               pme->pmegrid_start_iy,
4440               pme->pmegrid_start_iz,
4441               pme->pmegrid_nx-pme->pme_order+1,
4442               pme->pmegrid_ny-pme->pme_order+1,
4443               pme->pmegrid_nz-pme->pme_order+1,
4444               local_fft_size[YY],
4445               local_fft_size[ZZ],
4446               fftgrid);
4447 }
4448
4449
4450 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4451 {
4452     pme_atomcomm_t *atc;
4453     pmegrids_t *grid;
4454
4455     if (pme->nnodes > 1)
4456     {
4457         gmx_incons("gmx_pme_calc_energy called in parallel");
4458     }
4459     if (pme->bFEP_q > 1)
4460     {
4461         gmx_incons("gmx_pme_calc_energy with free energy");
4462     }
4463
4464     atc            = &pme->atc_energy;
4465     atc->nthread   = 1;
4466     if (atc->spline == NULL)
4467     {
4468         snew(atc->spline, atc->nthread);
4469     }
4470     atc->nslab     = 1;
4471     atc->bSpread   = TRUE;
4472     atc->pme_order = pme->pme_order;
4473     atc->n         = n;
4474     pme_realloc_atomcomm_things(atc);
4475     atc->x           = x;
4476     atc->coefficient = q;
4477
4478     /* We only use the A-charges grid */
4479     grid = &pme->pmegrid[PME_GRID_QA];
4480
4481     /* Only calculate the spline coefficients, don't actually spread */
4482     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE, PME_GRID_QA);
4483
4484     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4485 }
4486
4487
4488 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4489                                    gmx_walltime_accounting_t walltime_accounting,
4490                                    t_nrnb *nrnb, t_inputrec *ir,
4491                                    gmx_int64_t step)
4492 {
4493     /* Reset all the counters related to performance over the run */
4494     wallcycle_stop(wcycle, ewcRUN);
4495     wallcycle_reset_all(wcycle);
4496     init_nrnb(nrnb);
4497     if (ir->nsteps >= 0)
4498     {
4499         /* ir->nsteps is not used here, but we update it for consistency */
4500         ir->nsteps -= step - ir->init_step;
4501     }
4502     ir->init_step = step;
4503     wallcycle_start(wcycle, ewcRUN);
4504     walltime_accounting_start(walltime_accounting);
4505 }
4506
4507
4508 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4509                                ivec grid_size,
4510                                t_commrec *cr, t_inputrec *ir,
4511                                gmx_pme_t *pme_ret)
4512 {
4513     int ind;
4514     gmx_pme_t pme = NULL;
4515
4516     ind = 0;
4517     while (ind < *npmedata)
4518     {
4519         pme = (*pmedata)[ind];
4520         if (pme->nkx == grid_size[XX] &&
4521             pme->nky == grid_size[YY] &&
4522             pme->nkz == grid_size[ZZ])
4523         {
4524             *pme_ret = pme;
4525
4526             return;
4527         }
4528
4529         ind++;
4530     }
4531
4532     (*npmedata)++;
4533     srenew(*pmedata, *npmedata);
4534
4535     /* Generate a new PME data structure, copying part of the old pointers */
4536     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4537
4538     *pme_ret = (*pmedata)[ind];
4539 }
4540
4541 int gmx_pmeonly(gmx_pme_t pme,
4542                 t_commrec *cr,    t_nrnb *mynrnb,
4543                 gmx_wallcycle_t wcycle,
4544                 gmx_walltime_accounting_t walltime_accounting,
4545                 real ewaldcoeff_q, real ewaldcoeff_lj,
4546                 t_inputrec *ir)
4547 {
4548     int npmedata;
4549     gmx_pme_t *pmedata;
4550     gmx_pme_pp_t pme_pp;
4551     int  ret;
4552     int  natoms;
4553     matrix box;
4554     rvec *x_pp      = NULL, *f_pp = NULL;
4555     real *chargeA   = NULL, *chargeB = NULL;
4556     real *c6A       = NULL, *c6B = NULL;
4557     real *sigmaA    = NULL, *sigmaB = NULL;
4558     real lambda_q   = 0;
4559     real lambda_lj  = 0;
4560     int  maxshift_x = 0, maxshift_y = 0;
4561     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4562     matrix vir_q, vir_lj;
4563     float cycles;
4564     int  count;
4565     gmx_bool bEnerVir;
4566     int pme_flags;
4567     gmx_int64_t step, step_rel;
4568     ivec grid_switch;
4569
4570     /* This data will only use with PME tuning, i.e. switching PME grids */
4571     npmedata = 1;
4572     snew(pmedata, npmedata);
4573     pmedata[0] = pme;
4574
4575     pme_pp = gmx_pme_pp_init(cr);
4576
4577     init_nrnb(mynrnb);
4578
4579     count = 0;
4580     do /****** this is a quasi-loop over time steps! */
4581     {
4582         /* The reason for having a loop here is PME grid tuning/switching */
4583         do
4584         {
4585             /* Domain decomposition */
4586             ret = gmx_pme_recv_coeffs_coords(pme_pp,
4587                                              &natoms,
4588                                              &chargeA, &chargeB,
4589                                              &c6A, &c6B,
4590                                              &sigmaA, &sigmaB,
4591                                              box, &x_pp, &f_pp,
4592                                              &maxshift_x, &maxshift_y,
4593                                              &pme->bFEP_q, &pme->bFEP_lj,
4594                                              &lambda_q, &lambda_lj,
4595                                              &bEnerVir,
4596                                              &pme_flags,
4597                                              &step,
4598                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4599
4600             if (ret == pmerecvqxSWITCHGRID)
4601             {
4602                 /* Switch the PME grid to grid_switch */
4603                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4604             }
4605
4606             if (ret == pmerecvqxRESETCOUNTERS)
4607             {
4608                 /* Reset the cycle and flop counters */
4609                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4610             }
4611         }
4612         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4613
4614         if (ret == pmerecvqxFINISH)
4615         {
4616             /* We should stop: break out of the loop */
4617             break;
4618         }
4619
4620         step_rel = step - ir->init_step;
4621
4622         if (count == 0)
4623         {
4624             wallcycle_start(wcycle, ewcRUN);
4625             walltime_accounting_start(walltime_accounting);
4626         }
4627
4628         wallcycle_start(wcycle, ewcPMEMESH);
4629
4630         dvdlambda_q  = 0;
4631         dvdlambda_lj = 0;
4632         clear_mat(vir_q);
4633         clear_mat(vir_lj);
4634
4635         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4636                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4637                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4638                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4639                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4640                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4641
4642         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4643
4644         gmx_pme_send_force_vir_ener(pme_pp,
4645                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4646                                     dvdlambda_q, dvdlambda_lj, cycles);
4647
4648         count++;
4649     } /***** end of quasi-loop, we stop with the break above */
4650     while (TRUE);
4651
4652     walltime_accounting_end(walltime_accounting);
4653
4654     return 0;
4655 }
4656
4657 static void
4658 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4659 {
4660     int  i;
4661
4662     for (i = 0; i < pme->atc[0].n; ++i)
4663     {
4664         real sigma4;
4665
4666         sigma4                     = local_sigma[i];
4667         sigma4                     = sigma4*sigma4;
4668         sigma4                     = sigma4*sigma4;
4669         pme->atc[0].coefficient[i] = local_c6[i] / sigma4;
4670     }
4671 }
4672
4673 static void
4674 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4675 {
4676     int  i;
4677
4678     for (i = 0; i < pme->atc[0].n; ++i)
4679     {
4680         pme->atc[0].coefficient[i] *= local_sigma[i];
4681     }
4682 }
4683
4684 static void
4685 do_redist_pos_coeffs(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4686                      gmx_bool bFirst, rvec x[], real *data)
4687 {
4688     int      d;
4689     pme_atomcomm_t *atc;
4690     atc = &pme->atc[0];
4691
4692     for (d = pme->ndecompdim - 1; d >= 0; d--)
4693     {
4694         int             n_d;
4695         rvec           *x_d;
4696         real           *param_d;
4697
4698         if (d == pme->ndecompdim - 1)
4699         {
4700             n_d     = homenr;
4701             x_d     = x + start;
4702             param_d = data;
4703         }
4704         else
4705         {
4706             n_d     = pme->atc[d + 1].n;
4707             x_d     = atc->x;
4708             param_d = atc->coefficient;
4709         }
4710         atc      = &pme->atc[d];
4711         atc->npd = n_d;
4712         if (atc->npd > atc->pd_nalloc)
4713         {
4714             atc->pd_nalloc = over_alloc_dd(atc->npd);
4715             srenew(atc->pd, atc->pd_nalloc);
4716         }
4717         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4718         where();
4719         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4720         if (DOMAINDECOMP(cr))
4721         {
4722             dd_pmeredist_pos_coeffs(pme, n_d, bFirst, x_d, param_d, atc);
4723         }
4724     }
4725 }
4726
4727 int gmx_pme_do(gmx_pme_t pme,
4728                int start,       int homenr,
4729                rvec x[],        rvec f[],
4730                real *chargeA,   real *chargeB,
4731                real *c6A,       real *c6B,
4732                real *sigmaA,    real *sigmaB,
4733                matrix box, t_commrec *cr,
4734                int  maxshift_x, int maxshift_y,
4735                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4736                matrix vir_q,      real ewaldcoeff_q,
4737                matrix vir_lj,   real ewaldcoeff_lj,
4738                real *energy_q,  real *energy_lj,
4739                real lambda_q, real lambda_lj,
4740                real *dvdlambda_q, real *dvdlambda_lj,
4741                int flags)
4742 {
4743     int     d, i, j, k, ntot, npme, grid_index, max_grid_index;
4744     int     nx, ny, nz;
4745     int     n_d, local_ny;
4746     pme_atomcomm_t *atc = NULL;
4747     pmegrids_t *pmegrid = NULL;
4748     real    *grid       = NULL;
4749     real    *ptr;
4750     rvec    *x_d, *f_d;
4751     real    *coefficient = NULL;
4752     real    energy_AB[4];
4753     matrix  vir_AB[4];
4754     real    scale, lambda;
4755     gmx_bool bClearF;
4756     gmx_parallel_3dfft_t pfft_setup;
4757     real *  fftgrid;
4758     t_complex * cfftgrid;
4759     int     thread;
4760     gmx_bool bFirst, bDoSplines;
4761     int fep_state;
4762     int fep_states_lj           = pme->bFEP_lj ? 2 : 1;
4763     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4764     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4765
4766     assert(pme->nnodes > 0);
4767     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4768
4769     if (pme->nnodes > 1)
4770     {
4771         atc      = &pme->atc[0];
4772         atc->npd = homenr;
4773         if (atc->npd > atc->pd_nalloc)
4774         {
4775             atc->pd_nalloc = over_alloc_dd(atc->npd);
4776             srenew(atc->pd, atc->pd_nalloc);
4777         }
4778         for (d = pme->ndecompdim-1; d >= 0; d--)
4779         {
4780             atc           = &pme->atc[d];
4781             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4782         }
4783     }
4784     else
4785     {
4786         atc = &pme->atc[0];
4787         /* This could be necessary for TPI */
4788         pme->atc[0].n = homenr;
4789         if (DOMAINDECOMP(cr))
4790         {
4791             pme_realloc_atomcomm_things(atc);
4792         }
4793         atc->x = x;
4794         atc->f = f;
4795     }
4796
4797     m_inv_ur0(box, pme->recipbox);
4798     bFirst = TRUE;
4799
4800     /* For simplicity, we construct the splines for all particles if
4801      * more than one PME calculations is needed. Some optimization
4802      * could be done by keeping track of which atoms have splines
4803      * constructed, and construct new splines on each pass for atoms
4804      * that don't yet have them.
4805      */
4806
4807     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4808
4809     /* We need a maximum of four separate PME calculations:
4810      * grid_index=0: Coulomb PME with charges from state A
4811      * grid_index=1: Coulomb PME with charges from state B
4812      * grid_index=2: LJ PME with C6 from state A
4813      * grid_index=3: LJ PME with C6 from state B
4814      * For Lorentz-Berthelot combination rules, a separate loop is used to
4815      * calculate all the terms
4816      */
4817
4818     /* If we are doing LJ-PME with LB, we only do Q here */
4819     max_grid_index = (pme->ljpme_combination_rule == eljpmeLB) ? DO_Q : DO_Q_AND_LJ;
4820
4821     for (grid_index = 0; grid_index < max_grid_index; ++grid_index)
4822     {
4823         /* Check if we should do calculations at this grid_index
4824          * If grid_index is odd we should be doing FEP
4825          * If grid_index < 2 we should be doing electrostatic PME
4826          * If grid_index >= 2 we should be doing LJ-PME
4827          */
4828         if ((grid_index <  DO_Q && (!(flags & GMX_PME_DO_COULOMB) ||
4829                                     (grid_index == 1 && !pme->bFEP_q))) ||
4830             (grid_index >= DO_Q && (!(flags & GMX_PME_DO_LJ) ||
4831                                     (grid_index == 3 && !pme->bFEP_lj))))
4832         {
4833             continue;
4834         }
4835         /* Unpack structure */
4836         pmegrid    = &pme->pmegrid[grid_index];
4837         fftgrid    = pme->fftgrid[grid_index];
4838         cfftgrid   = pme->cfftgrid[grid_index];
4839         pfft_setup = pme->pfft_setup[grid_index];
4840         switch (grid_index)
4841         {
4842             case 0: coefficient = chargeA + start; break;
4843             case 1: coefficient = chargeB + start; break;
4844             case 2: coefficient = c6A + start; break;
4845             case 3: coefficient = c6B + start; break;
4846         }
4847
4848         grid = pmegrid->grid.grid;
4849
4850         if (debug)
4851         {
4852             fprintf(debug, "PME: number of ranks = %d, rank = %d\n",
4853                     cr->nnodes, cr->nodeid);
4854             fprintf(debug, "Grid = %p\n", (void*)grid);
4855             if (grid == NULL)
4856             {
4857                 gmx_fatal(FARGS, "No grid!");
4858             }
4859         }
4860         where();
4861
4862         if (pme->nnodes == 1)
4863         {
4864             atc->coefficient = coefficient;
4865         }
4866         else
4867         {
4868             wallcycle_start(wcycle, ewcPME_REDISTXF);
4869             do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, coefficient);
4870             where();
4871
4872             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4873         }
4874
4875         if (debug)
4876         {
4877             fprintf(debug, "Rank= %6d, pme local particles=%6d\n",
4878                     cr->nodeid, atc->n);
4879         }
4880
4881         if (flags & GMX_PME_SPREAD)
4882         {
4883             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4884
4885             /* Spread the coefficients on a grid */
4886             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
4887
4888             if (bFirst)
4889             {
4890                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4891             }
4892             inc_nrnb(nrnb, eNR_SPREADBSP,
4893                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4894
4895             if (!pme->bUseThreads)
4896             {
4897                 wrap_periodic_pmegrid(pme, grid);
4898
4899                 /* sum contributions to local grid from other nodes */
4900 #ifdef GMX_MPI
4901                 if (pme->nnodes > 1)
4902                 {
4903                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
4904                     where();
4905                 }
4906 #endif
4907
4908                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
4909             }
4910
4911             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4912
4913             /*
4914                dump_local_fftgrid(pme,fftgrid);
4915                exit(0);
4916              */
4917         }
4918
4919         /* Here we start a large thread parallel region */
4920 #pragma omp parallel num_threads(pme->nthread) private(thread)
4921         {
4922             thread = gmx_omp_get_thread_num();
4923             if (flags & GMX_PME_SOLVE)
4924             {
4925                 int loop_count;
4926
4927                 /* do 3d-fft */
4928                 if (thread == 0)
4929                 {
4930                     wallcycle_start(wcycle, ewcPME_FFT);
4931                 }
4932                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4933                                            thread, wcycle);
4934                 if (thread == 0)
4935                 {
4936                     wallcycle_stop(wcycle, ewcPME_FFT);
4937                 }
4938                 where();
4939
4940                 /* solve in k-space for our local cells */
4941                 if (thread == 0)
4942                 {
4943                     wallcycle_start(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4944                 }
4945                 if (grid_index < DO_Q)
4946                 {
4947                     loop_count =
4948                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
4949                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4950                                       bCalcEnerVir,
4951                                       pme->nthread, thread);
4952                 }
4953                 else
4954                 {
4955                     loop_count =
4956                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
4957                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4958                                          bCalcEnerVir,
4959                                          pme->nthread, thread);
4960                 }
4961
4962                 if (thread == 0)
4963                 {
4964                     wallcycle_stop(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4965                     where();
4966                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4967                 }
4968             }
4969
4970             if (bCalcF)
4971             {
4972                 /* do 3d-invfft */
4973                 if (thread == 0)
4974                 {
4975                     where();
4976                     wallcycle_start(wcycle, ewcPME_FFT);
4977                 }
4978                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4979                                            thread, wcycle);
4980                 if (thread == 0)
4981                 {
4982                     wallcycle_stop(wcycle, ewcPME_FFT);
4983
4984                     where();
4985
4986                     if (pme->nodeid == 0)
4987                     {
4988                         ntot  = pme->nkx*pme->nky*pme->nkz;
4989                         npme  = ntot*log((real)ntot)/log(2.0);
4990                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4991                     }
4992
4993                     /* Note: this wallcycle region is closed below
4994                        outside an OpenMP region, so take care if
4995                        refactoring code here. */
4996                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4997                 }
4998
4999                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
5000             }
5001         }
5002         /* End of thread parallel section.
5003          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
5004          */
5005
5006         if (bCalcF)
5007         {
5008             /* distribute local grid to all nodes */
5009 #ifdef GMX_MPI
5010             if (pme->nnodes > 1)
5011             {
5012                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
5013             }
5014 #endif
5015             where();
5016
5017             unwrap_periodic_pmegrid(pme, grid);
5018
5019             /* interpolate forces for our local atoms */
5020
5021             where();
5022
5023             /* If we are running without parallelization,
5024              * atc->f is the actual force array, not a buffer,
5025              * therefore we should not clear it.
5026              */
5027             lambda  = grid_index < DO_Q ? lambda_q : lambda_lj;
5028             bClearF = (bFirst && PAR(cr));
5029 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5030             for (thread = 0; thread < pme->nthread; thread++)
5031             {
5032                 gather_f_bsplines(pme, grid, bClearF, atc,
5033                                   &atc->spline[thread],
5034                                   pme->bFEP ? (grid_index % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
5035             }
5036
5037             where();
5038
5039             inc_nrnb(nrnb, eNR_GATHERFBSP,
5040                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5041             /* Note: this wallcycle region is opened above inside an OpenMP
5042                region, so take care if refactoring code here. */
5043             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5044         }
5045
5046         if (bCalcEnerVir)
5047         {
5048             /* This should only be called on the master thread
5049              * and after the threads have synchronized.
5050              */
5051             if (grid_index < 2)
5052             {
5053                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5054             }
5055             else
5056             {
5057                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5058             }
5059         }
5060         bFirst = FALSE;
5061     } /* of grid_index-loop */
5062
5063     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
5064      * seven terms. */
5065
5066     if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB)
5067     {
5068         /* Loop over A- and B-state if we are doing FEP */
5069         for (fep_state = 0; fep_state < fep_states_lj; ++fep_state)
5070         {
5071             real *local_c6 = NULL, *local_sigma = NULL, *RedistC6 = NULL, *RedistSigma = NULL;
5072             if (pme->nnodes == 1)
5073             {
5074                 if (pme->lb_buf1 == NULL)
5075                 {
5076                     pme->lb_buf_nalloc = pme->atc[0].n;
5077                     snew(pme->lb_buf1, pme->lb_buf_nalloc);
5078                 }
5079                 pme->atc[0].coefficient = pme->lb_buf1;
5080                 switch (fep_state)
5081                 {
5082                     case 0:
5083                         local_c6      = c6A;
5084                         local_sigma   = sigmaA;
5085                         break;
5086                     case 1:
5087                         local_c6      = c6B;
5088                         local_sigma   = sigmaB;
5089                         break;
5090                     default:
5091                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5092                 }
5093             }
5094             else
5095             {
5096                 atc = &pme->atc[0];
5097                 switch (fep_state)
5098                 {
5099                     case 0:
5100                         RedistC6      = c6A;
5101                         RedistSigma   = sigmaA;
5102                         break;
5103                     case 1:
5104                         RedistC6      = c6B;
5105                         RedistSigma   = sigmaB;
5106                         break;
5107                     default:
5108                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5109                 }
5110                 wallcycle_start(wcycle, ewcPME_REDISTXF);
5111
5112                 do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, RedistC6);
5113                 if (pme->lb_buf_nalloc < atc->n)
5114                 {
5115                     pme->lb_buf_nalloc = atc->nalloc;
5116                     srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5117                     srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5118                 }
5119                 local_c6 = pme->lb_buf1;
5120                 for (i = 0; i < atc->n; ++i)
5121                 {
5122                     local_c6[i] = atc->coefficient[i];
5123                 }
5124                 where();
5125
5126                 do_redist_pos_coeffs(pme, cr, start, homenr, FALSE, x, RedistSigma);
5127                 local_sigma = pme->lb_buf2;
5128                 for (i = 0; i < atc->n; ++i)
5129                 {
5130                     local_sigma[i] = atc->coefficient[i];
5131                 }
5132                 where();
5133
5134                 wallcycle_stop(wcycle, ewcPME_REDISTXF);
5135             }
5136             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5137
5138             /*Seven terms in LJ-PME with LB, grid_index < 2 reserved for electrostatics*/
5139             for (grid_index = 2; grid_index < 9; ++grid_index)
5140             {
5141                 /* Unpack structure */
5142                 pmegrid    = &pme->pmegrid[grid_index];
5143                 fftgrid    = pme->fftgrid[grid_index];
5144                 cfftgrid   = pme->cfftgrid[grid_index];
5145                 pfft_setup = pme->pfft_setup[grid_index];
5146                 calc_next_lb_coeffs(pme, local_sigma);
5147                 grid = pmegrid->grid.grid;
5148                 where();
5149
5150                 if (flags & GMX_PME_SPREAD)
5151                 {
5152                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5153                     /* Spread the c6 on a grid */
5154                     spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
5155
5156                     if (bFirst)
5157                     {
5158                         inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5159                     }
5160
5161                     inc_nrnb(nrnb, eNR_SPREADBSP,
5162                              pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5163                     if (pme->nthread == 1)
5164                     {
5165                         wrap_periodic_pmegrid(pme, grid);
5166                         /* sum contributions to local grid from other nodes */
5167 #ifdef GMX_MPI
5168                         if (pme->nnodes > 1)
5169                         {
5170                             gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
5171                             where();
5172                         }
5173 #endif
5174                         copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
5175                     }
5176                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5177                 }
5178                 /*Here we start a large thread parallel region*/
5179 #pragma omp parallel num_threads(pme->nthread) private(thread)
5180                 {
5181                     thread = gmx_omp_get_thread_num();
5182                     if (flags & GMX_PME_SOLVE)
5183                     {
5184                         /* do 3d-fft */
5185                         if (thread == 0)
5186                         {
5187                             wallcycle_start(wcycle, ewcPME_FFT);
5188                         }
5189
5190                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5191                                                    thread, wcycle);
5192                         if (thread == 0)
5193                         {
5194                             wallcycle_stop(wcycle, ewcPME_FFT);
5195                         }
5196                         where();
5197                     }
5198                 }
5199                 bFirst = FALSE;
5200             }
5201             if (flags & GMX_PME_SOLVE)
5202             {
5203                 /* solve in k-space for our local cells */
5204 #pragma omp parallel num_threads(pme->nthread) private(thread)
5205                 {
5206                     int loop_count;
5207                     thread = gmx_omp_get_thread_num();
5208                     if (thread == 0)
5209                     {
5210                         wallcycle_start(wcycle, ewcLJPME);
5211                     }
5212
5213                     loop_count =
5214                         solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5215                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5216                                          bCalcEnerVir,
5217                                          pme->nthread, thread);
5218                     if (thread == 0)
5219                     {
5220                         wallcycle_stop(wcycle, ewcLJPME);
5221                         where();
5222                         inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5223                     }
5224                 }
5225             }
5226
5227             if (bCalcEnerVir)
5228             {
5229                 /* This should only be called on the master thread and
5230                  * after the threads have synchronized.
5231                  */
5232                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
5233             }
5234
5235             if (bCalcF)
5236             {
5237                 bFirst = !(flags & GMX_PME_DO_COULOMB);
5238                 calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5239                 for (grid_index = 8; grid_index >= 2; --grid_index)
5240                 {
5241                     /* Unpack structure */
5242                     pmegrid    = &pme->pmegrid[grid_index];
5243                     fftgrid    = pme->fftgrid[grid_index];
5244                     cfftgrid   = pme->cfftgrid[grid_index];
5245                     pfft_setup = pme->pfft_setup[grid_index];
5246                     grid       = pmegrid->grid.grid;
5247                     calc_next_lb_coeffs(pme, local_sigma);
5248                     where();
5249 #pragma omp parallel num_threads(pme->nthread) private(thread)
5250                     {
5251                         thread = gmx_omp_get_thread_num();
5252                         /* do 3d-invfft */
5253                         if (thread == 0)
5254                         {
5255                             where();
5256                             wallcycle_start(wcycle, ewcPME_FFT);
5257                         }
5258
5259                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5260                                                    thread, wcycle);
5261                         if (thread == 0)
5262                         {
5263                             wallcycle_stop(wcycle, ewcPME_FFT);
5264
5265                             where();
5266
5267                             if (pme->nodeid == 0)
5268                             {
5269                                 ntot  = pme->nkx*pme->nky*pme->nkz;
5270                                 npme  = ntot*log((real)ntot)/log(2.0);
5271                                 inc_nrnb(nrnb, eNR_FFT, 2*npme);
5272                             }
5273                             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5274                         }
5275
5276                         copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
5277
5278                     } /*#pragma omp parallel*/
5279
5280                     /* distribute local grid to all nodes */
5281 #ifdef GMX_MPI
5282                     if (pme->nnodes > 1)
5283                     {
5284                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
5285                     }
5286 #endif
5287                     where();
5288
5289                     unwrap_periodic_pmegrid(pme, grid);
5290
5291                     /* interpolate forces for our local atoms */
5292                     where();
5293                     bClearF = (bFirst && PAR(cr));
5294                     scale   = pme->bFEP ? (fep_state < 1 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5295                     scale  *= lb_scale_factor[grid_index-2];
5296 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5297                     for (thread = 0; thread < pme->nthread; thread++)
5298                     {
5299                         gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5300                                           &pme->atc[0].spline[thread],
5301                                           scale);
5302                     }
5303                     where();
5304
5305                     inc_nrnb(nrnb, eNR_GATHERFBSP,
5306                              pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5307                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5308
5309                     bFirst = FALSE;
5310                 } /* for (grid_index = 8; grid_index >= 2; --grid_index) */
5311             }     /* if (bCalcF) */
5312         }         /* for (fep_state = 0; fep_state < fep_states_lj; ++fep_state) */
5313     }             /* if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB) */
5314
5315     if (bCalcF && pme->nnodes > 1)
5316     {
5317         wallcycle_start(wcycle, ewcPME_REDISTXF);
5318         for (d = 0; d < pme->ndecompdim; d++)
5319         {
5320             atc = &pme->atc[d];
5321             if (d == pme->ndecompdim - 1)
5322             {
5323                 n_d = homenr;
5324                 f_d = f + start;
5325             }
5326             else
5327             {
5328                 n_d = pme->atc[d+1].n;
5329                 f_d = pme->atc[d+1].f;
5330             }
5331             if (DOMAINDECOMP(cr))
5332             {
5333                 dd_pmeredist_f(pme, atc, n_d, f_d,
5334                                d == pme->ndecompdim-1 && pme->bPPnode);
5335             }
5336         }
5337
5338         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5339     }
5340     where();
5341
5342     if (bCalcEnerVir)
5343     {
5344         if (flags & GMX_PME_DO_COULOMB)
5345         {
5346             if (!pme->bFEP_q)
5347             {
5348                 *energy_q = energy_AB[0];
5349                 m_add(vir_q, vir_AB[0], vir_q);
5350             }
5351             else
5352             {
5353                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5354                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5355                 for (i = 0; i < DIM; i++)
5356                 {
5357                     for (j = 0; j < DIM; j++)
5358                     {
5359                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5360                             lambda_q*vir_AB[1][i][j];
5361                     }
5362                 }
5363             }
5364             if (debug)
5365             {
5366                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5367             }
5368         }
5369         else
5370         {
5371             *energy_q = 0;
5372         }
5373
5374         if (flags & GMX_PME_DO_LJ)
5375         {
5376             if (!pme->bFEP_lj)
5377             {
5378                 *energy_lj = energy_AB[2];
5379                 m_add(vir_lj, vir_AB[2], vir_lj);
5380             }
5381             else
5382             {
5383                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5384                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5385                 for (i = 0; i < DIM; i++)
5386                 {
5387                     for (j = 0; j < DIM; j++)
5388                     {
5389                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5390                     }
5391                 }
5392             }
5393             if (debug)
5394             {
5395                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5396             }
5397         }
5398         else
5399         {
5400             *energy_lj = 0;
5401         }
5402     }
5403     return 0;
5404 }