Redesigned SIMD module and unit tests.
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #ifdef HAVE_CONFIG_H
61 #include <config.h>
62 #endif
63
64 #include <stdio.h>
65 #include <string.h>
66 #include <math.h>
67 #include <assert.h>
68 #include "typedefs.h"
69 #include "txtdump.h"
70 #include "vec.h"
71 #include "smalloc.h"
72 #include "coulomb.h"
73 #include "gmx_fatal.h"
74 #include "pme.h"
75 #include "network.h"
76 #include "physics.h"
77 #include "nrnb.h"
78 #include "macros.h"
79
80 #include "gromacs/fft/parallel_3dfft.h"
81 #include "gromacs/fileio/futil.h"
82 #include "gromacs/fileio/pdbio.h"
83 #include "gromacs/math/gmxcomplex.h"
84 #include "gromacs/timing/cyclecounter.h"
85 #include "gromacs/timing/wallcycle.h"
86 #include "gromacs/utility/gmxmpi.h"
87 #include "gromacs/utility/gmxomp.h"
88
89 /* Include the SIMD macro file and then check for support */
90 #include "gromacs/simd/simd.h"
91 #include "gromacs/simd/simd_math.h"
92 #ifdef GMX_SIMD_HAVE_REAL
93 /* Turn on arbitrary width SIMD intrinsics for PME solve */
94 #    define PME_SIMD_SOLVE
95 #endif
96
97 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
98 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
99 #define DO_Q           2 /* Electrostatic grids have index q<2 */
100 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
101 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
102
103 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
104 const real lb_scale_factor[] = {
105     1.0/64, 6.0/64, 15.0/64, 20.0/64,
106     15.0/64, 6.0/64, 1.0/64
107 };
108
109 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
110 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
111
112 /* Check if we have 4-wide SIMD macro support */
113 #if (defined GMX_SIMD4_HAVE_REAL)
114 /* Do PME spread and gather with 4-wide SIMD.
115  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
116  */
117 #    define PME_SIMD4_SPREAD_GATHER
118
119 #    if (defined GMX_SIMD_HAVE_LOADU) && (defined GMX_SIMD_HAVE_STOREU)
120 /* With PME-order=4 on x86, unaligned load+store is slightly faster
121  * than doubling all SIMD operations when using aligned load+store.
122  */
123 #        define PME_SIMD4_UNALIGNED
124 #    endif
125 #endif
126
127 #define DFT_TOL 1e-7
128 /* #define PRT_FORCE */
129 /* conditions for on the fly time-measurement */
130 /* #define TAKETIME (step > 1 && timesteps < 10) */
131 #define TAKETIME FALSE
132
133 /* #define PME_TIME_THREADS */
134
135 #ifdef GMX_DOUBLE
136 #define mpi_type MPI_DOUBLE
137 #else
138 #define mpi_type MPI_FLOAT
139 #endif
140
141 #ifdef PME_SIMD4_SPREAD_GATHER
142 #    define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
143 #else
144 /* We can use any alignment, apart from 0, so we use 4 reals */
145 #    define SIMD4_ALIGNMENT  (4*sizeof(real))
146 #endif
147
148 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
149  * to preserve alignment.
150  */
151 #define GMX_CACHE_SEP 64
152
153 /* We only define a maximum to be able to use local arrays without allocation.
154  * An order larger than 12 should never be needed, even for test cases.
155  * If needed it can be changed here.
156  */
157 #define PME_ORDER_MAX 12
158
159 /* Internal datastructures */
160 typedef struct {
161     int send_index0;
162     int send_nindex;
163     int recv_index0;
164     int recv_nindex;
165     int recv_size;   /* Receive buffer width, used with OpenMP */
166 } pme_grid_comm_t;
167
168 typedef struct {
169 #ifdef GMX_MPI
170     MPI_Comm         mpi_comm;
171 #endif
172     int              nnodes, nodeid;
173     int             *s2g0;
174     int             *s2g1;
175     int              noverlap_nodes;
176     int             *send_id, *recv_id;
177     int              send_size; /* Send buffer width, used with OpenMP */
178     pme_grid_comm_t *comm_data;
179     real            *sendbuf;
180     real            *recvbuf;
181 } pme_overlap_t;
182
183 typedef struct {
184     int *n;      /* Cumulative counts of the number of particles per thread */
185     int  nalloc; /* Allocation size of i */
186     int *i;      /* Particle indices ordered on thread index (n) */
187 } thread_plist_t;
188
189 typedef struct {
190     int      *thread_one;
191     int       n;
192     int      *ind;
193     splinevec theta;
194     real     *ptr_theta_z;
195     splinevec dtheta;
196     real     *ptr_dtheta_z;
197 } splinedata_t;
198
199 typedef struct {
200     int      dimind;        /* The index of the dimension, 0=x, 1=y */
201     int      nslab;
202     int      nodeid;
203 #ifdef GMX_MPI
204     MPI_Comm mpi_comm;
205 #endif
206
207     int     *node_dest;     /* The nodes to send x and q to with DD */
208     int     *node_src;      /* The nodes to receive x and q from with DD */
209     int     *buf_index;     /* Index for commnode into the buffers */
210
211     int      maxshift;
212
213     int      npd;
214     int      pd_nalloc;
215     int     *pd;
216     int     *count;         /* The number of atoms to send to each node */
217     int    **count_thread;
218     int     *rcount;        /* The number of atoms to receive */
219
220     int      n;
221     int      nalloc;
222     rvec    *x;
223     real    *q;
224     rvec    *f;
225     gmx_bool bSpread;       /* These coordinates are used for spreading */
226     int      pme_order;
227     ivec    *idx;
228     rvec    *fractx;            /* Fractional coordinate relative to
229                                  * the lower cell boundary
230                                  */
231     int             nthread;
232     int            *thread_idx; /* Which thread should spread which charge */
233     thread_plist_t *thread_plist;
234     splinedata_t   *spline;
235 } pme_atomcomm_t;
236
237 #define FLBS  3
238 #define FLBSZ 4
239
240 typedef struct {
241     ivec  ci;     /* The spatial location of this grid         */
242     ivec  n;      /* The used size of *grid, including order-1 */
243     ivec  offset; /* The grid offset from the full node grid   */
244     int   order;  /* PME spreading order                       */
245     ivec  s;      /* The allocated size of *grid, s >= n       */
246     real *grid;   /* The grid local thread, size n             */
247 } pmegrid_t;
248
249 typedef struct {
250     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
251     int        nthread;      /* The number of threads operating on this grid     */
252     ivec       nc;           /* The local spatial decomposition over the threads */
253     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
254     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
255     int      **g2t;          /* The grid to thread index                         */
256     ivec       nthread_comm; /* The number of threads to communicate with        */
257 } pmegrids_t;
258
259 typedef struct {
260 #ifdef PME_SIMD4_SPREAD_GATHER
261     /* Masks for 4-wide SIMD aligned spreading and gathering */
262     gmx_simd4_bool_t mask_S0[6], mask_S1[6];
263 #else
264     int              dummy; /* C89 requires that struct has at least one member */
265 #endif
266 } pme_spline_work_t;
267
268 typedef struct {
269     /* work data for solve_pme */
270     int      nalloc;
271     real *   mhx;
272     real *   mhy;
273     real *   mhz;
274     real *   m2;
275     real *   denom;
276     real *   tmp1_alloc;
277     real *   tmp1;
278     real *   tmp2;
279     real *   eterm;
280     real *   m2inv;
281
282     real     energy_q;
283     matrix   vir_q;
284     real     energy_lj;
285     matrix   vir_lj;
286 } pme_work_t;
287
288 typedef struct gmx_pme {
289     int           ndecompdim; /* The number of decomposition dimensions */
290     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
291     int           nodeid_major;
292     int           nodeid_minor;
293     int           nnodes;    /* The number of nodes doing PME */
294     int           nnodes_major;
295     int           nnodes_minor;
296
297     MPI_Comm      mpi_comm;
298     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
299 #ifdef GMX_MPI
300     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
301 #endif
302
303     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
304     int        nthread;       /* The number of threads doing PME on our rank */
305
306     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
307     gmx_bool   bFEP;          /* Compute Free energy contribution */
308     gmx_bool   bFEP_q;
309     gmx_bool   bFEP_lj;
310     int        nkx, nky, nkz; /* Grid dimensions */
311     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
312     int        pme_order;
313     real       epsilon_r;
314
315     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
316
317     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
318                                          * includes overlap Grid indices are ordered as
319                                          * follows:
320                                          * 0: Coloumb PME, state A
321                                          * 1: Coloumb PME, state B
322                                          * 2-8: LJ-PME
323                                          * This can probably be done in a better way
324                                          * but this simple hack works for now
325                                          */
326     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
327     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
328     /* pmegrid_nz might be larger than strictly necessary to ensure
329      * memory alignment, pmegrid_nz_base gives the real base size.
330      */
331     int     pmegrid_nz_base;
332     /* The local PME grid starting indices */
333     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
334
335     /* Work data for spreading and gathering */
336     pme_spline_work_t     *spline_work;
337
338     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
339     /* inside the interpolation grid, but separate for 2D PME decomp. */
340     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
341
342     t_complex            **cfftgrid;  /* Grids for complex FFT data */
343
344     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
345
346     gmx_parallel_3dfft_t  *pfft_setup;
347
348     int                   *nnx, *nny, *nnz;
349     real                  *fshx, *fshy, *fshz;
350
351     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
352     matrix                 recipbox;
353     splinevec              bsp_mod;
354     /* Buffers to store data for local atoms for L-B combination rule
355      * calculations in LJ-PME. lb_buf1 stores either the coefficients
356      * for spreading/gathering (in serial), or the C6 parameters for
357      * local atoms (in parallel).  lb_buf2 is only used in parallel,
358      * and stores the sigma values for local atoms. */
359     real                 *lb_buf1, *lb_buf2;
360     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
361
362     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
363
364     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
365
366     rvec                 *bufv;          /* Communication buffer */
367     real                 *bufr;          /* Communication buffer */
368     int                   buf_nalloc;    /* The communication buffer size */
369
370     /* thread local work data for solve_pme */
371     pme_work_t *work;
372
373     /* Work data for sum_qgrid */
374     real *   sum_qgrid_tmp;
375     real *   sum_qgrid_dd_tmp;
376 } t_gmx_pme;
377
378 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
379                                    int start, int end, int thread)
380 {
381     int             i;
382     int            *idxptr, tix, tiy, tiz;
383     real           *xptr, *fptr, tx, ty, tz;
384     real            rxx, ryx, ryy, rzx, rzy, rzz;
385     int             nx, ny, nz;
386     int             start_ix, start_iy, start_iz;
387     int            *g2tx, *g2ty, *g2tz;
388     gmx_bool        bThreads;
389     int            *thread_idx = NULL;
390     thread_plist_t *tpl        = NULL;
391     int            *tpl_n      = NULL;
392     int             thread_i;
393
394     nx  = pme->nkx;
395     ny  = pme->nky;
396     nz  = pme->nkz;
397
398     start_ix = pme->pmegrid_start_ix;
399     start_iy = pme->pmegrid_start_iy;
400     start_iz = pme->pmegrid_start_iz;
401
402     rxx = pme->recipbox[XX][XX];
403     ryx = pme->recipbox[YY][XX];
404     ryy = pme->recipbox[YY][YY];
405     rzx = pme->recipbox[ZZ][XX];
406     rzy = pme->recipbox[ZZ][YY];
407     rzz = pme->recipbox[ZZ][ZZ];
408
409     g2tx = pme->pmegrid[PME_GRID_QA].g2t[XX];
410     g2ty = pme->pmegrid[PME_GRID_QA].g2t[YY];
411     g2tz = pme->pmegrid[PME_GRID_QA].g2t[ZZ];
412
413     bThreads = (atc->nthread > 1);
414     if (bThreads)
415     {
416         thread_idx = atc->thread_idx;
417
418         tpl   = &atc->thread_plist[thread];
419         tpl_n = tpl->n;
420         for (i = 0; i < atc->nthread; i++)
421         {
422             tpl_n[i] = 0;
423         }
424     }
425
426     for (i = start; i < end; i++)
427     {
428         xptr   = atc->x[i];
429         idxptr = atc->idx[i];
430         fptr   = atc->fractx[i];
431
432         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
433         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
434         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
435         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
436
437         tix = (int)(tx);
438         tiy = (int)(ty);
439         tiz = (int)(tz);
440
441         /* Because decomposition only occurs in x and y,
442          * we never have a fraction correction in z.
443          */
444         fptr[XX] = tx - tix + pme->fshx[tix];
445         fptr[YY] = ty - tiy + pme->fshy[tiy];
446         fptr[ZZ] = tz - tiz;
447
448         idxptr[XX] = pme->nnx[tix];
449         idxptr[YY] = pme->nny[tiy];
450         idxptr[ZZ] = pme->nnz[tiz];
451
452 #ifdef DEBUG
453         range_check(idxptr[XX], 0, pme->pmegrid_nx);
454         range_check(idxptr[YY], 0, pme->pmegrid_ny);
455         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
456 #endif
457
458         if (bThreads)
459         {
460             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
461             thread_idx[i] = thread_i;
462             tpl_n[thread_i]++;
463         }
464     }
465
466     if (bThreads)
467     {
468         /* Make a list of particle indices sorted on thread */
469
470         /* Get the cumulative count */
471         for (i = 1; i < atc->nthread; i++)
472         {
473             tpl_n[i] += tpl_n[i-1];
474         }
475         /* The current implementation distributes particles equally
476          * over the threads, so we could actually allocate for that
477          * in pme_realloc_atomcomm_things.
478          */
479         if (tpl_n[atc->nthread-1] > tpl->nalloc)
480         {
481             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
482             srenew(tpl->i, tpl->nalloc);
483         }
484         /* Set tpl_n to the cumulative start */
485         for (i = atc->nthread-1; i >= 1; i--)
486         {
487             tpl_n[i] = tpl_n[i-1];
488         }
489         tpl_n[0] = 0;
490
491         /* Fill our thread local array with indices sorted on thread */
492         for (i = start; i < end; i++)
493         {
494             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
495         }
496         /* Now tpl_n contains the cummulative count again */
497     }
498 }
499
500 static void make_thread_local_ind(pme_atomcomm_t *atc,
501                                   int thread, splinedata_t *spline)
502 {
503     int             n, t, i, start, end;
504     thread_plist_t *tpl;
505
506     /* Combine the indices made by each thread into one index */
507
508     n     = 0;
509     start = 0;
510     for (t = 0; t < atc->nthread; t++)
511     {
512         tpl = &atc->thread_plist[t];
513         /* Copy our part (start - end) from the list of thread t */
514         if (thread > 0)
515         {
516             start = tpl->n[thread-1];
517         }
518         end = tpl->n[thread];
519         for (i = start; i < end; i++)
520         {
521             spline->ind[n++] = tpl->i[i];
522         }
523     }
524
525     spline->n = n;
526 }
527
528
529 static void pme_calc_pidx(int start, int end,
530                           matrix recipbox, rvec x[],
531                           pme_atomcomm_t *atc, int *count)
532 {
533     int   nslab, i;
534     int   si;
535     real *xptr, s;
536     real  rxx, ryx, rzx, ryy, rzy;
537     int  *pd;
538
539     /* Calculate PME task index (pidx) for each grid index.
540      * Here we always assign equally sized slabs to each node
541      * for load balancing reasons (the PME grid spacing is not used).
542      */
543
544     nslab = atc->nslab;
545     pd    = atc->pd;
546
547     /* Reset the count */
548     for (i = 0; i < nslab; i++)
549     {
550         count[i] = 0;
551     }
552
553     if (atc->dimind == 0)
554     {
555         rxx = recipbox[XX][XX];
556         ryx = recipbox[YY][XX];
557         rzx = recipbox[ZZ][XX];
558         /* Calculate the node index in x-dimension */
559         for (i = start; i < end; i++)
560         {
561             xptr   = x[i];
562             /* Fractional coordinates along box vectors */
563             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
564             si    = (int)(s + 2*nslab) % nslab;
565             pd[i] = si;
566             count[si]++;
567         }
568     }
569     else
570     {
571         ryy = recipbox[YY][YY];
572         rzy = recipbox[ZZ][YY];
573         /* Calculate the node index in y-dimension */
574         for (i = start; i < end; i++)
575         {
576             xptr   = x[i];
577             /* Fractional coordinates along box vectors */
578             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
579             si    = (int)(s + 2*nslab) % nslab;
580             pd[i] = si;
581             count[si]++;
582         }
583     }
584 }
585
586 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
587                                   pme_atomcomm_t *atc)
588 {
589     int nthread, thread, slab;
590
591     nthread = atc->nthread;
592
593 #pragma omp parallel for num_threads(nthread) schedule(static)
594     for (thread = 0; thread < nthread; thread++)
595     {
596         pme_calc_pidx(natoms* thread   /nthread,
597                       natoms*(thread+1)/nthread,
598                       recipbox, x, atc, atc->count_thread[thread]);
599     }
600     /* Non-parallel reduction, since nslab is small */
601
602     for (thread = 1; thread < nthread; thread++)
603     {
604         for (slab = 0; slab < atc->nslab; slab++)
605         {
606             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
607         }
608     }
609 }
610
611 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
612 {
613     const int padding = 4;
614     int       i;
615
616     srenew(th[XX], nalloc);
617     srenew(th[YY], nalloc);
618     /* In z we add padding, this is only required for the aligned SIMD code */
619     sfree_aligned(*ptr_z);
620     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
621     th[ZZ] = *ptr_z + padding;
622
623     for (i = 0; i < padding; i++)
624     {
625         (*ptr_z)[               i] = 0;
626         (*ptr_z)[padding+nalloc+i] = 0;
627     }
628 }
629
630 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
631 {
632     int i, d;
633
634     srenew(spline->ind, atc->nalloc);
635     /* Initialize the index to identity so it works without threads */
636     for (i = 0; i < atc->nalloc; i++)
637     {
638         spline->ind[i] = i;
639     }
640
641     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
642                       atc->pme_order*atc->nalloc);
643     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
644                       atc->pme_order*atc->nalloc);
645 }
646
647 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
648 {
649     int nalloc_old, i, j, nalloc_tpl;
650
651     /* We have to avoid a NULL pointer for atc->x to avoid
652      * possible fatal errors in MPI routines.
653      */
654     if (atc->n > atc->nalloc || atc->nalloc == 0)
655     {
656         nalloc_old  = atc->nalloc;
657         atc->nalloc = over_alloc_dd(max(atc->n, 1));
658
659         if (atc->nslab > 1)
660         {
661             srenew(atc->x, atc->nalloc);
662             srenew(atc->q, atc->nalloc);
663             srenew(atc->f, atc->nalloc);
664             for (i = nalloc_old; i < atc->nalloc; i++)
665             {
666                 clear_rvec(atc->f[i]);
667             }
668         }
669         if (atc->bSpread)
670         {
671             srenew(atc->fractx, atc->nalloc);
672             srenew(atc->idx, atc->nalloc);
673
674             if (atc->nthread > 1)
675             {
676                 srenew(atc->thread_idx, atc->nalloc);
677             }
678
679             for (i = 0; i < atc->nthread; i++)
680             {
681                 pme_realloc_splinedata(&atc->spline[i], atc);
682             }
683         }
684     }
685 }
686
687 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
688                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
689                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
690                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
691 {
692 #ifdef GMX_MPI
693     int        dest, src;
694     MPI_Status stat;
695
696     if (bBackward == FALSE)
697     {
698         dest = atc->node_dest[shift];
699         src  = atc->node_src[shift];
700     }
701     else
702     {
703         dest = atc->node_src[shift];
704         src  = atc->node_dest[shift];
705     }
706
707     if (nbyte_s > 0 && nbyte_r > 0)
708     {
709         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
710                      dest, shift,
711                      buf_r, nbyte_r, MPI_BYTE,
712                      src, shift,
713                      atc->mpi_comm, &stat);
714     }
715     else if (nbyte_s > 0)
716     {
717         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
718                  dest, shift,
719                  atc->mpi_comm);
720     }
721     else if (nbyte_r > 0)
722     {
723         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
724                  src, shift,
725                  atc->mpi_comm, &stat);
726     }
727 #endif
728 }
729
730 static void dd_pmeredist_x_q(gmx_pme_t pme,
731                              int n, gmx_bool bX, rvec *x, real *charge,
732                              pme_atomcomm_t *atc)
733 {
734     int *commnode, *buf_index;
735     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
736
737     commnode  = atc->node_dest;
738     buf_index = atc->buf_index;
739
740     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
741
742     nsend = 0;
743     for (i = 0; i < nnodes_comm; i++)
744     {
745         buf_index[commnode[i]] = nsend;
746         nsend                 += atc->count[commnode[i]];
747     }
748     if (bX)
749     {
750         if (atc->count[atc->nodeid] + nsend != n)
751         {
752             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
753                       "This usually means that your system is not well equilibrated.",
754                       n - (atc->count[atc->nodeid] + nsend),
755                       pme->nodeid, 'x'+atc->dimind);
756         }
757
758         if (nsend > pme->buf_nalloc)
759         {
760             pme->buf_nalloc = over_alloc_dd(nsend);
761             srenew(pme->bufv, pme->buf_nalloc);
762             srenew(pme->bufr, pme->buf_nalloc);
763         }
764
765         atc->n = atc->count[atc->nodeid];
766         for (i = 0; i < nnodes_comm; i++)
767         {
768             scount = atc->count[commnode[i]];
769             /* Communicate the count */
770             if (debug)
771             {
772                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
773                         atc->dimind, atc->nodeid, commnode[i], scount);
774             }
775             pme_dd_sendrecv(atc, FALSE, i,
776                             &scount, sizeof(int),
777                             &atc->rcount[i], sizeof(int));
778             atc->n += atc->rcount[i];
779         }
780
781         pme_realloc_atomcomm_things(atc);
782     }
783
784     local_pos = 0;
785     for (i = 0; i < n; i++)
786     {
787         node = atc->pd[i];
788         if (node == atc->nodeid)
789         {
790             /* Copy direct to the receive buffer */
791             if (bX)
792             {
793                 copy_rvec(x[i], atc->x[local_pos]);
794             }
795             atc->q[local_pos] = charge[i];
796             local_pos++;
797         }
798         else
799         {
800             /* Copy to the send buffer */
801             if (bX)
802             {
803                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
804             }
805             pme->bufr[buf_index[node]] = charge[i];
806             buf_index[node]++;
807         }
808     }
809
810     buf_pos = 0;
811     for (i = 0; i < nnodes_comm; i++)
812     {
813         scount = atc->count[commnode[i]];
814         rcount = atc->rcount[i];
815         if (scount > 0 || rcount > 0)
816         {
817             if (bX)
818             {
819                 /* Communicate the coordinates */
820                 pme_dd_sendrecv(atc, FALSE, i,
821                                 pme->bufv[buf_pos], scount*sizeof(rvec),
822                                 atc->x[local_pos], rcount*sizeof(rvec));
823             }
824             /* Communicate the charges */
825             pme_dd_sendrecv(atc, FALSE, i,
826                             pme->bufr+buf_pos, scount*sizeof(real),
827                             atc->q+local_pos, rcount*sizeof(real));
828             buf_pos   += scount;
829             local_pos += atc->rcount[i];
830         }
831     }
832 }
833
834 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
835                            int n, rvec *f,
836                            gmx_bool bAddF)
837 {
838     int *commnode, *buf_index;
839     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
840
841     commnode  = atc->node_dest;
842     buf_index = atc->buf_index;
843
844     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
845
846     local_pos = atc->count[atc->nodeid];
847     buf_pos   = 0;
848     for (i = 0; i < nnodes_comm; i++)
849     {
850         scount = atc->rcount[i];
851         rcount = atc->count[commnode[i]];
852         if (scount > 0 || rcount > 0)
853         {
854             /* Communicate the forces */
855             pme_dd_sendrecv(atc, TRUE, i,
856                             atc->f[local_pos], scount*sizeof(rvec),
857                             pme->bufv[buf_pos], rcount*sizeof(rvec));
858             local_pos += scount;
859         }
860         buf_index[commnode[i]] = buf_pos;
861         buf_pos               += rcount;
862     }
863
864     local_pos = 0;
865     if (bAddF)
866     {
867         for (i = 0; i < n; i++)
868         {
869             node = atc->pd[i];
870             if (node == atc->nodeid)
871             {
872                 /* Add from the local force array */
873                 rvec_inc(f[i], atc->f[local_pos]);
874                 local_pos++;
875             }
876             else
877             {
878                 /* Add from the receive buffer */
879                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
880                 buf_index[node]++;
881             }
882         }
883     }
884     else
885     {
886         for (i = 0; i < n; i++)
887         {
888             node = atc->pd[i];
889             if (node == atc->nodeid)
890             {
891                 /* Copy from the local force array */
892                 copy_rvec(atc->f[local_pos], f[i]);
893                 local_pos++;
894             }
895             else
896             {
897                 /* Copy from the receive buffer */
898                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
899                 buf_index[node]++;
900             }
901         }
902     }
903 }
904
905 #ifdef GMX_MPI
906 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
907 {
908     pme_overlap_t *overlap;
909     int            send_index0, send_nindex;
910     int            recv_index0, recv_nindex;
911     MPI_Status     stat;
912     int            i, j, k, ix, iy, iz, icnt;
913     int            ipulse, send_id, recv_id, datasize;
914     real          *p;
915     real          *sendptr, *recvptr;
916
917     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
918     overlap = &pme->overlap[1];
919
920     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
921     {
922         /* Since we have already (un)wrapped the overlap in the z-dimension,
923          * we only have to communicate 0 to nkz (not pmegrid_nz).
924          */
925         if (direction == GMX_SUM_QGRID_FORWARD)
926         {
927             send_id       = overlap->send_id[ipulse];
928             recv_id       = overlap->recv_id[ipulse];
929             send_index0   = overlap->comm_data[ipulse].send_index0;
930             send_nindex   = overlap->comm_data[ipulse].send_nindex;
931             recv_index0   = overlap->comm_data[ipulse].recv_index0;
932             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
933         }
934         else
935         {
936             send_id       = overlap->recv_id[ipulse];
937             recv_id       = overlap->send_id[ipulse];
938             send_index0   = overlap->comm_data[ipulse].recv_index0;
939             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
940             recv_index0   = overlap->comm_data[ipulse].send_index0;
941             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
942         }
943
944         /* Copy data to contiguous send buffer */
945         if (debug)
946         {
947             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
948                     pme->nodeid, overlap->nodeid, send_id,
949                     pme->pmegrid_start_iy,
950                     send_index0-pme->pmegrid_start_iy,
951                     send_index0-pme->pmegrid_start_iy+send_nindex);
952         }
953         icnt = 0;
954         for (i = 0; i < pme->pmegrid_nx; i++)
955         {
956             ix = i;
957             for (j = 0; j < send_nindex; j++)
958             {
959                 iy = j + send_index0 - pme->pmegrid_start_iy;
960                 for (k = 0; k < pme->nkz; k++)
961                 {
962                     iz = k;
963                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
964                 }
965             }
966         }
967
968         datasize      = pme->pmegrid_nx * pme->nkz;
969
970         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
971                      send_id, ipulse,
972                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
973                      recv_id, ipulse,
974                      overlap->mpi_comm, &stat);
975
976         /* Get data from contiguous recv buffer */
977         if (debug)
978         {
979             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
980                     pme->nodeid, overlap->nodeid, recv_id,
981                     pme->pmegrid_start_iy,
982                     recv_index0-pme->pmegrid_start_iy,
983                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
984         }
985         icnt = 0;
986         for (i = 0; i < pme->pmegrid_nx; i++)
987         {
988             ix = i;
989             for (j = 0; j < recv_nindex; j++)
990             {
991                 iy = j + recv_index0 - pme->pmegrid_start_iy;
992                 for (k = 0; k < pme->nkz; k++)
993                 {
994                     iz = k;
995                     if (direction == GMX_SUM_QGRID_FORWARD)
996                     {
997                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
998                     }
999                     else
1000                     {
1001                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1002                     }
1003                 }
1004             }
1005         }
1006     }
1007
1008     /* Major dimension is easier, no copying required,
1009      * but we might have to sum to separate array.
1010      * Since we don't copy, we have to communicate up to pmegrid_nz,
1011      * not nkz as for the minor direction.
1012      */
1013     overlap = &pme->overlap[0];
1014
1015     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1016     {
1017         if (direction == GMX_SUM_QGRID_FORWARD)
1018         {
1019             send_id       = overlap->send_id[ipulse];
1020             recv_id       = overlap->recv_id[ipulse];
1021             send_index0   = overlap->comm_data[ipulse].send_index0;
1022             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1023             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1024             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1025             recvptr       = overlap->recvbuf;
1026         }
1027         else
1028         {
1029             send_id       = overlap->recv_id[ipulse];
1030             recv_id       = overlap->send_id[ipulse];
1031             send_index0   = overlap->comm_data[ipulse].recv_index0;
1032             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1033             recv_index0   = overlap->comm_data[ipulse].send_index0;
1034             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1035             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1036         }
1037
1038         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1039         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1040
1041         if (debug)
1042         {
1043             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1044                     pme->nodeid, overlap->nodeid, send_id,
1045                     pme->pmegrid_start_ix,
1046                     send_index0-pme->pmegrid_start_ix,
1047                     send_index0-pme->pmegrid_start_ix+send_nindex);
1048             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1049                     pme->nodeid, overlap->nodeid, recv_id,
1050                     pme->pmegrid_start_ix,
1051                     recv_index0-pme->pmegrid_start_ix,
1052                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1053         }
1054
1055         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1056                      send_id, ipulse,
1057                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1058                      recv_id, ipulse,
1059                      overlap->mpi_comm, &stat);
1060
1061         /* ADD data from contiguous recv buffer */
1062         if (direction == GMX_SUM_QGRID_FORWARD)
1063         {
1064             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1065             for (i = 0; i < recv_nindex*datasize; i++)
1066             {
1067                 p[i] += overlap->recvbuf[i];
1068             }
1069         }
1070     }
1071 }
1072 #endif
1073
1074
1075 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1076 {
1077     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1078     ivec    local_pme_size;
1079     int     i, ix, iy, iz;
1080     int     pmeidx, fftidx;
1081
1082     /* Dimensions should be identical for A/B grid, so we just use A here */
1083     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1084                                    local_fft_ndata,
1085                                    local_fft_offset,
1086                                    local_fft_size);
1087
1088     local_pme_size[0] = pme->pmegrid_nx;
1089     local_pme_size[1] = pme->pmegrid_ny;
1090     local_pme_size[2] = pme->pmegrid_nz;
1091
1092     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1093        the offset is identical, and the PME grid always has more data (due to overlap)
1094      */
1095     {
1096 #ifdef DEBUG_PME
1097         FILE *fp, *fp2;
1098         char  fn[STRLEN], format[STRLEN];
1099         real  val;
1100         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1101         fp = gmx_ffopen(fn, "w");
1102         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1103         fp2 = gmx_ffopen(fn, "w");
1104         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1105 #endif
1106
1107         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1108         {
1109             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1110             {
1111                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1112                 {
1113                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1114                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1115                     fftgrid[fftidx] = pmegrid[pmeidx];
1116 #ifdef DEBUG_PME
1117                     val = 100*pmegrid[pmeidx];
1118                     if (pmegrid[pmeidx] != 0)
1119                     {
1120                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1121                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1122                     }
1123                     if (pmegrid[pmeidx] != 0)
1124                     {
1125                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1126                                 "qgrid",
1127                                 pme->pmegrid_start_ix + ix,
1128                                 pme->pmegrid_start_iy + iy,
1129                                 pme->pmegrid_start_iz + iz,
1130                                 pmegrid[pmeidx]);
1131                     }
1132 #endif
1133                 }
1134             }
1135         }
1136 #ifdef DEBUG_PME
1137         gmx_ffclose(fp);
1138         gmx_ffclose(fp2);
1139 #endif
1140     }
1141     return 0;
1142 }
1143
1144
1145 static gmx_cycles_t omp_cyc_start()
1146 {
1147     return gmx_cycles_read();
1148 }
1149
1150 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1151 {
1152     return gmx_cycles_read() - c;
1153 }
1154
1155
1156 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1157                                    int nthread, int thread)
1158 {
1159     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1160     ivec          local_pme_size;
1161     int           ixy0, ixy1, ixy, ix, iy, iz;
1162     int           pmeidx, fftidx;
1163 #ifdef PME_TIME_THREADS
1164     gmx_cycles_t  c1;
1165     static double cs1 = 0;
1166     static int    cnt = 0;
1167 #endif
1168
1169 #ifdef PME_TIME_THREADS
1170     c1 = omp_cyc_start();
1171 #endif
1172     /* Dimensions should be identical for A/B grid, so we just use A here */
1173     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1174                                    local_fft_ndata,
1175                                    local_fft_offset,
1176                                    local_fft_size);
1177
1178     local_pme_size[0] = pme->pmegrid_nx;
1179     local_pme_size[1] = pme->pmegrid_ny;
1180     local_pme_size[2] = pme->pmegrid_nz;
1181
1182     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1183        the offset is identical, and the PME grid always has more data (due to overlap)
1184      */
1185     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1186     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1187
1188     for (ixy = ixy0; ixy < ixy1; ixy++)
1189     {
1190         ix = ixy/local_fft_ndata[YY];
1191         iy = ixy - ix*local_fft_ndata[YY];
1192
1193         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1194         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1195         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1196         {
1197             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1198         }
1199     }
1200
1201 #ifdef PME_TIME_THREADS
1202     c1   = omp_cyc_end(c1);
1203     cs1 += (double)c1;
1204     cnt++;
1205     if (cnt % 20 == 0)
1206     {
1207         printf("copy %.2f\n", cs1*1e-9);
1208     }
1209 #endif
1210
1211     return 0;
1212 }
1213
1214
1215 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1216 {
1217     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1218
1219     nx = pme->nkx;
1220     ny = pme->nky;
1221     nz = pme->nkz;
1222
1223     pnx = pme->pmegrid_nx;
1224     pny = pme->pmegrid_ny;
1225     pnz = pme->pmegrid_nz;
1226
1227     overlap = pme->pme_order - 1;
1228
1229     /* Add periodic overlap in z */
1230     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1231     {
1232         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1233         {
1234             for (iz = 0; iz < overlap; iz++)
1235             {
1236                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1237                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1238             }
1239         }
1240     }
1241
1242     if (pme->nnodes_minor == 1)
1243     {
1244         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1245         {
1246             for (iy = 0; iy < overlap; iy++)
1247             {
1248                 for (iz = 0; iz < nz; iz++)
1249                 {
1250                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1251                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1252                 }
1253             }
1254         }
1255     }
1256
1257     if (pme->nnodes_major == 1)
1258     {
1259         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1260
1261         for (ix = 0; ix < overlap; ix++)
1262         {
1263             for (iy = 0; iy < ny_x; iy++)
1264             {
1265                 for (iz = 0; iz < nz; iz++)
1266                 {
1267                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1268                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1269                 }
1270             }
1271         }
1272     }
1273 }
1274
1275
1276 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1277 {
1278     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1279
1280     nx = pme->nkx;
1281     ny = pme->nky;
1282     nz = pme->nkz;
1283
1284     pnx = pme->pmegrid_nx;
1285     pny = pme->pmegrid_ny;
1286     pnz = pme->pmegrid_nz;
1287
1288     overlap = pme->pme_order - 1;
1289
1290     if (pme->nnodes_major == 1)
1291     {
1292         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1293
1294         for (ix = 0; ix < overlap; ix++)
1295         {
1296             int iy, iz;
1297
1298             for (iy = 0; iy < ny_x; iy++)
1299             {
1300                 for (iz = 0; iz < nz; iz++)
1301                 {
1302                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1303                         pmegrid[(ix*pny+iy)*pnz+iz];
1304                 }
1305             }
1306         }
1307     }
1308
1309     if (pme->nnodes_minor == 1)
1310     {
1311 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1312         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1313         {
1314             int iy, iz;
1315
1316             for (iy = 0; iy < overlap; iy++)
1317             {
1318                 for (iz = 0; iz < nz; iz++)
1319                 {
1320                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1321                         pmegrid[(ix*pny+iy)*pnz+iz];
1322                 }
1323             }
1324         }
1325     }
1326
1327     /* Copy periodic overlap in z */
1328 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1329     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1330     {
1331         int iy, iz;
1332
1333         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1334         {
1335             for (iz = 0; iz < overlap; iz++)
1336             {
1337                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1338                     pmegrid[(ix*pny+iy)*pnz+iz];
1339             }
1340         }
1341     }
1342 }
1343
1344
1345 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1346 #define DO_BSPLINE(order)                            \
1347     for (ithx = 0; (ithx < order); ithx++)                    \
1348     {                                                    \
1349         index_x = (i0+ithx)*pny*pnz;                     \
1350         valx    = qn*thx[ithx];                          \
1351                                                      \
1352         for (ithy = 0; (ithy < order); ithy++)                \
1353         {                                                \
1354             valxy    = valx*thy[ithy];                   \
1355             index_xy = index_x+(j0+ithy)*pnz;            \
1356                                                      \
1357             for (ithz = 0; (ithz < order); ithz++)            \
1358             {                                            \
1359                 index_xyz        = index_xy+(k0+ithz);   \
1360                 grid[index_xyz] += valxy*thz[ithz];      \
1361             }                                            \
1362         }                                                \
1363     }
1364
1365
1366 static void spread_q_bsplines_thread(pmegrid_t                    *pmegrid,
1367                                      pme_atomcomm_t               *atc,
1368                                      splinedata_t                 *spline,
1369                                      pme_spline_work_t gmx_unused *work)
1370 {
1371
1372     /* spread charges from home atoms to local grid */
1373     real          *grid;
1374     pme_overlap_t *ol;
1375     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1376     int       *    idxptr;
1377     int            order, norder, index_x, index_xy, index_xyz;
1378     real           valx, valxy, qn;
1379     real          *thx, *thy, *thz;
1380     int            localsize, bndsize;
1381     int            pnx, pny, pnz, ndatatot;
1382     int            offx, offy, offz;
1383
1384 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1385     real           thz_buffer[GMX_SIMD4_WIDTH*3], *thz_aligned;
1386
1387     thz_aligned = gmx_simd4_align_r(thz_buffer);
1388 #endif
1389
1390     pnx = pmegrid->s[XX];
1391     pny = pmegrid->s[YY];
1392     pnz = pmegrid->s[ZZ];
1393
1394     offx = pmegrid->offset[XX];
1395     offy = pmegrid->offset[YY];
1396     offz = pmegrid->offset[ZZ];
1397
1398     ndatatot = pnx*pny*pnz;
1399     grid     = pmegrid->grid;
1400     for (i = 0; i < ndatatot; i++)
1401     {
1402         grid[i] = 0;
1403     }
1404
1405     order = pmegrid->order;
1406
1407     for (nn = 0; nn < spline->n; nn++)
1408     {
1409         n  = spline->ind[nn];
1410         qn = atc->q[n];
1411
1412         if (qn != 0)
1413         {
1414             idxptr = atc->idx[n];
1415             norder = nn*order;
1416
1417             i0   = idxptr[XX] - offx;
1418             j0   = idxptr[YY] - offy;
1419             k0   = idxptr[ZZ] - offz;
1420
1421             thx = spline->theta[XX] + norder;
1422             thy = spline->theta[YY] + norder;
1423             thz = spline->theta[ZZ] + norder;
1424
1425             switch (order)
1426             {
1427                 case 4:
1428 #ifdef PME_SIMD4_SPREAD_GATHER
1429 #ifdef PME_SIMD4_UNALIGNED
1430 #define PME_SPREAD_SIMD4_ORDER4
1431 #else
1432 #define PME_SPREAD_SIMD4_ALIGNED
1433 #define PME_ORDER 4
1434 #endif
1435 #include "pme_simd4.h"
1436 #else
1437                     DO_BSPLINE(4);
1438 #endif
1439                     break;
1440                 case 5:
1441 #ifdef PME_SIMD4_SPREAD_GATHER
1442 #define PME_SPREAD_SIMD4_ALIGNED
1443 #define PME_ORDER 5
1444 #include "pme_simd4.h"
1445 #else
1446                     DO_BSPLINE(5);
1447 #endif
1448                     break;
1449                 default:
1450                     DO_BSPLINE(order);
1451                     break;
1452             }
1453         }
1454     }
1455 }
1456
1457 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1458 {
1459 #ifdef PME_SIMD4_SPREAD_GATHER
1460     if (pme_order == 5
1461 #ifndef PME_SIMD4_UNALIGNED
1462         || pme_order == 4
1463 #endif
1464         )
1465     {
1466         /* Round nz up to a multiple of 4 to ensure alignment */
1467         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1468     }
1469 #endif
1470 }
1471
1472 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1473 {
1474 #ifdef PME_SIMD4_SPREAD_GATHER
1475 #ifndef PME_SIMD4_UNALIGNED
1476     if (pme_order == 4)
1477     {
1478         /* Add extra elements to ensured aligned operations do not go
1479          * beyond the allocated grid size.
1480          * Note that for pme_order=5, the pme grid z-size alignment
1481          * ensures that we will not go beyond the grid size.
1482          */
1483         *gridsize += 4;
1484     }
1485 #endif
1486 #endif
1487 }
1488
1489 static void pmegrid_init(pmegrid_t *grid,
1490                          int cx, int cy, int cz,
1491                          int x0, int y0, int z0,
1492                          int x1, int y1, int z1,
1493                          gmx_bool set_alignment,
1494                          int pme_order,
1495                          real *ptr)
1496 {
1497     int nz, gridsize;
1498
1499     grid->ci[XX]     = cx;
1500     grid->ci[YY]     = cy;
1501     grid->ci[ZZ]     = cz;
1502     grid->offset[XX] = x0;
1503     grid->offset[YY] = y0;
1504     grid->offset[ZZ] = z0;
1505     grid->n[XX]      = x1 - x0 + pme_order - 1;
1506     grid->n[YY]      = y1 - y0 + pme_order - 1;
1507     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1508     copy_ivec(grid->n, grid->s);
1509
1510     nz = grid->s[ZZ];
1511     set_grid_alignment(&nz, pme_order);
1512     if (set_alignment)
1513     {
1514         grid->s[ZZ] = nz;
1515     }
1516     else if (nz != grid->s[ZZ])
1517     {
1518         gmx_incons("pmegrid_init call with an unaligned z size");
1519     }
1520
1521     grid->order = pme_order;
1522     if (ptr == NULL)
1523     {
1524         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1525         set_gridsize_alignment(&gridsize, pme_order);
1526         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1527     }
1528     else
1529     {
1530         grid->grid = ptr;
1531     }
1532 }
1533
1534 static int div_round_up(int enumerator, int denominator)
1535 {
1536     return (enumerator + denominator - 1)/denominator;
1537 }
1538
1539 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1540                                   ivec nsub)
1541 {
1542     int gsize_opt, gsize;
1543     int nsx, nsy, nsz;
1544     char *env;
1545
1546     gsize_opt = -1;
1547     for (nsx = 1; nsx <= nthread; nsx++)
1548     {
1549         if (nthread % nsx == 0)
1550         {
1551             for (nsy = 1; nsy <= nthread; nsy++)
1552             {
1553                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1554                 {
1555                     nsz = nthread/(nsx*nsy);
1556
1557                     /* Determine the number of grid points per thread */
1558                     gsize =
1559                         (div_round_up(n[XX], nsx) + ovl)*
1560                         (div_round_up(n[YY], nsy) + ovl)*
1561                         (div_round_up(n[ZZ], nsz) + ovl);
1562
1563                     /* Minimize the number of grids points per thread
1564                      * and, secondarily, the number of cuts in minor dimensions.
1565                      */
1566                     if (gsize_opt == -1 ||
1567                         gsize < gsize_opt ||
1568                         (gsize == gsize_opt &&
1569                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1570                     {
1571                         nsub[XX]  = nsx;
1572                         nsub[YY]  = nsy;
1573                         nsub[ZZ]  = nsz;
1574                         gsize_opt = gsize;
1575                     }
1576                 }
1577             }
1578         }
1579     }
1580
1581     env = getenv("GMX_PME_THREAD_DIVISION");
1582     if (env != NULL)
1583     {
1584         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1585     }
1586
1587     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1588     {
1589         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1590     }
1591 }
1592
1593 static void pmegrids_init(pmegrids_t *grids,
1594                           int nx, int ny, int nz, int nz_base,
1595                           int pme_order,
1596                           gmx_bool bUseThreads,
1597                           int nthread,
1598                           int overlap_x,
1599                           int overlap_y)
1600 {
1601     ivec n, n_base, g0, g1;
1602     int t, x, y, z, d, i, tfac;
1603     int max_comm_lines = -1;
1604
1605     n[XX] = nx - (pme_order - 1);
1606     n[YY] = ny - (pme_order - 1);
1607     n[ZZ] = nz - (pme_order - 1);
1608
1609     copy_ivec(n, n_base);
1610     n_base[ZZ] = nz_base;
1611
1612     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1613                  NULL);
1614
1615     grids->nthread = nthread;
1616
1617     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1618
1619     if (bUseThreads)
1620     {
1621         ivec nst;
1622         int gridsize;
1623
1624         for (d = 0; d < DIM; d++)
1625         {
1626             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1627         }
1628         set_grid_alignment(&nst[ZZ], pme_order);
1629
1630         if (debug)
1631         {
1632             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1633                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1634             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1635                     nx, ny, nz,
1636                     nst[XX], nst[YY], nst[ZZ]);
1637         }
1638
1639         snew(grids->grid_th, grids->nthread);
1640         t        = 0;
1641         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1642         set_gridsize_alignment(&gridsize, pme_order);
1643         snew_aligned(grids->grid_all,
1644                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1645                      SIMD4_ALIGNMENT);
1646
1647         for (x = 0; x < grids->nc[XX]; x++)
1648         {
1649             for (y = 0; y < grids->nc[YY]; y++)
1650             {
1651                 for (z = 0; z < grids->nc[ZZ]; z++)
1652                 {
1653                     pmegrid_init(&grids->grid_th[t],
1654                                  x, y, z,
1655                                  (n[XX]*(x  ))/grids->nc[XX],
1656                                  (n[YY]*(y  ))/grids->nc[YY],
1657                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1658                                  (n[XX]*(x+1))/grids->nc[XX],
1659                                  (n[YY]*(y+1))/grids->nc[YY],
1660                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1661                                  TRUE,
1662                                  pme_order,
1663                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1664                     t++;
1665                 }
1666             }
1667         }
1668     }
1669     else
1670     {
1671         grids->grid_th = NULL;
1672     }
1673
1674     snew(grids->g2t, DIM);
1675     tfac = 1;
1676     for (d = DIM-1; d >= 0; d--)
1677     {
1678         snew(grids->g2t[d], n[d]);
1679         t = 0;
1680         for (i = 0; i < n[d]; i++)
1681         {
1682             /* The second check should match the parameters
1683              * of the pmegrid_init call above.
1684              */
1685             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1686             {
1687                 t++;
1688             }
1689             grids->g2t[d][i] = t*tfac;
1690         }
1691
1692         tfac *= grids->nc[d];
1693
1694         switch (d)
1695         {
1696             case XX: max_comm_lines = overlap_x;     break;
1697             case YY: max_comm_lines = overlap_y;     break;
1698             case ZZ: max_comm_lines = pme_order - 1; break;
1699         }
1700         grids->nthread_comm[d] = 0;
1701         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1702                grids->nthread_comm[d] < grids->nc[d])
1703         {
1704             grids->nthread_comm[d]++;
1705         }
1706         if (debug != NULL)
1707         {
1708             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1709                     'x'+d, grids->nthread_comm[d]);
1710         }
1711         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1712          * work, but this is not a problematic restriction.
1713          */
1714         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1715         {
1716             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1717         }
1718     }
1719 }
1720
1721
1722 static void pmegrids_destroy(pmegrids_t *grids)
1723 {
1724     int t;
1725
1726     if (grids->grid.grid != NULL)
1727     {
1728         sfree(grids->grid.grid);
1729
1730         if (grids->nthread > 0)
1731         {
1732             for (t = 0; t < grids->nthread; t++)
1733             {
1734                 sfree(grids->grid_th[t].grid);
1735             }
1736             sfree(grids->grid_th);
1737         }
1738     }
1739 }
1740
1741
1742 static void realloc_work(pme_work_t *work, int nkx)
1743 {
1744     int simd_width;
1745
1746     if (nkx > work->nalloc)
1747     {
1748         work->nalloc = nkx;
1749         srenew(work->mhx, work->nalloc);
1750         srenew(work->mhy, work->nalloc);
1751         srenew(work->mhz, work->nalloc);
1752         srenew(work->m2, work->nalloc);
1753         /* Allocate an aligned pointer for SIMD operations, including extra
1754          * elements at the end for padding.
1755          */
1756 #ifdef PME_SIMD_SOLVE
1757         simd_width = GMX_SIMD_REAL_WIDTH;
1758 #else
1759         /* We can use any alignment, apart from 0, so we use 4 */
1760         simd_width = 4;
1761 #endif
1762         sfree_aligned(work->denom);
1763         sfree_aligned(work->tmp1);
1764         sfree_aligned(work->tmp2);
1765         sfree_aligned(work->eterm);
1766         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1767         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1768         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1769         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1770         srenew(work->m2inv, work->nalloc);
1771     }
1772 }
1773
1774
1775 static void free_work(pme_work_t *work)
1776 {
1777     sfree(work->mhx);
1778     sfree(work->mhy);
1779     sfree(work->mhz);
1780     sfree(work->m2);
1781     sfree_aligned(work->denom);
1782     sfree_aligned(work->tmp1);
1783     sfree_aligned(work->tmp2);
1784     sfree_aligned(work->eterm);
1785     sfree(work->m2inv);
1786 }
1787
1788
1789 #if defined PME_SIMD_SOLVE
1790 /* Calculate exponentials through SIMD */
1791 inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1792 {
1793     {
1794         const gmx_simd_real_t two = gmx_simd_set1_r(2.0);
1795         gmx_simd_real_t f_simd;
1796         gmx_simd_real_t lu;
1797         gmx_simd_real_t tmp_d1, d_inv, tmp_r, tmp_e;
1798         int kx;
1799         f_simd = gmx_simd_set1_r(f);
1800         /* We only need to calculate from start. But since start is 0 or 1
1801          * and we want to use aligned loads/stores, we always start from 0.
1802          */
1803         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1804         {
1805             tmp_d1   = gmx_simd_load_r(d_aligned+kx);
1806             d_inv    = gmx_simd_inv_r(tmp_d1);
1807             tmp_r    = gmx_simd_load_r(r_aligned+kx);
1808             tmp_r    = gmx_simd_exp_r(tmp_r);
1809             tmp_e    = gmx_simd_mul_r(f_simd, d_inv);
1810             tmp_e    = gmx_simd_mul_r(tmp_e, tmp_r);
1811             gmx_simd_store_r(e_aligned+kx, tmp_e);
1812         }
1813     }
1814 }
1815 #else
1816 inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1817 {
1818     int kx;
1819     for (kx = start; kx < end; kx++)
1820     {
1821         d[kx] = 1.0/d[kx];
1822     }
1823     for (kx = start; kx < end; kx++)
1824     {
1825         r[kx] = exp(r[kx]);
1826     }
1827     for (kx = start; kx < end; kx++)
1828     {
1829         e[kx] = f*r[kx]*d[kx];
1830     }
1831 }
1832 #endif
1833
1834 #if defined PME_SIMD_SOLVE
1835 /* Calculate exponentials through SIMD */
1836 inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1837 {
1838     gmx_simd_real_t tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1839     const gmx_simd_real_t sqr_PI = gmx_simd_sqrt_r(gmx_simd_set1_r(M_PI));
1840     int kx;
1841     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1842     {
1843         /* We only need to calculate from start. But since start is 0 or 1
1844          * and we want to use aligned loads/stores, we always start from 0.
1845          */
1846         tmp_d = gmx_simd_load_r(d_aligned+kx);
1847         d_inv = gmx_simd_inv_r(tmp_d);
1848         gmx_simd_store_r(d_aligned+kx, d_inv);
1849         tmp_r = gmx_simd_load_r(r_aligned+kx);
1850         tmp_r = gmx_simd_exp_r(tmp_r);
1851         gmx_simd_store_r(r_aligned+kx, tmp_r);
1852         tmp_mk  = gmx_simd_load_r(factor_aligned+kx);
1853         tmp_fac = gmx_simd_mul_r(sqr_PI, gmx_simd_mul_r(tmp_mk, gmx_simd_erfc_r(tmp_mk)));
1854         gmx_simd_store_r(factor_aligned+kx, tmp_fac);
1855     }
1856 }
1857 #else
1858 inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1859 {
1860     int kx;
1861     real mk;
1862     for (kx = start; kx < end; kx++)
1863     {
1864         d[kx] = 1.0/d[kx];
1865     }
1866
1867     for (kx = start; kx < end; kx++)
1868     {
1869         r[kx] = exp(r[kx]);
1870     }
1871
1872     for (kx = start; kx < end; kx++)
1873     {
1874         mk       = tmp2[kx];
1875         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1876     }
1877 }
1878 #endif
1879
1880 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1881                          real ewaldcoeff, real vol,
1882                          gmx_bool bEnerVir,
1883                          int nthread, int thread)
1884 {
1885     /* do recip sum over local cells in grid */
1886     /* y major, z middle, x minor or continuous */
1887     t_complex *p0;
1888     int     kx, ky, kz, maxkx, maxky, maxkz;
1889     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1890     real    mx, my, mz;
1891     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1892     real    ets2, struct2, vfactor, ets2vf;
1893     real    d1, d2, energy = 0;
1894     real    by, bz;
1895     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1896     real    rxx, ryx, ryy, rzx, rzy, rzz;
1897     pme_work_t *work;
1898     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1899     real    mhxk, mhyk, mhzk, m2k;
1900     real    corner_fac;
1901     ivec    complex_order;
1902     ivec    local_ndata, local_offset, local_size;
1903     real    elfac;
1904
1905     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1906
1907     nx = pme->nkx;
1908     ny = pme->nky;
1909     nz = pme->nkz;
1910
1911     /* Dimensions should be identical for A/B grid, so we just use A here */
1912     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
1913                                       complex_order,
1914                                       local_ndata,
1915                                       local_offset,
1916                                       local_size);
1917
1918     rxx = pme->recipbox[XX][XX];
1919     ryx = pme->recipbox[YY][XX];
1920     ryy = pme->recipbox[YY][YY];
1921     rzx = pme->recipbox[ZZ][XX];
1922     rzy = pme->recipbox[ZZ][YY];
1923     rzz = pme->recipbox[ZZ][ZZ];
1924
1925     maxkx = (nx+1)/2;
1926     maxky = (ny+1)/2;
1927     maxkz = nz/2+1;
1928
1929     work  = &pme->work[thread];
1930     mhx   = work->mhx;
1931     mhy   = work->mhy;
1932     mhz   = work->mhz;
1933     m2    = work->m2;
1934     denom = work->denom;
1935     tmp1  = work->tmp1;
1936     eterm = work->eterm;
1937     m2inv = work->m2inv;
1938
1939     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1940     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1941
1942     for (iyz = iyz0; iyz < iyz1; iyz++)
1943     {
1944         iy = iyz/local_ndata[ZZ];
1945         iz = iyz - iy*local_ndata[ZZ];
1946
1947         ky = iy + local_offset[YY];
1948
1949         if (ky < maxky)
1950         {
1951             my = ky;
1952         }
1953         else
1954         {
1955             my = (ky - ny);
1956         }
1957
1958         by = M_PI*vol*pme->bsp_mod[YY][ky];
1959
1960         kz = iz + local_offset[ZZ];
1961
1962         mz = kz;
1963
1964         bz = pme->bsp_mod[ZZ][kz];
1965
1966         /* 0.5 correction for corner points */
1967         corner_fac = 1;
1968         if (kz == 0 || kz == (nz+1)/2)
1969         {
1970             corner_fac = 0.5;
1971         }
1972
1973         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1974
1975         /* We should skip the k-space point (0,0,0) */
1976         /* Note that since here x is the minor index, local_offset[XX]=0 */
1977         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1978         {
1979             kxstart = local_offset[XX];
1980         }
1981         else
1982         {
1983             kxstart = local_offset[XX] + 1;
1984             p0++;
1985         }
1986         kxend = local_offset[XX] + local_ndata[XX];
1987
1988         if (bEnerVir)
1989         {
1990             /* More expensive inner loop, especially because of the storage
1991              * of the mh elements in array's.
1992              * Because x is the minor grid index, all mh elements
1993              * depend on kx for triclinic unit cells.
1994              */
1995
1996             /* Two explicit loops to avoid a conditional inside the loop */
1997             for (kx = kxstart; kx < maxkx; kx++)
1998             {
1999                 mx = kx;
2000
2001                 mhxk      = mx * rxx;
2002                 mhyk      = mx * ryx + my * ryy;
2003                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2004                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2005                 mhx[kx]   = mhxk;
2006                 mhy[kx]   = mhyk;
2007                 mhz[kx]   = mhzk;
2008                 m2[kx]    = m2k;
2009                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2010                 tmp1[kx]  = -factor*m2k;
2011             }
2012
2013             for (kx = maxkx; kx < kxend; kx++)
2014             {
2015                 mx = (kx - nx);
2016
2017                 mhxk      = mx * rxx;
2018                 mhyk      = mx * ryx + my * ryy;
2019                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2020                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2021                 mhx[kx]   = mhxk;
2022                 mhy[kx]   = mhyk;
2023                 mhz[kx]   = mhzk;
2024                 m2[kx]    = m2k;
2025                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2026                 tmp1[kx]  = -factor*m2k;
2027             }
2028
2029             for (kx = kxstart; kx < kxend; kx++)
2030             {
2031                 m2inv[kx] = 1.0/m2[kx];
2032             }
2033
2034             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2035
2036             for (kx = kxstart; kx < kxend; kx++, p0++)
2037             {
2038                 d1      = p0->re;
2039                 d2      = p0->im;
2040
2041                 p0->re  = d1*eterm[kx];
2042                 p0->im  = d2*eterm[kx];
2043
2044                 struct2 = 2.0*(d1*d1+d2*d2);
2045
2046                 tmp1[kx] = eterm[kx]*struct2;
2047             }
2048
2049             for (kx = kxstart; kx < kxend; kx++)
2050             {
2051                 ets2     = corner_fac*tmp1[kx];
2052                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2053                 energy  += ets2;
2054
2055                 ets2vf   = ets2*vfactor;
2056                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2057                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2058                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2059                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2060                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2061                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2062             }
2063         }
2064         else
2065         {
2066             /* We don't need to calculate the energy and the virial.
2067              * In this case the triclinic overhead is small.
2068              */
2069
2070             /* Two explicit loops to avoid a conditional inside the loop */
2071
2072             for (kx = kxstart; kx < maxkx; kx++)
2073             {
2074                 mx = kx;
2075
2076                 mhxk      = mx * rxx;
2077                 mhyk      = mx * ryx + my * ryy;
2078                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2079                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2080                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2081                 tmp1[kx]  = -factor*m2k;
2082             }
2083
2084             for (kx = maxkx; kx < kxend; kx++)
2085             {
2086                 mx = (kx - nx);
2087
2088                 mhxk      = mx * rxx;
2089                 mhyk      = mx * ryx + my * ryy;
2090                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2091                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2092                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2093                 tmp1[kx]  = -factor*m2k;
2094             }
2095
2096             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2097
2098             for (kx = kxstart; kx < kxend; kx++, p0++)
2099             {
2100                 d1      = p0->re;
2101                 d2      = p0->im;
2102
2103                 p0->re  = d1*eterm[kx];
2104                 p0->im  = d2*eterm[kx];
2105             }
2106         }
2107     }
2108
2109     if (bEnerVir)
2110     {
2111         /* Update virial with local values.
2112          * The virial is symmetric by definition.
2113          * this virial seems ok for isotropic scaling, but I'm
2114          * experiencing problems on semiisotropic membranes.
2115          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2116          */
2117         work->vir_q[XX][XX] = 0.25*virxx;
2118         work->vir_q[YY][YY] = 0.25*viryy;
2119         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2120         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2121         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2122         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2123
2124         /* This energy should be corrected for a charged system */
2125         work->energy_q = 0.5*energy;
2126     }
2127
2128     /* Return the loop count */
2129     return local_ndata[YY]*local_ndata[XX];
2130 }
2131
2132 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2133                             real ewaldcoeff, real vol,
2134                             gmx_bool bEnerVir, int nthread, int thread)
2135 {
2136     /* do recip sum over local cells in grid */
2137     /* y major, z middle, x minor or continuous */
2138     int     ig, gcount;
2139     int     kx, ky, kz, maxkx, maxky, maxkz;
2140     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2141     real    mx, my, mz;
2142     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2143     real    ets2, ets2vf;
2144     real    eterm, vterm, d1, d2, energy = 0;
2145     real    by, bz;
2146     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2147     real    rxx, ryx, ryy, rzx, rzy, rzz;
2148     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2149     real    mhxk, mhyk, mhzk, m2k;
2150     real    mk;
2151     pme_work_t *work;
2152     real    corner_fac;
2153     ivec    complex_order;
2154     ivec    local_ndata, local_offset, local_size;
2155     nx = pme->nkx;
2156     ny = pme->nky;
2157     nz = pme->nkz;
2158
2159     /* Dimensions should be identical for A/B grid, so we just use A here */
2160     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2161                                       complex_order,
2162                                       local_ndata,
2163                                       local_offset,
2164                                       local_size);
2165     rxx = pme->recipbox[XX][XX];
2166     ryx = pme->recipbox[YY][XX];
2167     ryy = pme->recipbox[YY][YY];
2168     rzx = pme->recipbox[ZZ][XX];
2169     rzy = pme->recipbox[ZZ][YY];
2170     rzz = pme->recipbox[ZZ][ZZ];
2171
2172     maxkx = (nx+1)/2;
2173     maxky = (ny+1)/2;
2174     maxkz = nz/2+1;
2175
2176     work  = &pme->work[thread];
2177     mhx   = work->mhx;
2178     mhy   = work->mhy;
2179     mhz   = work->mhz;
2180     m2    = work->m2;
2181     denom = work->denom;
2182     tmp1  = work->tmp1;
2183     tmp2  = work->tmp2;
2184
2185     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2186     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2187
2188     for (iyz = iyz0; iyz < iyz1; iyz++)
2189     {
2190         iy = iyz/local_ndata[ZZ];
2191         iz = iyz - iy*local_ndata[ZZ];
2192
2193         ky = iy + local_offset[YY];
2194
2195         if (ky < maxky)
2196         {
2197             my = ky;
2198         }
2199         else
2200         {
2201             my = (ky - ny);
2202         }
2203
2204         by = 3.0*vol*pme->bsp_mod[YY][ky]
2205             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2206
2207         kz = iz + local_offset[ZZ];
2208
2209         mz = kz;
2210
2211         bz = pme->bsp_mod[ZZ][kz];
2212
2213         /* 0.5 correction for corner points */
2214         corner_fac = 1;
2215         if (kz == 0 || kz == (nz+1)/2)
2216         {
2217             corner_fac = 0.5;
2218         }
2219
2220         kxstart = local_offset[XX];
2221         kxend   = local_offset[XX] + local_ndata[XX];
2222         if (bEnerVir)
2223         {
2224             /* More expensive inner loop, especially because of the
2225              * storage of the mh elements in array's.  Because x is the
2226              * minor grid index, all mh elements depend on kx for
2227              * triclinic unit cells.
2228              */
2229
2230             /* Two explicit loops to avoid a conditional inside the loop */
2231             for (kx = kxstart; kx < maxkx; kx++)
2232             {
2233                 mx = kx;
2234
2235                 mhxk      = mx * rxx;
2236                 mhyk      = mx * ryx + my * ryy;
2237                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2238                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2239                 mhx[kx]   = mhxk;
2240                 mhy[kx]   = mhyk;
2241                 mhz[kx]   = mhzk;
2242                 m2[kx]    = m2k;
2243                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2244                 tmp1[kx]  = -factor*m2k;
2245                 tmp2[kx]  = sqrt(factor*m2k);
2246             }
2247
2248             for (kx = maxkx; kx < kxend; kx++)
2249             {
2250                 mx = (kx - nx);
2251
2252                 mhxk      = mx * rxx;
2253                 mhyk      = mx * ryx + my * ryy;
2254                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2255                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2256                 mhx[kx]   = mhxk;
2257                 mhy[kx]   = mhyk;
2258                 mhz[kx]   = mhzk;
2259                 m2[kx]    = m2k;
2260                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2261                 tmp1[kx]  = -factor*m2k;
2262                 tmp2[kx]  = sqrt(factor*m2k);
2263             }
2264
2265             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2266
2267             for (kx = kxstart; kx < kxend; kx++)
2268             {
2269                 m2k   = factor*m2[kx];
2270                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2271                           + 2.0*m2k*tmp2[kx]);
2272                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2273                 tmp1[kx] = eterm*denom[kx];
2274                 tmp2[kx] = vterm*denom[kx];
2275             }
2276
2277             if (!bLB)
2278             {
2279                 t_complex *p0;
2280                 real       struct2;
2281
2282                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2283                 for (kx = kxstart; kx < kxend; kx++, p0++)
2284                 {
2285                     d1      = p0->re;
2286                     d2      = p0->im;
2287
2288                     eterm   = tmp1[kx];
2289                     vterm   = tmp2[kx];
2290                     p0->re  = d1*eterm;
2291                     p0->im  = d2*eterm;
2292
2293                     struct2 = 2.0*(d1*d1+d2*d2);
2294
2295                     tmp1[kx] = eterm*struct2;
2296                     tmp2[kx] = vterm*struct2;
2297                 }
2298             }
2299             else
2300             {
2301                 real *struct2 = denom;
2302                 real  str2;
2303
2304                 for (kx = kxstart; kx < kxend; kx++)
2305                 {
2306                     struct2[kx] = 0.0;
2307                 }
2308                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2309                 for (ig = 0; ig <= 3; ++ig)
2310                 {
2311                     t_complex *p0, *p1;
2312                     real       scale;
2313
2314                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2315                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2316                     scale = 2.0*lb_scale_factor_symm[ig];
2317                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2318                     {
2319                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2320                     }
2321
2322                 }
2323                 for (ig = 0; ig <= 6; ++ig)
2324                 {
2325                     t_complex *p0;
2326
2327                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2328                     for (kx = kxstart; kx < kxend; kx++, p0++)
2329                     {
2330                         d1     = p0->re;
2331                         d2     = p0->im;
2332
2333                         eterm  = tmp1[kx];
2334                         p0->re = d1*eterm;
2335                         p0->im = d2*eterm;
2336                     }
2337                 }
2338                 for (kx = kxstart; kx < kxend; kx++)
2339                 {
2340                     eterm    = tmp1[kx];
2341                     vterm    = tmp2[kx];
2342                     str2     = struct2[kx];
2343                     tmp1[kx] = eterm*str2;
2344                     tmp2[kx] = vterm*str2;
2345                 }
2346             }
2347
2348             for (kx = kxstart; kx < kxend; kx++)
2349             {
2350                 ets2     = corner_fac*tmp1[kx];
2351                 vterm    = 2.0*factor*tmp2[kx];
2352                 energy  += ets2;
2353                 ets2vf   = corner_fac*vterm;
2354                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2355                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2356                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2357                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2358                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2359                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2360             }
2361         }
2362         else
2363         {
2364             /* We don't need to calculate the energy and the virial.
2365              *  In this case the triclinic overhead is small.
2366              */
2367
2368             /* Two explicit loops to avoid a conditional inside the loop */
2369
2370             for (kx = kxstart; kx < maxkx; kx++)
2371             {
2372                 mx = kx;
2373
2374                 mhxk      = mx * rxx;
2375                 mhyk      = mx * ryx + my * ryy;
2376                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2377                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2378                 m2[kx]    = m2k;
2379                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2380                 tmp1[kx]  = -factor*m2k;
2381                 tmp2[kx]  = sqrt(factor*m2k);
2382             }
2383
2384             for (kx = maxkx; kx < kxend; kx++)
2385             {
2386                 mx = (kx - nx);
2387
2388                 mhxk      = mx * rxx;
2389                 mhyk      = mx * ryx + my * ryy;
2390                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2391                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2392                 m2[kx]    = m2k;
2393                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2394                 tmp1[kx]  = -factor*m2k;
2395                 tmp2[kx]  = sqrt(factor*m2k);
2396             }
2397
2398             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2399
2400             for (kx = kxstart; kx < kxend; kx++)
2401             {
2402                 m2k    = factor*m2[kx];
2403                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2404                            + 2.0*m2k*tmp2[kx]);
2405                 tmp1[kx] = eterm*denom[kx];
2406             }
2407             gcount = (bLB ? 7 : 1);
2408             for (ig = 0; ig < gcount; ++ig)
2409             {
2410                 t_complex *p0;
2411
2412                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2413                 for (kx = kxstart; kx < kxend; kx++, p0++)
2414                 {
2415                     d1      = p0->re;
2416                     d2      = p0->im;
2417
2418                     eterm   = tmp1[kx];
2419
2420                     p0->re  = d1*eterm;
2421                     p0->im  = d2*eterm;
2422                 }
2423             }
2424         }
2425     }
2426     if (bEnerVir)
2427     {
2428         work->vir_lj[XX][XX] = 0.25*virxx;
2429         work->vir_lj[YY][YY] = 0.25*viryy;
2430         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2431         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2432         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2433         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2434         /* This energy should be corrected for a charged system */
2435         work->energy_lj = 0.5*energy;
2436     }
2437     /* Return the loop count */
2438     return local_ndata[YY]*local_ndata[XX];
2439 }
2440
2441 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2442                                real *mesh_energy, matrix vir)
2443 {
2444     /* This function sums output over threads and should therefore
2445      * only be called after thread synchronization.
2446      */
2447     int thread;
2448
2449     *mesh_energy = pme->work[0].energy_q;
2450     copy_mat(pme->work[0].vir_q, vir);
2451
2452     for (thread = 1; thread < nthread; thread++)
2453     {
2454         *mesh_energy += pme->work[thread].energy_q;
2455         m_add(vir, pme->work[thread].vir_q, vir);
2456     }
2457 }
2458
2459 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2460                                 real *mesh_energy, matrix vir)
2461 {
2462     /* This function sums output over threads and should therefore
2463      * only be called after thread synchronization.
2464      */
2465     int thread;
2466
2467     *mesh_energy = pme->work[0].energy_lj;
2468     copy_mat(pme->work[0].vir_lj, vir);
2469
2470     for (thread = 1; thread < nthread; thread++)
2471     {
2472         *mesh_energy += pme->work[thread].energy_lj;
2473         m_add(vir, pme->work[thread].vir_lj, vir);
2474     }
2475 }
2476
2477
2478 #define DO_FSPLINE(order)                      \
2479     for (ithx = 0; (ithx < order); ithx++)              \
2480     {                                              \
2481         index_x = (i0+ithx)*pny*pnz;               \
2482         tx      = thx[ithx];                       \
2483         dx      = dthx[ithx];                      \
2484                                                \
2485         for (ithy = 0; (ithy < order); ithy++)          \
2486         {                                          \
2487             index_xy = index_x+(j0+ithy)*pnz;      \
2488             ty       = thy[ithy];                  \
2489             dy       = dthy[ithy];                 \
2490             fxy1     = fz1 = 0;                    \
2491                                                \
2492             for (ithz = 0; (ithz < order); ithz++)      \
2493             {                                      \
2494                 gval  = grid[index_xy+(k0+ithz)];  \
2495                 fxy1 += thz[ithz]*gval;            \
2496                 fz1  += dthz[ithz]*gval;           \
2497             }                                      \
2498             fx += dx*ty*fxy1;                      \
2499             fy += tx*dy*fxy1;                      \
2500             fz += tx*ty*fz1;                       \
2501         }                                          \
2502     }
2503
2504
2505 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2506                               gmx_bool bClearF, pme_atomcomm_t *atc,
2507                               splinedata_t *spline,
2508                               real scale)
2509 {
2510     /* sum forces for local particles */
2511     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2512     int     index_x, index_xy;
2513     int     nx, ny, nz, pnx, pny, pnz;
2514     int *   idxptr;
2515     real    tx, ty, dx, dy, qn;
2516     real    fx, fy, fz, gval;
2517     real    fxy1, fz1;
2518     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2519     int     norder;
2520     real    rxx, ryx, ryy, rzx, rzy, rzz;
2521     int     order;
2522
2523     pme_spline_work_t *work;
2524
2525 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2526     real           thz_buffer[GMX_SIMD4_WIDTH*3],  *thz_aligned;
2527     real           dthz_buffer[GMX_SIMD4_WIDTH*3], *dthz_aligned;
2528
2529     thz_aligned  = gmx_simd4_align_r(thz_buffer);
2530     dthz_aligned = gmx_simd4_align_r(dthz_buffer);
2531 #endif
2532
2533     work = pme->spline_work;
2534
2535     order = pme->pme_order;
2536     thx   = spline->theta[XX];
2537     thy   = spline->theta[YY];
2538     thz   = spline->theta[ZZ];
2539     dthx  = spline->dtheta[XX];
2540     dthy  = spline->dtheta[YY];
2541     dthz  = spline->dtheta[ZZ];
2542     nx    = pme->nkx;
2543     ny    = pme->nky;
2544     nz    = pme->nkz;
2545     pnx   = pme->pmegrid_nx;
2546     pny   = pme->pmegrid_ny;
2547     pnz   = pme->pmegrid_nz;
2548
2549     rxx   = pme->recipbox[XX][XX];
2550     ryx   = pme->recipbox[YY][XX];
2551     ryy   = pme->recipbox[YY][YY];
2552     rzx   = pme->recipbox[ZZ][XX];
2553     rzy   = pme->recipbox[ZZ][YY];
2554     rzz   = pme->recipbox[ZZ][ZZ];
2555
2556     for (nn = 0; nn < spline->n; nn++)
2557     {
2558         n  = spline->ind[nn];
2559         qn = scale*atc->q[n];
2560
2561         if (bClearF)
2562         {
2563             atc->f[n][XX] = 0;
2564             atc->f[n][YY] = 0;
2565             atc->f[n][ZZ] = 0;
2566         }
2567         if (qn != 0)
2568         {
2569             fx     = 0;
2570             fy     = 0;
2571             fz     = 0;
2572             idxptr = atc->idx[n];
2573             norder = nn*order;
2574
2575             i0   = idxptr[XX];
2576             j0   = idxptr[YY];
2577             k0   = idxptr[ZZ];
2578
2579             /* Pointer arithmetic alert, next six statements */
2580             thx  = spline->theta[XX] + norder;
2581             thy  = spline->theta[YY] + norder;
2582             thz  = spline->theta[ZZ] + norder;
2583             dthx = spline->dtheta[XX] + norder;
2584             dthy = spline->dtheta[YY] + norder;
2585             dthz = spline->dtheta[ZZ] + norder;
2586
2587             switch (order)
2588             {
2589                 case 4:
2590 #ifdef PME_SIMD4_SPREAD_GATHER
2591 #ifdef PME_SIMD4_UNALIGNED
2592 #define PME_GATHER_F_SIMD4_ORDER4
2593 #else
2594 #define PME_GATHER_F_SIMD4_ALIGNED
2595 #define PME_ORDER 4
2596 #endif
2597 #include "pme_simd4.h"
2598 #else
2599                     DO_FSPLINE(4);
2600 #endif
2601                     break;
2602                 case 5:
2603 #ifdef PME_SIMD4_SPREAD_GATHER
2604 #define PME_GATHER_F_SIMD4_ALIGNED
2605 #define PME_ORDER 5
2606 #include "pme_simd4.h"
2607 #else
2608                     DO_FSPLINE(5);
2609 #endif
2610                     break;
2611                 default:
2612                     DO_FSPLINE(order);
2613                     break;
2614             }
2615
2616             atc->f[n][XX] += -qn*( fx*nx*rxx );
2617             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2618             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2619         }
2620     }
2621     /* Since the energy and not forces are interpolated
2622      * the net force might not be exactly zero.
2623      * This can be solved by also interpolating F, but
2624      * that comes at a cost.
2625      * A better hack is to remove the net force every
2626      * step, but that must be done at a higher level
2627      * since this routine doesn't see all atoms if running
2628      * in parallel. Don't know how important it is?  EL 990726
2629      */
2630 }
2631
2632
2633 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2634                                    pme_atomcomm_t *atc)
2635 {
2636     splinedata_t *spline;
2637     int     n, ithx, ithy, ithz, i0, j0, k0;
2638     int     index_x, index_xy;
2639     int *   idxptr;
2640     real    energy, pot, tx, ty, qn, gval;
2641     real    *thx, *thy, *thz;
2642     int     norder;
2643     int     order;
2644
2645     spline = &atc->spline[0];
2646
2647     order = pme->pme_order;
2648
2649     energy = 0;
2650     for (n = 0; (n < atc->n); n++)
2651     {
2652         qn      = atc->q[n];
2653
2654         if (qn != 0)
2655         {
2656             idxptr = atc->idx[n];
2657             norder = n*order;
2658
2659             i0   = idxptr[XX];
2660             j0   = idxptr[YY];
2661             k0   = idxptr[ZZ];
2662
2663             /* Pointer arithmetic alert, next three statements */
2664             thx  = spline->theta[XX] + norder;
2665             thy  = spline->theta[YY] + norder;
2666             thz  = spline->theta[ZZ] + norder;
2667
2668             pot = 0;
2669             for (ithx = 0; (ithx < order); ithx++)
2670             {
2671                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2672                 tx      = thx[ithx];
2673
2674                 for (ithy = 0; (ithy < order); ithy++)
2675                 {
2676                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2677                     ty       = thy[ithy];
2678
2679                     for (ithz = 0; (ithz < order); ithz++)
2680                     {
2681                         gval  = grid[index_xy+(k0+ithz)];
2682                         pot  += tx*ty*thz[ithz]*gval;
2683                     }
2684
2685                 }
2686             }
2687
2688             energy += pot*qn;
2689         }
2690     }
2691
2692     return energy;
2693 }
2694
2695 /* Macro to force loop unrolling by fixing order.
2696  * This gives a significant performance gain.
2697  */
2698 #define CALC_SPLINE(order)                     \
2699     {                                              \
2700         int j, k, l;                                 \
2701         real dr, div;                               \
2702         real data[PME_ORDER_MAX];                  \
2703         real ddata[PME_ORDER_MAX];                 \
2704                                                \
2705         for (j = 0; (j < DIM); j++)                     \
2706         {                                          \
2707             dr  = xptr[j];                         \
2708                                                \
2709             /* dr is relative offset from lower cell limit */ \
2710             data[order-1] = 0;                     \
2711             data[1]       = dr;                          \
2712             data[0]       = 1 - dr;                      \
2713                                                \
2714             for (k = 3; (k < order); k++)               \
2715             {                                      \
2716                 div       = 1.0/(k - 1.0);               \
2717                 data[k-1] = div*dr*data[k-2];      \
2718                 for (l = 1; (l < (k-1)); l++)           \
2719                 {                                  \
2720                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2721                                        data[k-l-1]);                \
2722                 }                                  \
2723                 data[0] = div*(1-dr)*data[0];      \
2724             }                                      \
2725             /* differentiate */                    \
2726             ddata[0] = -data[0];                   \
2727             for (k = 1; (k < order); k++)               \
2728             {                                      \
2729                 ddata[k] = data[k-1] - data[k];    \
2730             }                                      \
2731                                                \
2732             div           = 1.0/(order - 1);                 \
2733             data[order-1] = div*dr*data[order-2];  \
2734             for (l = 1; (l < (order-1)); l++)           \
2735             {                                      \
2736                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2737                                        (order-l-dr)*data[order-l-1]); \
2738             }                                      \
2739             data[0] = div*(1 - dr)*data[0];        \
2740                                                \
2741             for (k = 0; k < order; k++)                 \
2742             {                                      \
2743                 theta[j][i*order+k]  = data[k];    \
2744                 dtheta[j][i*order+k] = ddata[k];   \
2745             }                                      \
2746         }                                          \
2747     }
2748
2749 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2750                    rvec fractx[], int nr, int ind[], real charge[],
2751                    gmx_bool bDoSplines)
2752 {
2753     /* construct splines for local atoms */
2754     int  i, ii;
2755     real *xptr;
2756
2757     for (i = 0; i < nr; i++)
2758     {
2759         /* With free energy we do not use the charge check.
2760          * In most cases this will be more efficient than calling make_bsplines
2761          * twice, since usually more than half the particles have charges.
2762          */
2763         ii = ind[i];
2764         if (bDoSplines || charge[ii] != 0.0)
2765         {
2766             xptr = fractx[ii];
2767             switch (order)
2768             {
2769                 case 4:  CALC_SPLINE(4);     break;
2770                 case 5:  CALC_SPLINE(5);     break;
2771                 default: CALC_SPLINE(order); break;
2772             }
2773         }
2774     }
2775 }
2776
2777
2778 void make_dft_mod(real *mod, real *data, int ndata)
2779 {
2780     int i, j;
2781     real sc, ss, arg;
2782
2783     for (i = 0; i < ndata; i++)
2784     {
2785         sc = ss = 0;
2786         for (j = 0; j < ndata; j++)
2787         {
2788             arg = (2.0*M_PI*i*j)/ndata;
2789             sc += data[j]*cos(arg);
2790             ss += data[j]*sin(arg);
2791         }
2792         mod[i] = sc*sc+ss*ss;
2793     }
2794     for (i = 0; i < ndata; i++)
2795     {
2796         if (mod[i] < 1e-7)
2797         {
2798             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2799         }
2800     }
2801 }
2802
2803
2804 static void make_bspline_moduli(splinevec bsp_mod,
2805                                 int nx, int ny, int nz, int order)
2806 {
2807     int nmax = max(nx, max(ny, nz));
2808     real *data, *ddata, *bsp_data;
2809     int i, k, l;
2810     real div;
2811
2812     snew(data, order);
2813     snew(ddata, order);
2814     snew(bsp_data, nmax);
2815
2816     data[order-1] = 0;
2817     data[1]       = 0;
2818     data[0]       = 1;
2819
2820     for (k = 3; k < order; k++)
2821     {
2822         div       = 1.0/(k-1.0);
2823         data[k-1] = 0;
2824         for (l = 1; l < (k-1); l++)
2825         {
2826             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2827         }
2828         data[0] = div*data[0];
2829     }
2830     /* differentiate */
2831     ddata[0] = -data[0];
2832     for (k = 1; k < order; k++)
2833     {
2834         ddata[k] = data[k-1]-data[k];
2835     }
2836     div           = 1.0/(order-1);
2837     data[order-1] = 0;
2838     for (l = 1; l < (order-1); l++)
2839     {
2840         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2841     }
2842     data[0] = div*data[0];
2843
2844     for (i = 0; i < nmax; i++)
2845     {
2846         bsp_data[i] = 0;
2847     }
2848     for (i = 1; i <= order; i++)
2849     {
2850         bsp_data[i] = data[i-1];
2851     }
2852
2853     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2854     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2855     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2856
2857     sfree(data);
2858     sfree(ddata);
2859     sfree(bsp_data);
2860 }
2861
2862
2863 /* Return the P3M optimal influence function */
2864 static double do_p3m_influence(double z, int order)
2865 {
2866     double z2, z4;
2867
2868     z2 = z*z;
2869     z4 = z2*z2;
2870
2871     /* The formula and most constants can be found in:
2872      * Ballenegger et al., JCTC 8, 936 (2012)
2873      */
2874     switch (order)
2875     {
2876         case 2:
2877             return 1.0 - 2.0*z2/3.0;
2878             break;
2879         case 3:
2880             return 1.0 - z2 + 2.0*z4/15.0;
2881             break;
2882         case 4:
2883             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2884             break;
2885         case 5:
2886             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2887             break;
2888         case 6:
2889             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2890             break;
2891         case 7:
2892             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2893         case 8:
2894             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2895             break;
2896     }
2897
2898     return 0.0;
2899 }
2900
2901 /* Calculate the P3M B-spline moduli for one dimension */
2902 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2903 {
2904     double zarg, zai, sinzai, infl;
2905     int    maxk, i;
2906
2907     if (order > 8)
2908     {
2909         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2910     }
2911
2912     zarg = M_PI/n;
2913
2914     maxk = (n + 1)/2;
2915
2916     for (i = -maxk; i < 0; i++)
2917     {
2918         zai          = zarg*i;
2919         sinzai       = sin(zai);
2920         infl         = do_p3m_influence(sinzai, order);
2921         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2922     }
2923     bsp_mod[0] = 1.0;
2924     for (i = 1; i < maxk; i++)
2925     {
2926         zai        = zarg*i;
2927         sinzai     = sin(zai);
2928         infl       = do_p3m_influence(sinzai, order);
2929         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2930     }
2931 }
2932
2933 /* Calculate the P3M B-spline moduli */
2934 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2935                                     int nx, int ny, int nz, int order)
2936 {
2937     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2938     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2939     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2940 }
2941
2942
2943 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2944 {
2945     int nslab, n, i;
2946     int fw, bw;
2947
2948     nslab = atc->nslab;
2949
2950     n = 0;
2951     for (i = 1; i <= nslab/2; i++)
2952     {
2953         fw = (atc->nodeid + i) % nslab;
2954         bw = (atc->nodeid - i + nslab) % nslab;
2955         if (n < nslab - 1)
2956         {
2957             atc->node_dest[n] = fw;
2958             atc->node_src[n]  = bw;
2959             n++;
2960         }
2961         if (n < nslab - 1)
2962         {
2963             atc->node_dest[n] = bw;
2964             atc->node_src[n]  = fw;
2965             n++;
2966         }
2967     }
2968 }
2969
2970 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2971 {
2972     int thread, i;
2973
2974     if (NULL != log)
2975     {
2976         fprintf(log, "Destroying PME data structures.\n");
2977     }
2978
2979     sfree((*pmedata)->nnx);
2980     sfree((*pmedata)->nny);
2981     sfree((*pmedata)->nnz);
2982
2983     for (i = 0; i < (*pmedata)->ngrids; ++i)
2984     {
2985         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
2986         sfree((*pmedata)->fftgrid[i]);
2987         sfree((*pmedata)->cfftgrid[i]);
2988         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
2989     }
2990
2991     sfree((*pmedata)->lb_buf1);
2992     sfree((*pmedata)->lb_buf2);
2993
2994     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2995     {
2996         free_work(&(*pmedata)->work[thread]);
2997     }
2998     sfree((*pmedata)->work);
2999
3000     sfree(*pmedata);
3001     *pmedata = NULL;
3002
3003     return 0;
3004 }
3005
3006 static int mult_up(int n, int f)
3007 {
3008     return ((n + f - 1)/f)*f;
3009 }
3010
3011
3012 static double pme_load_imbalance(gmx_pme_t pme)
3013 {
3014     int    nma, nmi;
3015     double n1, n2, n3;
3016
3017     nma = pme->nnodes_major;
3018     nmi = pme->nnodes_minor;
3019
3020     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3021     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3022     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3023
3024     /* pme_solve is roughly double the cost of an fft */
3025
3026     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3027 }
3028
3029 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3030                           int dimind, gmx_bool bSpread)
3031 {
3032     int nk, k, s, thread;
3033
3034     atc->dimind    = dimind;
3035     atc->nslab     = 1;
3036     atc->nodeid    = 0;
3037     atc->pd_nalloc = 0;
3038 #ifdef GMX_MPI
3039     if (pme->nnodes > 1)
3040     {
3041         atc->mpi_comm = pme->mpi_comm_d[dimind];
3042         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3043         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3044     }
3045     if (debug)
3046     {
3047         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3048     }
3049 #endif
3050
3051     atc->bSpread   = bSpread;
3052     atc->pme_order = pme->pme_order;
3053
3054     if (atc->nslab > 1)
3055     {
3056         snew(atc->node_dest, atc->nslab);
3057         snew(atc->node_src, atc->nslab);
3058         setup_coordinate_communication(atc);
3059
3060         snew(atc->count_thread, pme->nthread);
3061         for (thread = 0; thread < pme->nthread; thread++)
3062         {
3063             snew(atc->count_thread[thread], atc->nslab);
3064         }
3065         atc->count = atc->count_thread[0];
3066         snew(atc->rcount, atc->nslab);
3067         snew(atc->buf_index, atc->nslab);
3068     }
3069
3070     atc->nthread = pme->nthread;
3071     if (atc->nthread > 1)
3072     {
3073         snew(atc->thread_plist, atc->nthread);
3074     }
3075     snew(atc->spline, atc->nthread);
3076     for (thread = 0; thread < atc->nthread; thread++)
3077     {
3078         if (atc->nthread > 1)
3079         {
3080             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3081             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3082         }
3083         snew(atc->spline[thread].thread_one, pme->nthread);
3084         atc->spline[thread].thread_one[thread] = 1;
3085     }
3086 }
3087
3088 static void
3089 init_overlap_comm(pme_overlap_t *  ol,
3090                   int              norder,
3091 #ifdef GMX_MPI
3092                   MPI_Comm         comm,
3093 #endif
3094                   int              nnodes,
3095                   int              nodeid,
3096                   int              ndata,
3097                   int              commplainsize)
3098 {
3099     int lbnd, rbnd, maxlr, b, i;
3100     int exten;
3101     int nn, nk;
3102     pme_grid_comm_t *pgc;
3103     gmx_bool bCont;
3104     int fft_start, fft_end, send_index1, recv_index1;
3105 #ifdef GMX_MPI
3106     MPI_Status stat;
3107
3108     ol->mpi_comm = comm;
3109 #endif
3110
3111     ol->nnodes = nnodes;
3112     ol->nodeid = nodeid;
3113
3114     /* Linear translation of the PME grid won't affect reciprocal space
3115      * calculations, so to optimize we only interpolate "upwards",
3116      * which also means we only have to consider overlap in one direction.
3117      * I.e., particles on this node might also be spread to grid indices
3118      * that belong to higher nodes (modulo nnodes)
3119      */
3120
3121     snew(ol->s2g0, ol->nnodes+1);
3122     snew(ol->s2g1, ol->nnodes);
3123     if (debug)
3124     {
3125         fprintf(debug, "PME slab boundaries:");
3126     }
3127     for (i = 0; i < nnodes; i++)
3128     {
3129         /* s2g0 the local interpolation grid start.
3130          * s2g1 the local interpolation grid end.
3131          * Because grid overlap communication only goes forward,
3132          * the grid the slabs for fft's should be rounded down.
3133          */
3134         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3135         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3136
3137         if (debug)
3138         {
3139             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3140         }
3141     }
3142     ol->s2g0[nnodes] = ndata;
3143     if (debug)
3144     {
3145         fprintf(debug, "\n");
3146     }
3147
3148     /* Determine with how many nodes we need to communicate the grid overlap */
3149     b = 0;
3150     do
3151     {
3152         b++;
3153         bCont = FALSE;
3154         for (i = 0; i < nnodes; i++)
3155         {
3156             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3157                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3158             {
3159                 bCont = TRUE;
3160             }
3161         }
3162     }
3163     while (bCont && b < nnodes);
3164     ol->noverlap_nodes = b - 1;
3165
3166     snew(ol->send_id, ol->noverlap_nodes);
3167     snew(ol->recv_id, ol->noverlap_nodes);
3168     for (b = 0; b < ol->noverlap_nodes; b++)
3169     {
3170         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3171         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3172     }
3173     snew(ol->comm_data, ol->noverlap_nodes);
3174
3175     ol->send_size = 0;
3176     for (b = 0; b < ol->noverlap_nodes; b++)
3177     {
3178         pgc = &ol->comm_data[b];
3179         /* Send */
3180         fft_start        = ol->s2g0[ol->send_id[b]];
3181         fft_end          = ol->s2g0[ol->send_id[b]+1];
3182         if (ol->send_id[b] < nodeid)
3183         {
3184             fft_start += ndata;
3185             fft_end   += ndata;
3186         }
3187         send_index1       = ol->s2g1[nodeid];
3188         send_index1       = min(send_index1, fft_end);
3189         pgc->send_index0  = fft_start;
3190         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3191         ol->send_size    += pgc->send_nindex;
3192
3193         /* We always start receiving to the first index of our slab */
3194         fft_start        = ol->s2g0[ol->nodeid];
3195         fft_end          = ol->s2g0[ol->nodeid+1];
3196         recv_index1      = ol->s2g1[ol->recv_id[b]];
3197         if (ol->recv_id[b] > nodeid)
3198         {
3199             recv_index1 -= ndata;
3200         }
3201         recv_index1      = min(recv_index1, fft_end);
3202         pgc->recv_index0 = fft_start;
3203         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3204     }
3205
3206 #ifdef GMX_MPI
3207     /* Communicate the buffer sizes to receive */
3208     for (b = 0; b < ol->noverlap_nodes; b++)
3209     {
3210         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3211                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3212                      ol->mpi_comm, &stat);
3213     }
3214 #endif
3215
3216     /* For non-divisible grid we need pme_order iso pme_order-1 */
3217     snew(ol->sendbuf, norder*commplainsize);
3218     snew(ol->recvbuf, norder*commplainsize);
3219 }
3220
3221 static void
3222 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3223                               int **global_to_local,
3224                               real **fraction_shift)
3225 {
3226     int i;
3227     int * gtl;
3228     real * fsh;
3229
3230     snew(gtl, 5*n);
3231     snew(fsh, 5*n);
3232     for (i = 0; (i < 5*n); i++)
3233     {
3234         /* Determine the global to local grid index */
3235         gtl[i] = (i - local_start + n) % n;
3236         /* For coordinates that fall within the local grid the fraction
3237          * is correct, we don't need to shift it.
3238          */
3239         fsh[i] = 0;
3240         if (local_range < n)
3241         {
3242             /* Due to rounding issues i could be 1 beyond the lower or
3243              * upper boundary of the local grid. Correct the index for this.
3244              * If we shift the index, we need to shift the fraction by
3245              * the same amount in the other direction to not affect
3246              * the weights.
3247              * Note that due to this shifting the weights at the end of
3248              * the spline might change, but that will only involve values
3249              * between zero and values close to the precision of a real,
3250              * which is anyhow the accuracy of the whole mesh calculation.
3251              */
3252             /* With local_range=0 we should not change i=local_start */
3253             if (i % n != local_start)
3254             {
3255                 if (gtl[i] == n-1)
3256                 {
3257                     gtl[i] = 0;
3258                     fsh[i] = -1;
3259                 }
3260                 else if (gtl[i] == local_range)
3261                 {
3262                     gtl[i] = local_range - 1;
3263                     fsh[i] = 1;
3264                 }
3265             }
3266         }
3267     }
3268
3269     *global_to_local = gtl;
3270     *fraction_shift  = fsh;
3271 }
3272
3273 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3274 {
3275     pme_spline_work_t *work;
3276
3277 #ifdef PME_SIMD4_SPREAD_GATHER
3278     real             tmp[GMX_SIMD4_WIDTH*3], *tmp_aligned;
3279     gmx_simd4_real_t zero_S;
3280     gmx_simd4_real_t real_mask_S0, real_mask_S1;
3281     int              of, i;
3282
3283     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3284
3285     tmp_aligned = gmx_simd4_align_r(tmp);
3286
3287     zero_S = gmx_simd4_setzero_r();
3288
3289     /* Generate bit masks to mask out the unused grid entries,
3290      * as we only operate on order of the 8 grid entries that are
3291      * load into 2 SIMD registers.
3292      */
3293     for (of = 0; of < 2*GMX_SIMD4_WIDTH-(order-1); of++)
3294     {
3295         for (i = 0; i < 2*GMX_SIMD4_WIDTH; i++)
3296         {
3297             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3298         }
3299         real_mask_S0      = gmx_simd4_load_r(tmp_aligned);
3300         real_mask_S1      = gmx_simd4_load_r(tmp_aligned+GMX_SIMD4_WIDTH);
3301         work->mask_S0[of] = gmx_simd4_cmplt_r(real_mask_S0, zero_S);
3302         work->mask_S1[of] = gmx_simd4_cmplt_r(real_mask_S1, zero_S);
3303     }
3304 #else
3305     work = NULL;
3306 #endif
3307
3308     return work;
3309 }
3310
3311 void gmx_pme_check_restrictions(int pme_order,
3312                                 int nkx, int nky, int nkz,
3313                                 int nnodes_major,
3314                                 int nnodes_minor,
3315                                 gmx_bool bUseThreads,
3316                                 gmx_bool bFatal,
3317                                 gmx_bool *bValidSettings)
3318 {
3319     if (pme_order > PME_ORDER_MAX)
3320     {
3321         if (!bFatal)
3322         {
3323             *bValidSettings = FALSE;
3324             return;
3325         }
3326         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3327                   pme_order, PME_ORDER_MAX);
3328     }
3329
3330     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3331         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3332         nkz <= pme_order)
3333     {
3334         if (!bFatal)
3335         {
3336             *bValidSettings = FALSE;
3337             return;
3338         }
3339         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3340                   pme_order);
3341     }
3342
3343     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3344      * We only allow multiple communication pulses in dim 1, not in dim 0.
3345      */
3346     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3347                         nkx != nnodes_major*(pme_order - 1)))
3348     {
3349         if (!bFatal)
3350         {
3351             *bValidSettings = FALSE;
3352             return;
3353         }
3354         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3355                   nkx/(double)nnodes_major, pme_order);
3356     }
3357
3358     if (bValidSettings != NULL)
3359     {
3360         *bValidSettings = TRUE;
3361     }
3362
3363     return;
3364 }
3365
3366 int gmx_pme_init(gmx_pme_t *         pmedata,
3367                  t_commrec *         cr,
3368                  int                 nnodes_major,
3369                  int                 nnodes_minor,
3370                  t_inputrec *        ir,
3371                  int                 homenr,
3372                  gmx_bool            bFreeEnergy_q,
3373                  gmx_bool            bFreeEnergy_lj,
3374                  gmx_bool            bReproducible,
3375                  int                 nthread)
3376 {
3377     gmx_pme_t pme = NULL;
3378
3379     int  use_threads, sum_use_threads, i;
3380     ivec ndata;
3381
3382     if (debug)
3383     {
3384         fprintf(debug, "Creating PME data structures.\n");
3385     }
3386     snew(pme, 1);
3387
3388     pme->sum_qgrid_tmp       = NULL;
3389     pme->sum_qgrid_dd_tmp    = NULL;
3390     pme->buf_nalloc          = 0;
3391
3392     pme->nnodes              = 1;
3393     pme->bPPnode             = TRUE;
3394
3395     pme->nnodes_major        = nnodes_major;
3396     pme->nnodes_minor        = nnodes_minor;
3397
3398 #ifdef GMX_MPI
3399     if (nnodes_major*nnodes_minor > 1)
3400     {
3401         pme->mpi_comm = cr->mpi_comm_mygroup;
3402
3403         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3404         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3405         if (pme->nnodes != nnodes_major*nnodes_minor)
3406         {
3407             gmx_incons("PME node count mismatch");
3408         }
3409     }
3410     else
3411     {
3412         pme->mpi_comm = MPI_COMM_NULL;
3413     }
3414 #endif
3415
3416     if (pme->nnodes == 1)
3417     {
3418 #ifdef GMX_MPI
3419         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3420         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3421 #endif
3422         pme->ndecompdim   = 0;
3423         pme->nodeid_major = 0;
3424         pme->nodeid_minor = 0;
3425 #ifdef GMX_MPI
3426         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3427 #endif
3428     }
3429     else
3430     {
3431         if (nnodes_minor == 1)
3432         {
3433 #ifdef GMX_MPI
3434             pme->mpi_comm_d[0] = pme->mpi_comm;
3435             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3436 #endif
3437             pme->ndecompdim   = 1;
3438             pme->nodeid_major = pme->nodeid;
3439             pme->nodeid_minor = 0;
3440
3441         }
3442         else if (nnodes_major == 1)
3443         {
3444 #ifdef GMX_MPI
3445             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3446             pme->mpi_comm_d[1] = pme->mpi_comm;
3447 #endif
3448             pme->ndecompdim   = 1;
3449             pme->nodeid_major = 0;
3450             pme->nodeid_minor = pme->nodeid;
3451         }
3452         else
3453         {
3454             if (pme->nnodes % nnodes_major != 0)
3455             {
3456                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3457             }
3458             pme->ndecompdim = 2;
3459
3460 #ifdef GMX_MPI
3461             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3462                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3463             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3464                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3465
3466             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3467             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3468             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3469             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3470 #endif
3471         }
3472         pme->bPPnode = (cr->duty & DUTY_PP);
3473     }
3474
3475     pme->nthread = nthread;
3476
3477     /* Check if any of the PME MPI ranks uses threads */
3478     use_threads = (pme->nthread > 1 ? 1 : 0);
3479 #ifdef GMX_MPI
3480     if (pme->nnodes > 1)
3481     {
3482         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3483                       MPI_SUM, pme->mpi_comm);
3484     }
3485     else
3486 #endif
3487     {
3488         sum_use_threads = use_threads;
3489     }
3490     pme->bUseThreads = (sum_use_threads > 0);
3491
3492     if (ir->ePBC == epbcSCREW)
3493     {
3494         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3495     }
3496
3497     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3498     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3499     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3500     pme->nkx         = ir->nkx;
3501     pme->nky         = ir->nky;
3502     pme->nkz         = ir->nkz;
3503     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3504     pme->pme_order   = ir->pme_order;
3505     pme->epsilon_r   = ir->epsilon_r;
3506
3507     /* If we violate restrictions, generate a fatal error here */
3508     gmx_pme_check_restrictions(pme->pme_order,
3509                                pme->nkx, pme->nky, pme->nkz,
3510                                pme->nnodes_major,
3511                                pme->nnodes_minor,
3512                                pme->bUseThreads,
3513                                TRUE,
3514                                NULL);
3515
3516     if (pme->nnodes > 1)
3517     {
3518         double imbal;
3519
3520 #ifdef GMX_MPI
3521         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3522         MPI_Type_commit(&(pme->rvec_mpi));
3523 #endif
3524
3525         /* Note that the charge spreading and force gathering, which usually
3526          * takes about the same amount of time as FFT+solve_pme,
3527          * is always fully load balanced
3528          * (unless the charge distribution is inhomogeneous).
3529          */
3530
3531         imbal = pme_load_imbalance(pme);
3532         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3533         {
3534             fprintf(stderr,
3535                     "\n"
3536                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3537                     "      For optimal PME load balancing\n"
3538                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3539                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3540                     "\n",
3541                     (int)((imbal-1)*100 + 0.5),
3542                     pme->nkx, pme->nky, pme->nnodes_major,
3543                     pme->nky, pme->nkz, pme->nnodes_minor);
3544         }
3545     }
3546
3547     /* For non-divisible grid we need pme_order iso pme_order-1 */
3548     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3549      * y is always copied through a buffer: we don't need padding in z,
3550      * but we do need the overlap in x because of the communication order.
3551      */
3552     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3553 #ifdef GMX_MPI
3554                       pme->mpi_comm_d[0],
3555 #endif
3556                       pme->nnodes_major, pme->nodeid_major,
3557                       pme->nkx,
3558                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3559
3560     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3561      * We do this with an offset buffer of equal size, so we need to allocate
3562      * extra for the offset. That's what the (+1)*pme->nkz is for.
3563      */
3564     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3565 #ifdef GMX_MPI
3566                       pme->mpi_comm_d[1],
3567 #endif
3568                       pme->nnodes_minor, pme->nodeid_minor,
3569                       pme->nky,
3570                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3571
3572     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3573      * Note that gmx_pme_check_restrictions checked for this already.
3574      */
3575     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3576     {
3577         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3578     }
3579
3580     snew(pme->bsp_mod[XX], pme->nkx);
3581     snew(pme->bsp_mod[YY], pme->nky);
3582     snew(pme->bsp_mod[ZZ], pme->nkz);
3583
3584     /* The required size of the interpolation grid, including overlap.
3585      * The allocated size (pmegrid_n?) might be slightly larger.
3586      */
3587     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3588         pme->overlap[0].s2g0[pme->nodeid_major];
3589     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3590         pme->overlap[1].s2g0[pme->nodeid_minor];
3591     pme->pmegrid_nz_base = pme->nkz;
3592     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3593     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3594
3595     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3596     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3597     pme->pmegrid_start_iz = 0;
3598
3599     make_gridindex5_to_localindex(pme->nkx,
3600                                   pme->pmegrid_start_ix,
3601                                   pme->pmegrid_nx - (pme->pme_order-1),
3602                                   &pme->nnx, &pme->fshx);
3603     make_gridindex5_to_localindex(pme->nky,
3604                                   pme->pmegrid_start_iy,
3605                                   pme->pmegrid_ny - (pme->pme_order-1),
3606                                   &pme->nny, &pme->fshy);
3607     make_gridindex5_to_localindex(pme->nkz,
3608                                   pme->pmegrid_start_iz,
3609                                   pme->pmegrid_nz_base,
3610                                   &pme->nnz, &pme->fshz);
3611
3612     pme->spline_work = make_pme_spline_work(pme->pme_order);
3613
3614     ndata[0]    = pme->nkx;
3615     ndata[1]    = pme->nky;
3616     ndata[2]    = pme->nkz;
3617     pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3618     snew(pme->fftgrid, pme->ngrids);
3619     snew(pme->cfftgrid, pme->ngrids);
3620     snew(pme->pfft_setup, pme->ngrids);
3621
3622     for (i = 0; i < pme->ngrids; ++i)
3623     {
3624         if (((ir->ljpme_combination_rule == eljpmeLB) && i >= 2) || i % 2 == 0 || bFreeEnergy_q || bFreeEnergy_lj)
3625         {
3626             pmegrids_init(&pme->pmegrid[i],
3627                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3628                           pme->pmegrid_nz_base,
3629                           pme->pme_order,
3630                           pme->bUseThreads,
3631                           pme->nthread,
3632                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3633                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3634             /* This routine will allocate the grid data to fit the FFTs */
3635             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3636                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3637                                     pme->mpi_comm_d,
3638                                     bReproducible, pme->nthread);
3639
3640         }
3641     }
3642
3643     if (!pme->bP3M)
3644     {
3645         /* Use plain SPME B-spline interpolation */
3646         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3647     }
3648     else
3649     {
3650         /* Use the P3M grid-optimized influence function */
3651         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3652     }
3653
3654     /* Use atc[0] for spreading */
3655     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3656     if (pme->ndecompdim >= 2)
3657     {
3658         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3659     }
3660
3661     if (pme->nnodes == 1)
3662     {
3663         pme->atc[0].n = homenr;
3664         pme_realloc_atomcomm_things(&pme->atc[0]);
3665     }
3666
3667     pme->lb_buf1       = NULL;
3668     pme->lb_buf2       = NULL;
3669     pme->lb_buf_nalloc = 0;
3670
3671     {
3672         int thread;
3673
3674         /* Use fft5d, order after FFT is y major, z, x minor */
3675
3676         snew(pme->work, pme->nthread);
3677         for (thread = 0; thread < pme->nthread; thread++)
3678         {
3679             realloc_work(&pme->work[thread], pme->nkx);
3680         }
3681     }
3682
3683     *pmedata = pme;
3684
3685     return 0;
3686 }
3687
3688 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3689 {
3690     int d, t;
3691
3692     for (d = 0; d < DIM; d++)
3693     {
3694         if (new->grid.n[d] > old->grid.n[d])
3695         {
3696             return;
3697         }
3698     }
3699
3700     sfree_aligned(new->grid.grid);
3701     new->grid.grid = old->grid.grid;
3702
3703     if (new->grid_th != NULL && new->nthread == old->nthread)
3704     {
3705         sfree_aligned(new->grid_all);
3706         for (t = 0; t < new->nthread; t++)
3707         {
3708             new->grid_th[t].grid = old->grid_th[t].grid;
3709         }
3710     }
3711 }
3712
3713 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3714                    t_commrec *         cr,
3715                    gmx_pme_t           pme_src,
3716                    const t_inputrec *  ir,
3717                    ivec                grid_size)
3718 {
3719     t_inputrec irc;
3720     int homenr;
3721     int ret;
3722
3723     irc     = *ir;
3724     irc.nkx = grid_size[XX];
3725     irc.nky = grid_size[YY];
3726     irc.nkz = grid_size[ZZ];
3727
3728     if (pme_src->nnodes == 1)
3729     {
3730         homenr = pme_src->atc[0].n;
3731     }
3732     else
3733     {
3734         homenr = -1;
3735     }
3736
3737     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3738                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3739
3740     if (ret == 0)
3741     {
3742         /* We can easily reuse the allocated pme grids in pme_src */
3743         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3744         /* We would like to reuse the fft grids, but that's harder */
3745     }
3746
3747     return ret;
3748 }
3749
3750
3751 static void copy_local_grid(gmx_pme_t pme,
3752                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3753 {
3754     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3755     int  fft_my, fft_mz;
3756     int  nsx, nsy, nsz;
3757     ivec nf;
3758     int  offx, offy, offz, x, y, z, i0, i0t;
3759     int  d;
3760     pmegrid_t *pmegrid;
3761     real *grid_th;
3762
3763     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
3764                                    local_fft_ndata,
3765                                    local_fft_offset,
3766                                    local_fft_size);
3767     fft_my = local_fft_size[YY];
3768     fft_mz = local_fft_size[ZZ];
3769
3770     pmegrid = &pmegrids->grid_th[thread];
3771
3772     nsx = pmegrid->s[XX];
3773     nsy = pmegrid->s[YY];
3774     nsz = pmegrid->s[ZZ];
3775
3776     for (d = 0; d < DIM; d++)
3777     {
3778         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3779                     local_fft_ndata[d] - pmegrid->offset[d]);
3780     }
3781
3782     offx = pmegrid->offset[XX];
3783     offy = pmegrid->offset[YY];
3784     offz = pmegrid->offset[ZZ];
3785
3786     /* Directly copy the non-overlapping parts of the local grids.
3787      * This also initializes the full grid.
3788      */
3789     grid_th = pmegrid->grid;
3790     for (x = 0; x < nf[XX]; x++)
3791     {
3792         for (y = 0; y < nf[YY]; y++)
3793         {
3794             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3795             i0t = (x*nsy + y)*nsz;
3796             for (z = 0; z < nf[ZZ]; z++)
3797             {
3798                 fftgrid[i0+z] = grid_th[i0t+z];
3799             }
3800         }
3801     }
3802 }
3803
3804 static void
3805 reduce_threadgrid_overlap(gmx_pme_t pme,
3806                           const pmegrids_t *pmegrids, int thread,
3807                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3808 {
3809     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3810     int  fft_nx, fft_ny, fft_nz;
3811     int  fft_my, fft_mz;
3812     int  buf_my = -1;
3813     int  nsx, nsy, nsz;
3814     ivec ne;
3815     int  offx, offy, offz, x, y, z, i0, i0t;
3816     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3817     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3818     gmx_bool bCommX, bCommY;
3819     int  d;
3820     int  thread_f;
3821     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3822     const real *grid_th;
3823     real *commbuf = NULL;
3824
3825     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
3826                                    local_fft_ndata,
3827                                    local_fft_offset,
3828                                    local_fft_size);
3829     fft_nx = local_fft_ndata[XX];
3830     fft_ny = local_fft_ndata[YY];
3831     fft_nz = local_fft_ndata[ZZ];
3832
3833     fft_my = local_fft_size[YY];
3834     fft_mz = local_fft_size[ZZ];
3835
3836     /* This routine is called when all thread have finished spreading.
3837      * Here each thread sums grid contributions calculated by other threads
3838      * to the thread local grid volume.
3839      * To minimize the number of grid copying operations,
3840      * this routines sums immediately from the pmegrid to the fftgrid.
3841      */
3842
3843     /* Determine which part of the full node grid we should operate on,
3844      * this is our thread local part of the full grid.
3845      */
3846     pmegrid = &pmegrids->grid_th[thread];
3847
3848     for (d = 0; d < DIM; d++)
3849     {
3850         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3851                     local_fft_ndata[d]);
3852     }
3853
3854     offx = pmegrid->offset[XX];
3855     offy = pmegrid->offset[YY];
3856     offz = pmegrid->offset[ZZ];
3857
3858
3859     bClearBufX  = TRUE;
3860     bClearBufY  = TRUE;
3861     bClearBufXY = TRUE;
3862
3863     /* Now loop over all the thread data blocks that contribute
3864      * to the grid region we (our thread) are operating on.
3865      */
3866     /* Note that ffy_nx/y is equal to the number of grid points
3867      * between the first point of our node grid and the one of the next node.
3868      */
3869     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3870     {
3871         fx     = pmegrid->ci[XX] + sx;
3872         ox     = 0;
3873         bCommX = FALSE;
3874         if (fx < 0)
3875         {
3876             fx    += pmegrids->nc[XX];
3877             ox    -= fft_nx;
3878             bCommX = (pme->nnodes_major > 1);
3879         }
3880         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3881         ox       += pmegrid_g->offset[XX];
3882         if (!bCommX)
3883         {
3884             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3885         }
3886         else
3887         {
3888             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3889         }
3890
3891         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3892         {
3893             fy     = pmegrid->ci[YY] + sy;
3894             oy     = 0;
3895             bCommY = FALSE;
3896             if (fy < 0)
3897             {
3898                 fy    += pmegrids->nc[YY];
3899                 oy    -= fft_ny;
3900                 bCommY = (pme->nnodes_minor > 1);
3901             }
3902             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3903             oy       += pmegrid_g->offset[YY];
3904             if (!bCommY)
3905             {
3906                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3907             }
3908             else
3909             {
3910                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3911             }
3912
3913             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3914             {
3915                 fz = pmegrid->ci[ZZ] + sz;
3916                 oz = 0;
3917                 if (fz < 0)
3918                 {
3919                     fz += pmegrids->nc[ZZ];
3920                     oz -= fft_nz;
3921                 }
3922                 pmegrid_g = &pmegrids->grid_th[fz];
3923                 oz       += pmegrid_g->offset[ZZ];
3924                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3925
3926                 if (sx == 0 && sy == 0 && sz == 0)
3927                 {
3928                     /* We have already added our local contribution
3929                      * before calling this routine, so skip it here.
3930                      */
3931                     continue;
3932                 }
3933
3934                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3935
3936                 pmegrid_f = &pmegrids->grid_th[thread_f];
3937
3938                 grid_th = pmegrid_f->grid;
3939
3940                 nsx = pmegrid_f->s[XX];
3941                 nsy = pmegrid_f->s[YY];
3942                 nsz = pmegrid_f->s[ZZ];
3943
3944 #ifdef DEBUG_PME_REDUCE
3945                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3946                        pme->nodeid, thread, thread_f,
3947                        pme->pmegrid_start_ix,
3948                        pme->pmegrid_start_iy,
3949                        pme->pmegrid_start_iz,
3950                        sx, sy, sz,
3951                        offx-ox, tx1-ox, offx, tx1,
3952                        offy-oy, ty1-oy, offy, ty1,
3953                        offz-oz, tz1-oz, offz, tz1);
3954 #endif
3955
3956                 if (!(bCommX || bCommY))
3957                 {
3958                     /* Copy from the thread local grid to the node grid */
3959                     for (x = offx; x < tx1; x++)
3960                     {
3961                         for (y = offy; y < ty1; y++)
3962                         {
3963                             i0  = (x*fft_my + y)*fft_mz;
3964                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3965                             for (z = offz; z < tz1; z++)
3966                             {
3967                                 fftgrid[i0+z] += grid_th[i0t+z];
3968                             }
3969                         }
3970                     }
3971                 }
3972                 else
3973                 {
3974                     /* The order of this conditional decides
3975                      * where the corner volume gets stored with x+y decomp.
3976                      */
3977                     if (bCommY)
3978                     {
3979                         commbuf = commbuf_y;
3980                         buf_my  = ty1 - offy;
3981                         if (bCommX)
3982                         {
3983                             /* We index commbuf modulo the local grid size */
3984                             commbuf += buf_my*fft_nx*fft_nz;
3985
3986                             bClearBuf   = bClearBufXY;
3987                             bClearBufXY = FALSE;
3988                         }
3989                         else
3990                         {
3991                             bClearBuf  = bClearBufY;
3992                             bClearBufY = FALSE;
3993                         }
3994                     }
3995                     else
3996                     {
3997                         commbuf    = commbuf_x;
3998                         buf_my     = fft_ny;
3999                         bClearBuf  = bClearBufX;
4000                         bClearBufX = FALSE;
4001                     }
4002
4003                     /* Copy to the communication buffer */
4004                     for (x = offx; x < tx1; x++)
4005                     {
4006                         for (y = offy; y < ty1; y++)
4007                         {
4008                             i0  = (x*buf_my + y)*fft_nz;
4009                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4010
4011                             if (bClearBuf)
4012                             {
4013                                 /* First access of commbuf, initialize it */
4014                                 for (z = offz; z < tz1; z++)
4015                                 {
4016                                     commbuf[i0+z]  = grid_th[i0t+z];
4017                                 }
4018                             }
4019                             else
4020                             {
4021                                 for (z = offz; z < tz1; z++)
4022                                 {
4023                                     commbuf[i0+z] += grid_th[i0t+z];
4024                                 }
4025                             }
4026                         }
4027                     }
4028                 }
4029             }
4030         }
4031     }
4032 }
4033
4034
4035 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
4036 {
4037     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4038     pme_overlap_t *overlap;
4039     int  send_index0, send_nindex;
4040     int  recv_nindex;
4041 #ifdef GMX_MPI
4042     MPI_Status stat;
4043 #endif
4044     int  send_size_y, recv_size_y;
4045     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4046     real *sendptr, *recvptr;
4047     int  x, y, z, indg, indb;
4048
4049     /* Note that this routine is only used for forward communication.
4050      * Since the force gathering, unlike the charge spreading,
4051      * can be trivially parallelized over the particles,
4052      * the backwards process is much simpler and can use the "old"
4053      * communication setup.
4054      */
4055
4056     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4057                                    local_fft_ndata,
4058                                    local_fft_offset,
4059                                    local_fft_size);
4060
4061     if (pme->nnodes_minor > 1)
4062     {
4063         /* Major dimension */
4064         overlap = &pme->overlap[1];
4065
4066         if (pme->nnodes_major > 1)
4067         {
4068             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4069         }
4070         else
4071         {
4072             size_yx = 0;
4073         }
4074         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4075
4076         send_size_y = overlap->send_size;
4077
4078         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4079         {
4080             send_id       = overlap->send_id[ipulse];
4081             recv_id       = overlap->recv_id[ipulse];
4082             send_index0   =
4083                 overlap->comm_data[ipulse].send_index0 -
4084                 overlap->comm_data[0].send_index0;
4085             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4086             /* We don't use recv_index0, as we always receive starting at 0 */
4087             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4088             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4089
4090             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4091             recvptr = overlap->recvbuf;
4092
4093 #ifdef GMX_MPI
4094             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4095                          send_id, ipulse,
4096                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4097                          recv_id, ipulse,
4098                          overlap->mpi_comm, &stat);
4099 #endif
4100
4101             for (x = 0; x < local_fft_ndata[XX]; x++)
4102             {
4103                 for (y = 0; y < recv_nindex; y++)
4104                 {
4105                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4106                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4107                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4108                     {
4109                         fftgrid[indg+z] += recvptr[indb+z];
4110                     }
4111                 }
4112             }
4113
4114             if (pme->nnodes_major > 1)
4115             {
4116                 /* Copy from the received buffer to the send buffer for dim 0 */
4117                 sendptr = pme->overlap[0].sendbuf;
4118                 for (x = 0; x < size_yx; x++)
4119                 {
4120                     for (y = 0; y < recv_nindex; y++)
4121                     {
4122                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4123                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4124                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4125                         {
4126                             sendptr[indg+z] += recvptr[indb+z];
4127                         }
4128                     }
4129                 }
4130             }
4131         }
4132     }
4133
4134     /* We only support a single pulse here.
4135      * This is not a severe limitation, as this code is only used
4136      * with OpenMP and with OpenMP the (PME) domains can be larger.
4137      */
4138     if (pme->nnodes_major > 1)
4139     {
4140         /* Major dimension */
4141         overlap = &pme->overlap[0];
4142
4143         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4144         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4145
4146         ipulse = 0;
4147
4148         send_id       = overlap->send_id[ipulse];
4149         recv_id       = overlap->recv_id[ipulse];
4150         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4151         /* We don't use recv_index0, as we always receive starting at 0 */
4152         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4153
4154         sendptr = overlap->sendbuf;
4155         recvptr = overlap->recvbuf;
4156
4157         if (debug != NULL)
4158         {
4159             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
4160                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4161         }
4162
4163 #ifdef GMX_MPI
4164         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4165                      send_id, ipulse,
4166                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4167                      recv_id, ipulse,
4168                      overlap->mpi_comm, &stat);
4169 #endif
4170
4171         for (x = 0; x < recv_nindex; x++)
4172         {
4173             for (y = 0; y < local_fft_ndata[YY]; y++)
4174             {
4175                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4176                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4177                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4178                 {
4179                     fftgrid[indg+z] += recvptr[indb+z];
4180                 }
4181             }
4182         }
4183     }
4184 }
4185
4186
4187 static void spread_on_grid(gmx_pme_t pme,
4188                            pme_atomcomm_t *atc, pmegrids_t *grids,
4189                            gmx_bool bCalcSplines, gmx_bool bSpread,
4190                            real *fftgrid, gmx_bool bDoSplines )
4191 {
4192     int nthread, thread;
4193 #ifdef PME_TIME_THREADS
4194     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4195     static double cs1     = 0, cs2 = 0, cs3 = 0;
4196     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4197     static int cnt        = 0;
4198 #endif
4199
4200     nthread = pme->nthread;
4201     assert(nthread > 0);
4202
4203 #ifdef PME_TIME_THREADS
4204     c1 = omp_cyc_start();
4205 #endif
4206     if (bCalcSplines)
4207     {
4208 #pragma omp parallel for num_threads(nthread) schedule(static)
4209         for (thread = 0; thread < nthread; thread++)
4210         {
4211             int start, end;
4212
4213             start = atc->n* thread   /nthread;
4214             end   = atc->n*(thread+1)/nthread;
4215
4216             /* Compute fftgrid index for all atoms,
4217              * with help of some extra variables.
4218              */
4219             calc_interpolation_idx(pme, atc, start, end, thread);
4220         }
4221     }
4222 #ifdef PME_TIME_THREADS
4223     c1   = omp_cyc_end(c1);
4224     cs1 += (double)c1;
4225 #endif
4226
4227 #ifdef PME_TIME_THREADS
4228     c2 = omp_cyc_start();
4229 #endif
4230 #pragma omp parallel for num_threads(nthread) schedule(static)
4231     for (thread = 0; thread < nthread; thread++)
4232     {
4233         splinedata_t *spline;
4234         pmegrid_t *grid = NULL;
4235
4236         /* make local bsplines  */
4237         if (grids == NULL || !pme->bUseThreads)
4238         {
4239             spline = &atc->spline[0];
4240
4241             spline->n = atc->n;
4242
4243             if (bSpread)
4244             {
4245                 grid = &grids->grid;
4246             }
4247         }
4248         else
4249         {
4250             spline = &atc->spline[thread];
4251
4252             if (grids->nthread == 1)
4253             {
4254                 /* One thread, we operate on all charges */
4255                 spline->n = atc->n;
4256             }
4257             else
4258             {
4259                 /* Get the indices our thread should operate on */
4260                 make_thread_local_ind(atc, thread, spline);
4261             }
4262
4263             grid = &grids->grid_th[thread];
4264         }
4265
4266         if (bCalcSplines)
4267         {
4268             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4269                           atc->fractx, spline->n, spline->ind, atc->q, bDoSplines);
4270         }
4271
4272         if (bSpread)
4273         {
4274             /* put local atoms on grid. */
4275 #ifdef PME_TIME_SPREAD
4276             ct1a = omp_cyc_start();
4277 #endif
4278             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4279
4280             if (pme->bUseThreads)
4281             {
4282                 copy_local_grid(pme, grids, thread, fftgrid);
4283             }
4284 #ifdef PME_TIME_SPREAD
4285             ct1a          = omp_cyc_end(ct1a);
4286             cs1a[thread] += (double)ct1a;
4287 #endif
4288         }
4289     }
4290 #ifdef PME_TIME_THREADS
4291     c2   = omp_cyc_end(c2);
4292     cs2 += (double)c2;
4293 #endif
4294
4295     if (bSpread && pme->bUseThreads)
4296     {
4297 #ifdef PME_TIME_THREADS
4298         c3 = omp_cyc_start();
4299 #endif
4300 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4301         for (thread = 0; thread < grids->nthread; thread++)
4302         {
4303             reduce_threadgrid_overlap(pme, grids, thread,
4304                                       fftgrid,
4305                                       pme->overlap[0].sendbuf,
4306                                       pme->overlap[1].sendbuf);
4307         }
4308 #ifdef PME_TIME_THREADS
4309         c3   = omp_cyc_end(c3);
4310         cs3 += (double)c3;
4311 #endif
4312
4313         if (pme->nnodes > 1)
4314         {
4315             /* Communicate the overlapping part of the fftgrid.
4316              * For this communication call we need to check pme->bUseThreads
4317              * to have all ranks communicate here, regardless of pme->nthread.
4318              */
4319             sum_fftgrid_dd(pme, fftgrid);
4320         }
4321     }
4322
4323 #ifdef PME_TIME_THREADS
4324     cnt++;
4325     if (cnt % 20 == 0)
4326     {
4327         printf("idx %.2f spread %.2f red %.2f",
4328                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4329 #ifdef PME_TIME_SPREAD
4330         for (thread = 0; thread < nthread; thread++)
4331         {
4332             printf(" %.2f", cs1a[thread]*1e-9);
4333         }
4334 #endif
4335         printf("\n");
4336     }
4337 #endif
4338 }
4339
4340
4341 static void dump_grid(FILE *fp,
4342                       int sx, int sy, int sz, int nx, int ny, int nz,
4343                       int my, int mz, const real *g)
4344 {
4345     int x, y, z;
4346
4347     for (x = 0; x < nx; x++)
4348     {
4349         for (y = 0; y < ny; y++)
4350         {
4351             for (z = 0; z < nz; z++)
4352             {
4353                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4354                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4355             }
4356         }
4357     }
4358 }
4359
4360 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4361 {
4362     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4363
4364     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4365                                    local_fft_ndata,
4366                                    local_fft_offset,
4367                                    local_fft_size);
4368
4369     dump_grid(stderr,
4370               pme->pmegrid_start_ix,
4371               pme->pmegrid_start_iy,
4372               pme->pmegrid_start_iz,
4373               pme->pmegrid_nx-pme->pme_order+1,
4374               pme->pmegrid_ny-pme->pme_order+1,
4375               pme->pmegrid_nz-pme->pme_order+1,
4376               local_fft_size[YY],
4377               local_fft_size[ZZ],
4378               fftgrid);
4379 }
4380
4381
4382 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4383 {
4384     pme_atomcomm_t *atc;
4385     pmegrids_t *grid;
4386
4387     if (pme->nnodes > 1)
4388     {
4389         gmx_incons("gmx_pme_calc_energy called in parallel");
4390     }
4391     if (pme->bFEP > 1)
4392     {
4393         gmx_incons("gmx_pme_calc_energy with free energy");
4394     }
4395
4396     atc            = &pme->atc_energy;
4397     atc->nthread   = 1;
4398     if (atc->spline == NULL)
4399     {
4400         snew(atc->spline, atc->nthread);
4401     }
4402     atc->nslab     = 1;
4403     atc->bSpread   = TRUE;
4404     atc->pme_order = pme->pme_order;
4405     atc->n         = n;
4406     pme_realloc_atomcomm_things(atc);
4407     atc->x         = x;
4408     atc->q         = q;
4409
4410     /* We only use the A-charges grid */
4411     grid = &pme->pmegrid[PME_GRID_QA];
4412
4413     /* Only calculate the spline coefficients, don't actually spread */
4414     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE);
4415
4416     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4417 }
4418
4419
4420 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4421                                    gmx_walltime_accounting_t walltime_accounting,
4422                                    t_nrnb *nrnb, t_inputrec *ir,
4423                                    gmx_int64_t step)
4424 {
4425     /* Reset all the counters related to performance over the run */
4426     wallcycle_stop(wcycle, ewcRUN);
4427     wallcycle_reset_all(wcycle);
4428     init_nrnb(nrnb);
4429     if (ir->nsteps >= 0)
4430     {
4431         /* ir->nsteps is not used here, but we update it for consistency */
4432         ir->nsteps -= step - ir->init_step;
4433     }
4434     ir->init_step = step;
4435     wallcycle_start(wcycle, ewcRUN);
4436     walltime_accounting_start(walltime_accounting);
4437 }
4438
4439
4440 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4441                                ivec grid_size,
4442                                t_commrec *cr, t_inputrec *ir,
4443                                gmx_pme_t *pme_ret)
4444 {
4445     int ind;
4446     gmx_pme_t pme = NULL;
4447
4448     ind = 0;
4449     while (ind < *npmedata)
4450     {
4451         pme = (*pmedata)[ind];
4452         if (pme->nkx == grid_size[XX] &&
4453             pme->nky == grid_size[YY] &&
4454             pme->nkz == grid_size[ZZ])
4455         {
4456             *pme_ret = pme;
4457
4458             return;
4459         }
4460
4461         ind++;
4462     }
4463
4464     (*npmedata)++;
4465     srenew(*pmedata, *npmedata);
4466
4467     /* Generate a new PME data structure, copying part of the old pointers */
4468     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4469
4470     *pme_ret = (*pmedata)[ind];
4471 }
4472
4473 int gmx_pmeonly(gmx_pme_t pme,
4474                 t_commrec *cr,    t_nrnb *mynrnb,
4475                 gmx_wallcycle_t wcycle,
4476                 gmx_walltime_accounting_t walltime_accounting,
4477                 real ewaldcoeff_q, real ewaldcoeff_lj,
4478                 t_inputrec *ir)
4479 {
4480     int npmedata;
4481     gmx_pme_t *pmedata;
4482     gmx_pme_pp_t pme_pp;
4483     int  ret;
4484     int  natoms;
4485     matrix box;
4486     rvec *x_pp      = NULL, *f_pp = NULL;
4487     real *chargeA   = NULL, *chargeB = NULL;
4488     real *c6A       = NULL, *c6B = NULL;
4489     real *sigmaA    = NULL, *sigmaB = NULL;
4490     real lambda_q   = 0;
4491     real lambda_lj  = 0;
4492     int  maxshift_x = 0, maxshift_y = 0;
4493     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4494     matrix vir_q, vir_lj;
4495     float cycles;
4496     int  count;
4497     gmx_bool bEnerVir;
4498     int pme_flags;
4499     gmx_int64_t step, step_rel;
4500     ivec grid_switch;
4501
4502     /* This data will only use with PME tuning, i.e. switching PME grids */
4503     npmedata = 1;
4504     snew(pmedata, npmedata);
4505     pmedata[0] = pme;
4506
4507     pme_pp = gmx_pme_pp_init(cr);
4508
4509     init_nrnb(mynrnb);
4510
4511     count = 0;
4512     do /****** this is a quasi-loop over time steps! */
4513     {
4514         /* The reason for having a loop here is PME grid tuning/switching */
4515         do
4516         {
4517             /* Domain decomposition */
4518             ret = gmx_pme_recv_params_coords(pme_pp,
4519                                              &natoms,
4520                                              &chargeA, &chargeB,
4521                                              &c6A, &c6B,
4522                                              &sigmaA, &sigmaB,
4523                                              box, &x_pp, &f_pp,
4524                                              &maxshift_x, &maxshift_y,
4525                                              &pme->bFEP_q, &pme->bFEP_lj,
4526                                              &lambda_q, &lambda_lj,
4527                                              &bEnerVir,
4528                                              &pme_flags,
4529                                              &step,
4530                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4531
4532             if (ret == pmerecvqxSWITCHGRID)
4533             {
4534                 /* Switch the PME grid to grid_switch */
4535                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4536             }
4537
4538             if (ret == pmerecvqxRESETCOUNTERS)
4539             {
4540                 /* Reset the cycle and flop counters */
4541                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4542             }
4543         }
4544         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4545
4546         if (ret == pmerecvqxFINISH)
4547         {
4548             /* We should stop: break out of the loop */
4549             break;
4550         }
4551
4552         step_rel = step - ir->init_step;
4553
4554         if (count == 0)
4555         {
4556             wallcycle_start(wcycle, ewcRUN);
4557             walltime_accounting_start(walltime_accounting);
4558         }
4559
4560         wallcycle_start(wcycle, ewcPMEMESH);
4561
4562         dvdlambda_q  = 0;
4563         dvdlambda_lj = 0;
4564         clear_mat(vir_q);
4565         clear_mat(vir_lj);
4566
4567         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4568                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4569                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4570                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4571                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4572                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4573
4574         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4575
4576         gmx_pme_send_force_vir_ener(pme_pp,
4577                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4578                                     dvdlambda_q, dvdlambda_lj, cycles);
4579
4580         count++;
4581     } /***** end of quasi-loop, we stop with the break above */
4582     while (TRUE);
4583
4584     walltime_accounting_end(walltime_accounting);
4585
4586     return 0;
4587 }
4588
4589 static void
4590 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4591 {
4592     int  i;
4593
4594     for (i = 0; i < pme->atc[0].n; ++i)
4595     {
4596         real sigma4;
4597
4598         sigma4           = local_sigma[i];
4599         sigma4           = sigma4*sigma4;
4600         sigma4           = sigma4*sigma4;
4601         pme->atc[0].q[i] = local_c6[i] / sigma4;
4602     }
4603 }
4604
4605 static void
4606 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4607 {
4608     int  i;
4609
4610     for (i = 0; i < pme->atc[0].n; ++i)
4611     {
4612         pme->atc[0].q[i] *= local_sigma[i];
4613     }
4614 }
4615
4616 static void
4617 do_redist_x_q(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4618               gmx_bool bFirst, rvec x[], real *data)
4619 {
4620     int      d;
4621     pme_atomcomm_t *atc;
4622     atc = &pme->atc[0];
4623
4624     for (d = pme->ndecompdim - 1; d >= 0; d--)
4625     {
4626         int             n_d;
4627         rvec           *x_d;
4628         real           *q_d;
4629
4630         if (d == pme->ndecompdim - 1)
4631         {
4632             n_d = homenr;
4633             x_d = x + start;
4634             q_d = data;
4635         }
4636         else
4637         {
4638             n_d = pme->atc[d + 1].n;
4639             x_d = atc->x;
4640             q_d = atc->q;
4641         }
4642         atc      = &pme->atc[d];
4643         atc->npd = n_d;
4644         if (atc->npd > atc->pd_nalloc)
4645         {
4646             atc->pd_nalloc = over_alloc_dd(atc->npd);
4647             srenew(atc->pd, atc->pd_nalloc);
4648         }
4649         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4650         where();
4651         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4652         if (DOMAINDECOMP(cr))
4653         {
4654             dd_pmeredist_x_q(pme, n_d, bFirst, x_d, q_d, atc);
4655         }
4656     }
4657 }
4658
4659 /* TODO: when adding free-energy support, sigmaB will no longer be
4660  * unused */
4661 int gmx_pme_do(gmx_pme_t pme,
4662                int start,       int homenr,
4663                rvec x[],        rvec f[],
4664                real *chargeA,   real *chargeB,
4665                real *c6A,       real *c6B,
4666                real *sigmaA,    real gmx_unused *sigmaB,
4667                matrix box, t_commrec *cr,
4668                int  maxshift_x, int maxshift_y,
4669                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4670                matrix vir_q,      real ewaldcoeff_q,
4671                matrix vir_lj,   real ewaldcoeff_lj,
4672                real *energy_q,  real *energy_lj,
4673                real lambda_q, real lambda_lj,
4674                real *dvdlambda_q, real *dvdlambda_lj,
4675                int flags)
4676 {
4677     int     q, d, i, j, ntot, npme, qmax;
4678     int     nx, ny, nz;
4679     int     n_d, local_ny;
4680     pme_atomcomm_t *atc = NULL;
4681     pmegrids_t *pmegrid = NULL;
4682     real    *grid       = NULL;
4683     real    *ptr;
4684     rvec    *x_d, *f_d;
4685     real    *charge = NULL, *q_d;
4686     real    energy_AB[4];
4687     matrix  vir_AB[4];
4688     real    scale, lambda;
4689     gmx_bool bClearF;
4690     gmx_parallel_3dfft_t pfft_setup;
4691     real *  fftgrid;
4692     t_complex * cfftgrid;
4693     int     thread;
4694     gmx_bool bFirst, bDoSplines;
4695     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4696     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4697
4698     assert(pme->nnodes > 0);
4699     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4700
4701     if (pme->nnodes > 1)
4702     {
4703         atc      = &pme->atc[0];
4704         atc->npd = homenr;
4705         if (atc->npd > atc->pd_nalloc)
4706         {
4707             atc->pd_nalloc = over_alloc_dd(atc->npd);
4708             srenew(atc->pd, atc->pd_nalloc);
4709         }
4710         for (d = pme->ndecompdim-1; d >= 0; d--)
4711         {
4712             atc           = &pme->atc[d];
4713             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4714         }
4715     }
4716     else
4717     {
4718         atc = &pme->atc[0];
4719         /* This could be necessary for TPI */
4720         pme->atc[0].n = homenr;
4721         if (DOMAINDECOMP(cr))
4722         {
4723             pme_realloc_atomcomm_things(atc);
4724         }
4725         atc->x = x;
4726         atc->f = f;
4727     }
4728
4729     m_inv_ur0(box, pme->recipbox);
4730     bFirst = TRUE;
4731
4732     /* For simplicity, we construct the splines for all particles if
4733      * more than one PME calculations is needed. Some optimization
4734      * could be done by keeping track of which atoms have splines
4735      * constructed, and construct new splines on each pass for atoms
4736      * that don't yet have them.
4737      */
4738
4739     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4740
4741     /* We need a maximum of four separate PME calculations:
4742      * q=0: Coulomb PME with charges from state A
4743      * q=1: Coulomb PME with charges from state B
4744      * q=2: LJ PME with C6 from state A
4745      * q=3: LJ PME with C6 from state B
4746      * For Lorentz-Berthelot combination rules, a separate loop is used to
4747      * calculate all the terms
4748      */
4749
4750     /* If we are doing LJ-PME with LB, we only do Q here */
4751     qmax = (flags & GMX_PME_LJ_LB) ? DO_Q : DO_Q_AND_LJ;
4752
4753     for (q = 0; q < qmax; ++q)
4754     {
4755         /* Check if we should do calculations at this q
4756          * If q is odd we should be doing FEP
4757          * If q < 2 we should be doing electrostatic PME
4758          * If q >= 2 we should be doing LJ-PME
4759          */
4760         if ((!pme->bFEP_q && q == 1)
4761             || (!pme->bFEP_lj && q == 3)
4762             || (!(flags & GMX_PME_DO_COULOMB) && q < 2)
4763             || (!(flags & GMX_PME_DO_LJ) && q >= 2))
4764         {
4765             continue;
4766         }
4767         /* Unpack structure */
4768         pmegrid    = &pme->pmegrid[q];
4769         fftgrid    = pme->fftgrid[q];
4770         cfftgrid   = pme->cfftgrid[q];
4771         pfft_setup = pme->pfft_setup[q];
4772         switch (q)
4773         {
4774             case 0: charge = chargeA + start; break;
4775             case 1: charge = chargeB + start; break;
4776             case 2: charge = c6A + start; break;
4777             case 3: charge = c6B + start; break;
4778         }
4779
4780         grid = pmegrid->grid.grid;
4781
4782         if (debug)
4783         {
4784             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4785                     cr->nnodes, cr->nodeid);
4786             fprintf(debug, "Grid = %p\n", (void*)grid);
4787             if (grid == NULL)
4788             {
4789                 gmx_fatal(FARGS, "No grid!");
4790             }
4791         }
4792         where();
4793
4794         if (pme->nnodes == 1)
4795         {
4796             atc->q = charge;
4797         }
4798         else
4799         {
4800             wallcycle_start(wcycle, ewcPME_REDISTXF);
4801             do_redist_x_q(pme, cr, start, homenr, bFirst, x, charge);
4802             where();
4803
4804             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4805         }
4806
4807         if (debug)
4808         {
4809             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4810                     cr->nodeid, atc->n);
4811         }
4812
4813         if (flags & GMX_PME_SPREAD_Q)
4814         {
4815             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4816
4817             /* Spread the charges on a grid */
4818             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines);
4819
4820             if (bFirst)
4821             {
4822                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4823             }
4824             inc_nrnb(nrnb, eNR_SPREADQBSP,
4825                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4826
4827             if (!pme->bUseThreads)
4828             {
4829                 wrap_periodic_pmegrid(pme, grid);
4830
4831                 /* sum contributions to local grid from other nodes */
4832 #ifdef GMX_MPI
4833                 if (pme->nnodes > 1)
4834                 {
4835                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4836                     where();
4837                 }
4838 #endif
4839
4840                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, q);
4841             }
4842
4843             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4844
4845             /*
4846                dump_local_fftgrid(pme,fftgrid);
4847                exit(0);
4848              */
4849         }
4850
4851         /* Here we start a large thread parallel region */
4852 #pragma omp parallel num_threads(pme->nthread) private(thread)
4853         {
4854             thread = gmx_omp_get_thread_num();
4855             if (flags & GMX_PME_SOLVE)
4856             {
4857                 int loop_count;
4858
4859                 /* do 3d-fft */
4860                 if (thread == 0)
4861                 {
4862                     wallcycle_start(wcycle, ewcPME_FFT);
4863                 }
4864                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4865                                            thread, wcycle);
4866                 if (thread == 0)
4867                 {
4868                     wallcycle_stop(wcycle, ewcPME_FFT);
4869                 }
4870                 where();
4871
4872                 /* solve in k-space for our local cells */
4873                 if (thread == 0)
4874                 {
4875                     wallcycle_start(wcycle, (q < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4876                 }
4877                 if (q < DO_Q)
4878                 {
4879                     loop_count =
4880                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
4881                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4882                                       bCalcEnerVir,
4883                                       pme->nthread, thread);
4884                 }
4885                 else
4886                 {
4887                     loop_count =
4888                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
4889                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4890                                          bCalcEnerVir,
4891                                          pme->nthread, thread);
4892                 }
4893
4894                 if (thread == 0)
4895                 {
4896                     wallcycle_stop(wcycle, (q < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4897                     where();
4898                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4899                 }
4900             }
4901
4902             if (bCalcF)
4903             {
4904                 /* do 3d-invfft */
4905                 if (thread == 0)
4906                 {
4907                     where();
4908                     wallcycle_start(wcycle, ewcPME_FFT);
4909                 }
4910                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4911                                            thread, wcycle);
4912                 if (thread == 0)
4913                 {
4914                     wallcycle_stop(wcycle, ewcPME_FFT);
4915
4916                     where();
4917
4918                     if (pme->nodeid == 0)
4919                     {
4920                         ntot  = pme->nkx*pme->nky*pme->nkz;
4921                         npme  = ntot*log((real)ntot)/log(2.0);
4922                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4923                     }
4924
4925                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4926                 }
4927
4928                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, q, pme->nthread, thread);
4929             }
4930         }
4931         /* End of thread parallel section.
4932          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4933          */
4934
4935         if (bCalcF)
4936         {
4937             /* distribute local grid to all nodes */
4938 #ifdef GMX_MPI
4939             if (pme->nnodes > 1)
4940             {
4941                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4942             }
4943 #endif
4944             where();
4945
4946             unwrap_periodic_pmegrid(pme, grid);
4947
4948             /* interpolate forces for our local atoms */
4949
4950             where();
4951
4952             /* If we are running without parallelization,
4953              * atc->f is the actual force array, not a buffer,
4954              * therefore we should not clear it.
4955              */
4956             lambda  = q < DO_Q ? lambda_q : lambda_lj;
4957             bClearF = (bFirst && PAR(cr));
4958 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4959             for (thread = 0; thread < pme->nthread; thread++)
4960             {
4961                 gather_f_bsplines(pme, grid, bClearF, atc,
4962                                   &atc->spline[thread],
4963                                   pme->bFEP ? (q % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
4964             }
4965
4966             where();
4967
4968             inc_nrnb(nrnb, eNR_GATHERFBSP,
4969                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4970             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4971         }
4972
4973         if (bCalcEnerVir)
4974         {
4975             /* This should only be called on the master thread
4976              * and after the threads have synchronized.
4977              */
4978             if (q < 2)
4979             {
4980                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4981             }
4982             else
4983             {
4984                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4985             }
4986         }
4987         bFirst = FALSE;
4988     } /* of q-loop */
4989
4990     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
4991      * seven terms. */
4992
4993     if (flags & GMX_PME_LJ_LB)
4994     {
4995         real *local_c6 = NULL, *local_sigma = NULL;
4996         if (pme->nnodes == 1)
4997         {
4998             if (pme->lb_buf1 == NULL)
4999             {
5000                 pme->lb_buf_nalloc = pme->atc[0].n;
5001                 snew(pme->lb_buf1, pme->lb_buf_nalloc);
5002             }
5003             pme->atc[0].q = pme->lb_buf1;
5004             local_c6      = c6A;
5005             local_sigma   = sigmaA;
5006         }
5007         else
5008         {
5009             atc = &pme->atc[0];
5010
5011             wallcycle_start(wcycle, ewcPME_REDISTXF);
5012
5013             do_redist_x_q(pme, cr, start, homenr, bFirst, x, c6A);
5014             if (pme->lb_buf_nalloc < atc->n)
5015             {
5016                 pme->lb_buf_nalloc = atc->nalloc;
5017                 srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5018                 srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5019             }
5020             local_c6 = pme->lb_buf1;
5021             for (i = 0; i < atc->n; ++i)
5022             {
5023                 local_c6[i] = atc->q[i];
5024             }
5025             where();
5026
5027             do_redist_x_q(pme, cr, start, homenr, FALSE, x, sigmaA);
5028             local_sigma = pme->lb_buf2;
5029             for (i = 0; i < atc->n; ++i)
5030             {
5031                 local_sigma[i] = atc->q[i];
5032             }
5033             where();
5034
5035             wallcycle_stop(wcycle, ewcPME_REDISTXF);
5036         }
5037         calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5038
5039         /*Seven terms in LJ-PME with LB, q < 2 reserved for electrostatics*/
5040         for (q = 2; q < 9; ++q)
5041         {
5042             /* Unpack structure */
5043             pmegrid    = &pme->pmegrid[q];
5044             fftgrid    = pme->fftgrid[q];
5045             cfftgrid   = pme->cfftgrid[q];
5046             pfft_setup = pme->pfft_setup[q];
5047             calc_next_lb_coeffs(pme, local_sigma);
5048             grid = pmegrid->grid.grid;
5049             where();
5050
5051             if (flags & GMX_PME_SPREAD_Q)
5052             {
5053                 wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5054                 /* Spread the charges on a grid */
5055                 spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines);
5056
5057                 if (bFirst)
5058                 {
5059                     inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5060                 }
5061                 inc_nrnb(nrnb, eNR_SPREADQBSP,
5062                          pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5063                 if (pme->nthread == 1)
5064                 {
5065                     wrap_periodic_pmegrid(pme, grid);
5066
5067                     /* sum contributions to local grid from other nodes */
5068 #ifdef GMX_MPI
5069                     if (pme->nnodes > 1)
5070                     {
5071                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
5072                         where();
5073                     }
5074 #endif
5075                     copy_pmegrid_to_fftgrid(pme, grid, fftgrid, q);
5076                 }
5077
5078                 wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5079             }
5080
5081             /*Here we start a large thread parallel region*/
5082 #pragma omp parallel num_threads(pme->nthread) private(thread)
5083             {
5084                 thread = gmx_omp_get_thread_num();
5085                 if (flags & GMX_PME_SOLVE)
5086                 {
5087                     /* do 3d-fft */
5088                     if (thread == 0)
5089                     {
5090                         wallcycle_start(wcycle, ewcPME_FFT);
5091                     }
5092
5093                     gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5094                                                thread, wcycle);
5095                     if (thread == 0)
5096                     {
5097                         wallcycle_stop(wcycle, ewcPME_FFT);
5098                     }
5099                     where();
5100                 }
5101             }
5102             bFirst = FALSE;
5103         }
5104         if (flags & GMX_PME_SOLVE)
5105         {
5106             int loop_count;
5107             /* solve in k-space for our local cells */
5108 #pragma omp parallel num_threads(pme->nthread) private(thread)
5109             {
5110                 thread = gmx_omp_get_thread_num();
5111                 if (thread == 0)
5112                 {
5113                     wallcycle_start(wcycle, ewcLJPME);
5114                 }
5115
5116                 loop_count =
5117                     solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5118                                      box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5119                                      bCalcEnerVir,
5120                                      pme->nthread, thread);
5121                 if (thread == 0)
5122                 {
5123                     wallcycle_stop(wcycle, ewcLJPME);
5124                     where();
5125                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5126                 }
5127             }
5128         }
5129
5130         if (bCalcEnerVir)
5131         {
5132             /* This should only be called on the master thread and
5133              * after the threads have synchronized.
5134              */
5135             get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2], vir_AB[2]);
5136         }
5137
5138         if (bCalcF)
5139         {
5140             bFirst = !(flags & GMX_PME_DO_COULOMB);
5141             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5142             for (q = 8; q >= 2; --q)
5143             {
5144                 /* Unpack structure */
5145                 pmegrid    = &pme->pmegrid[q];
5146                 fftgrid    = pme->fftgrid[q];
5147                 cfftgrid   = pme->cfftgrid[q];
5148                 pfft_setup = pme->pfft_setup[q];
5149                 grid       = pmegrid->grid.grid;
5150                 calc_next_lb_coeffs(pme, local_sigma);
5151                 where();
5152 #pragma omp parallel num_threads(pme->nthread) private(thread)
5153                 {
5154                     thread = gmx_omp_get_thread_num();
5155                     /* do 3d-invfft */
5156                     if (thread == 0)
5157                     {
5158                         where();
5159                         wallcycle_start(wcycle, ewcPME_FFT);
5160                     }
5161
5162                     gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5163                                                thread, wcycle);
5164                     if (thread == 0)
5165                     {
5166                         wallcycle_stop(wcycle, ewcPME_FFT);
5167
5168                         where();
5169
5170                         if (pme->nodeid == 0)
5171                         {
5172                             ntot  = pme->nkx*pme->nky*pme->nkz;
5173                             npme  = ntot*log((real)ntot)/log(2.0);
5174                             inc_nrnb(nrnb, eNR_FFT, 2*npme);
5175                         }
5176                         wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5177                     }
5178
5179                     copy_fftgrid_to_pmegrid(pme, fftgrid, grid, q, pme->nthread, thread);
5180
5181                 } /*#pragma omp parallel*/
5182
5183                 /* distribute local grid to all nodes */
5184 #ifdef GMX_MPI
5185                 if (pme->nnodes > 1)
5186                 {
5187                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
5188                 }
5189 #endif
5190                 where();
5191
5192                 unwrap_periodic_pmegrid(pme, grid);
5193
5194                 /* interpolate forces for our local atoms */
5195                 where();
5196                 bClearF = (bFirst && PAR(cr));
5197                 scale   = pme->bFEP ? (q < 9 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5198                 scale  *= lb_scale_factor[q-2];
5199 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5200                 for (thread = 0; thread < pme->nthread; thread++)
5201                 {
5202                     gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5203                                       &pme->atc[0].spline[thread],
5204                                       scale);
5205                 }
5206                 where();
5207
5208                 inc_nrnb(nrnb, eNR_GATHERFBSP,
5209                          pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5210                 wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5211
5212                 bFirst = FALSE;
5213             } /*for (q = 8; q >= 2; --q)*/
5214         }     /*if (bCalcF)*/
5215     }         /*if (flags & GMX_PME_LJ_LB)*/
5216
5217     if (bCalcF && pme->nnodes > 1)
5218     {
5219         wallcycle_start(wcycle, ewcPME_REDISTXF);
5220         for (d = 0; d < pme->ndecompdim; d++)
5221         {
5222             atc = &pme->atc[d];
5223             if (d == pme->ndecompdim - 1)
5224             {
5225                 n_d = homenr;
5226                 f_d = f + start;
5227             }
5228             else
5229             {
5230                 n_d = pme->atc[d+1].n;
5231                 f_d = pme->atc[d+1].f;
5232             }
5233             if (DOMAINDECOMP(cr))
5234             {
5235                 dd_pmeredist_f(pme, atc, n_d, f_d,
5236                                d == pme->ndecompdim-1 && pme->bPPnode);
5237             }
5238         }
5239
5240         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5241     }
5242     where();
5243
5244     if (bCalcEnerVir)
5245     {
5246         if (flags & GMX_PME_DO_COULOMB)
5247         {
5248             if (!pme->bFEP_q)
5249             {
5250                 *energy_q = energy_AB[0];
5251                 m_add(vir_q, vir_AB[0], vir_q);
5252             }
5253             else
5254             {
5255                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5256                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5257                 for (i = 0; i < DIM; i++)
5258                 {
5259                     for (j = 0; j < DIM; j++)
5260                     {
5261                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5262                             lambda_q*vir_AB[1][i][j];
5263                     }
5264                 }
5265             }
5266             if (debug)
5267             {
5268                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5269             }
5270         }
5271         else
5272         {
5273             *energy_q = 0;
5274         }
5275
5276         if (flags & GMX_PME_DO_LJ)
5277         {
5278             if (!pme->bFEP_lj)
5279             {
5280                 *energy_lj = energy_AB[2];
5281                 m_add(vir_lj, vir_AB[2], vir_lj);
5282             }
5283             else
5284             {
5285                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5286                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5287                 for (i = 0; i < DIM; i++)
5288                 {
5289                     for (j = 0; j < DIM; j++)
5290                     {
5291                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5292                     }
5293                 }
5294             }
5295             if (debug)
5296             {
5297                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5298             }
5299         }
5300         else
5301         {
5302             *energy_lj = 0;
5303         }
5304     }
5305     return 0;
5306 }