Sort all includes in src/gromacs
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #include "gmxpre.h"
61
62 #include "gromacs/legacyheaders/pme.h"
63
64 #include "config.h"
65
66 #include <assert.h>
67 #include <math.h>
68 #include <stdio.h>
69 #include <stdlib.h>
70 #include <string.h>
71
72 #include "gromacs/fft/parallel_3dfft.h"
73 #include "gromacs/fileio/pdbio.h"
74 #include "gromacs/legacyheaders/coulomb.h"
75 #include "gromacs/legacyheaders/macros.h"
76 #include "gromacs/legacyheaders/network.h"
77 #include "gromacs/legacyheaders/nrnb.h"
78 #include "gromacs/legacyheaders/txtdump.h"
79 #include "gromacs/legacyheaders/typedefs.h"
80 #include "gromacs/legacyheaders/types/commrec.h"
81 #include "gromacs/math/gmxcomplex.h"
82 #include "gromacs/math/units.h"
83 #include "gromacs/math/vec.h"
84 #include "gromacs/timing/cyclecounter.h"
85 #include "gromacs/timing/wallcycle.h"
86 #include "gromacs/utility/fatalerror.h"
87 #include "gromacs/utility/futil.h"
88 #include "gromacs/utility/gmxmpi.h"
89 #include "gromacs/utility/gmxomp.h"
90 #include "gromacs/utility/smalloc.h"
91
92 /* Include the SIMD macro file and then check for support */
93 #include "gromacs/simd/simd.h"
94 #include "gromacs/simd/simd_math.h"
95 #ifdef GMX_SIMD_HAVE_REAL
96 /* Turn on arbitrary width SIMD intrinsics for PME solve */
97 #    define PME_SIMD_SOLVE
98 #endif
99
100 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
101 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
102 #define DO_Q           2 /* Electrostatic grids have index q<2 */
103 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
104 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
105
106 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
107 const real lb_scale_factor[] = {
108     1.0/64, 6.0/64, 15.0/64, 20.0/64,
109     15.0/64, 6.0/64, 1.0/64
110 };
111
112 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
113 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
114
115 /* Check if we have 4-wide SIMD macro support */
116 #if (defined GMX_SIMD4_HAVE_REAL)
117 /* Do PME spread and gather with 4-wide SIMD.
118  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
119  */
120 #    define PME_SIMD4_SPREAD_GATHER
121
122 #    if (defined GMX_SIMD_HAVE_LOADU) && (defined GMX_SIMD_HAVE_STOREU)
123 /* With PME-order=4 on x86, unaligned load+store is slightly faster
124  * than doubling all SIMD operations when using aligned load+store.
125  */
126 #        define PME_SIMD4_UNALIGNED
127 #    endif
128 #endif
129
130 #define DFT_TOL 1e-7
131 /* #define PRT_FORCE */
132 /* conditions for on the fly time-measurement */
133 /* #define TAKETIME (step > 1 && timesteps < 10) */
134 #define TAKETIME FALSE
135
136 /* #define PME_TIME_THREADS */
137
138 #ifdef GMX_DOUBLE
139 #define mpi_type MPI_DOUBLE
140 #else
141 #define mpi_type MPI_FLOAT
142 #endif
143
144 #ifdef PME_SIMD4_SPREAD_GATHER
145 #    define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
146 #else
147 /* We can use any alignment, apart from 0, so we use 4 reals */
148 #    define SIMD4_ALIGNMENT  (4*sizeof(real))
149 #endif
150
151 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
152  * to preserve alignment.
153  */
154 #define GMX_CACHE_SEP 64
155
156 /* We only define a maximum to be able to use local arrays without allocation.
157  * An order larger than 12 should never be needed, even for test cases.
158  * If needed it can be changed here.
159  */
160 #define PME_ORDER_MAX 12
161
162 /* Internal datastructures */
163 typedef struct {
164     int send_index0;
165     int send_nindex;
166     int recv_index0;
167     int recv_nindex;
168     int recv_size;   /* Receive buffer width, used with OpenMP */
169 } pme_grid_comm_t;
170
171 typedef struct {
172 #ifdef GMX_MPI
173     MPI_Comm         mpi_comm;
174 #endif
175     int              nnodes, nodeid;
176     int             *s2g0;
177     int             *s2g1;
178     int              noverlap_nodes;
179     int             *send_id, *recv_id;
180     int              send_size; /* Send buffer width, used with OpenMP */
181     pme_grid_comm_t *comm_data;
182     real            *sendbuf;
183     real            *recvbuf;
184 } pme_overlap_t;
185
186 typedef struct {
187     int *n;      /* Cumulative counts of the number of particles per thread */
188     int  nalloc; /* Allocation size of i */
189     int *i;      /* Particle indices ordered on thread index (n) */
190 } thread_plist_t;
191
192 typedef struct {
193     int      *thread_one;
194     int       n;
195     int      *ind;
196     splinevec theta;
197     real     *ptr_theta_z;
198     splinevec dtheta;
199     real     *ptr_dtheta_z;
200 } splinedata_t;
201
202 typedef struct {
203     int      dimind;        /* The index of the dimension, 0=x, 1=y */
204     int      nslab;
205     int      nodeid;
206 #ifdef GMX_MPI
207     MPI_Comm mpi_comm;
208 #endif
209
210     int     *node_dest;     /* The nodes to send x and q to with DD */
211     int     *node_src;      /* The nodes to receive x and q from with DD */
212     int     *buf_index;     /* Index for commnode into the buffers */
213
214     int      maxshift;
215
216     int      npd;
217     int      pd_nalloc;
218     int     *pd;
219     int     *count;         /* The number of atoms to send to each node */
220     int    **count_thread;
221     int     *rcount;        /* The number of atoms to receive */
222
223     int      n;
224     int      nalloc;
225     rvec    *x;
226     real    *coefficient;
227     rvec    *f;
228     gmx_bool bSpread;       /* These coordinates are used for spreading */
229     int      pme_order;
230     ivec    *idx;
231     rvec    *fractx;            /* Fractional coordinate relative to
232                                  * the lower cell boundary
233                                  */
234     int             nthread;
235     int            *thread_idx; /* Which thread should spread which coefficient */
236     thread_plist_t *thread_plist;
237     splinedata_t   *spline;
238 } pme_atomcomm_t;
239
240 #define FLBS  3
241 #define FLBSZ 4
242
243 typedef struct {
244     ivec  ci;     /* The spatial location of this grid         */
245     ivec  n;      /* The used size of *grid, including order-1 */
246     ivec  offset; /* The grid offset from the full node grid   */
247     int   order;  /* PME spreading order                       */
248     ivec  s;      /* The allocated size of *grid, s >= n       */
249     real *grid;   /* The grid local thread, size n             */
250 } pmegrid_t;
251
252 typedef struct {
253     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
254     int        nthread;      /* The number of threads operating on this grid     */
255     ivec       nc;           /* The local spatial decomposition over the threads */
256     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
257     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
258     int      **g2t;          /* The grid to thread index                         */
259     ivec       nthread_comm; /* The number of threads to communicate with        */
260 } pmegrids_t;
261
262 typedef struct {
263 #ifdef PME_SIMD4_SPREAD_GATHER
264     /* Masks for 4-wide SIMD aligned spreading and gathering */
265     gmx_simd4_bool_t mask_S0[6], mask_S1[6];
266 #else
267     int              dummy; /* C89 requires that struct has at least one member */
268 #endif
269 } pme_spline_work_t;
270
271 typedef struct {
272     /* work data for solve_pme */
273     int      nalloc;
274     real *   mhx;
275     real *   mhy;
276     real *   mhz;
277     real *   m2;
278     real *   denom;
279     real *   tmp1_alloc;
280     real *   tmp1;
281     real *   tmp2;
282     real *   eterm;
283     real *   m2inv;
284
285     real     energy_q;
286     matrix   vir_q;
287     real     energy_lj;
288     matrix   vir_lj;
289 } pme_work_t;
290
291 typedef struct gmx_pme {
292     int           ndecompdim; /* The number of decomposition dimensions */
293     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
294     int           nodeid_major;
295     int           nodeid_minor;
296     int           nnodes;    /* The number of nodes doing PME */
297     int           nnodes_major;
298     int           nnodes_minor;
299
300     MPI_Comm      mpi_comm;
301     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
302 #ifdef GMX_MPI
303     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
304 #endif
305
306     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
307     int        nthread;       /* The number of threads doing PME on our rank */
308
309     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
310     gmx_bool   bFEP;          /* Compute Free energy contribution */
311     gmx_bool   bFEP_q;
312     gmx_bool   bFEP_lj;
313     int        nkx, nky, nkz; /* Grid dimensions */
314     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
315     int        pme_order;
316     real       epsilon_r;
317
318     int        ljpme_combination_rule;  /* Type of combination rule in LJ-PME */
319
320     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
321
322     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
323                                          * includes overlap Grid indices are ordered as
324                                          * follows:
325                                          * 0: Coloumb PME, state A
326                                          * 1: Coloumb PME, state B
327                                          * 2-8: LJ-PME
328                                          * This can probably be done in a better way
329                                          * but this simple hack works for now
330                                          */
331     /* The PME coefficient spreading grid sizes/strides, includes pme_order-1 */
332     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
333     /* pmegrid_nz might be larger than strictly necessary to ensure
334      * memory alignment, pmegrid_nz_base gives the real base size.
335      */
336     int     pmegrid_nz_base;
337     /* The local PME grid starting indices */
338     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
339
340     /* Work data for spreading and gathering */
341     pme_spline_work_t     *spline_work;
342
343     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
344     /* inside the interpolation grid, but separate for 2D PME decomp. */
345     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
346
347     t_complex            **cfftgrid;  /* Grids for complex FFT data */
348
349     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
350
351     gmx_parallel_3dfft_t  *pfft_setup;
352
353     int                   *nnx, *nny, *nnz;
354     real                  *fshx, *fshy, *fshz;
355
356     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
357     matrix                 recipbox;
358     splinevec              bsp_mod;
359     /* Buffers to store data for local atoms for L-B combination rule
360      * calculations in LJ-PME. lb_buf1 stores either the coefficients
361      * for spreading/gathering (in serial), or the C6 coefficient for
362      * local atoms (in parallel).  lb_buf2 is only used in parallel,
363      * and stores the sigma values for local atoms. */
364     real                 *lb_buf1, *lb_buf2;
365     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
366
367     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
368
369     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
370
371     rvec                 *bufv;          /* Communication buffer */
372     real                 *bufr;          /* Communication buffer */
373     int                   buf_nalloc;    /* The communication buffer size */
374
375     /* thread local work data for solve_pme */
376     pme_work_t *work;
377
378     /* Work data for sum_qgrid */
379     real *   sum_qgrid_tmp;
380     real *   sum_qgrid_dd_tmp;
381 } t_gmx_pme;
382
383 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
384                                    int start, int grid_index, int end, int thread)
385 {
386     int             i;
387     int            *idxptr, tix, tiy, tiz;
388     real           *xptr, *fptr, tx, ty, tz;
389     real            rxx, ryx, ryy, rzx, rzy, rzz;
390     int             nx, ny, nz;
391     int             start_ix, start_iy, start_iz;
392     int            *g2tx, *g2ty, *g2tz;
393     gmx_bool        bThreads;
394     int            *thread_idx = NULL;
395     thread_plist_t *tpl        = NULL;
396     int            *tpl_n      = NULL;
397     int             thread_i;
398
399     nx  = pme->nkx;
400     ny  = pme->nky;
401     nz  = pme->nkz;
402
403     start_ix = pme->pmegrid_start_ix;
404     start_iy = pme->pmegrid_start_iy;
405     start_iz = pme->pmegrid_start_iz;
406
407     rxx = pme->recipbox[XX][XX];
408     ryx = pme->recipbox[YY][XX];
409     ryy = pme->recipbox[YY][YY];
410     rzx = pme->recipbox[ZZ][XX];
411     rzy = pme->recipbox[ZZ][YY];
412     rzz = pme->recipbox[ZZ][ZZ];
413
414     g2tx = pme->pmegrid[grid_index].g2t[XX];
415     g2ty = pme->pmegrid[grid_index].g2t[YY];
416     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
417
418     bThreads = (atc->nthread > 1);
419     if (bThreads)
420     {
421         thread_idx = atc->thread_idx;
422
423         tpl   = &atc->thread_plist[thread];
424         tpl_n = tpl->n;
425         for (i = 0; i < atc->nthread; i++)
426         {
427             tpl_n[i] = 0;
428         }
429     }
430
431     for (i = start; i < end; i++)
432     {
433         xptr   = atc->x[i];
434         idxptr = atc->idx[i];
435         fptr   = atc->fractx[i];
436
437         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
438         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
439         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
440         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
441
442         tix = (int)(tx);
443         tiy = (int)(ty);
444         tiz = (int)(tz);
445
446         /* Because decomposition only occurs in x and y,
447          * we never have a fraction correction in z.
448          */
449         fptr[XX] = tx - tix + pme->fshx[tix];
450         fptr[YY] = ty - tiy + pme->fshy[tiy];
451         fptr[ZZ] = tz - tiz;
452
453         idxptr[XX] = pme->nnx[tix];
454         idxptr[YY] = pme->nny[tiy];
455         idxptr[ZZ] = pme->nnz[tiz];
456
457 #ifdef DEBUG
458         range_check(idxptr[XX], 0, pme->pmegrid_nx);
459         range_check(idxptr[YY], 0, pme->pmegrid_ny);
460         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
461 #endif
462
463         if (bThreads)
464         {
465             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
466             thread_idx[i] = thread_i;
467             tpl_n[thread_i]++;
468         }
469     }
470
471     if (bThreads)
472     {
473         /* Make a list of particle indices sorted on thread */
474
475         /* Get the cumulative count */
476         for (i = 1; i < atc->nthread; i++)
477         {
478             tpl_n[i] += tpl_n[i-1];
479         }
480         /* The current implementation distributes particles equally
481          * over the threads, so we could actually allocate for that
482          * in pme_realloc_atomcomm_things.
483          */
484         if (tpl_n[atc->nthread-1] > tpl->nalloc)
485         {
486             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
487             srenew(tpl->i, tpl->nalloc);
488         }
489         /* Set tpl_n to the cumulative start */
490         for (i = atc->nthread-1; i >= 1; i--)
491         {
492             tpl_n[i] = tpl_n[i-1];
493         }
494         tpl_n[0] = 0;
495
496         /* Fill our thread local array with indices sorted on thread */
497         for (i = start; i < end; i++)
498         {
499             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
500         }
501         /* Now tpl_n contains the cummulative count again */
502     }
503 }
504
505 static void make_thread_local_ind(pme_atomcomm_t *atc,
506                                   int thread, splinedata_t *spline)
507 {
508     int             n, t, i, start, end;
509     thread_plist_t *tpl;
510
511     /* Combine the indices made by each thread into one index */
512
513     n     = 0;
514     start = 0;
515     for (t = 0; t < atc->nthread; t++)
516     {
517         tpl = &atc->thread_plist[t];
518         /* Copy our part (start - end) from the list of thread t */
519         if (thread > 0)
520         {
521             start = tpl->n[thread-1];
522         }
523         end = tpl->n[thread];
524         for (i = start; i < end; i++)
525         {
526             spline->ind[n++] = tpl->i[i];
527         }
528     }
529
530     spline->n = n;
531 }
532
533
534 static void pme_calc_pidx(int start, int end,
535                           matrix recipbox, rvec x[],
536                           pme_atomcomm_t *atc, int *count)
537 {
538     int   nslab, i;
539     int   si;
540     real *xptr, s;
541     real  rxx, ryx, rzx, ryy, rzy;
542     int  *pd;
543
544     /* Calculate PME task index (pidx) for each grid index.
545      * Here we always assign equally sized slabs to each node
546      * for load balancing reasons (the PME grid spacing is not used).
547      */
548
549     nslab = atc->nslab;
550     pd    = atc->pd;
551
552     /* Reset the count */
553     for (i = 0; i < nslab; i++)
554     {
555         count[i] = 0;
556     }
557
558     if (atc->dimind == 0)
559     {
560         rxx = recipbox[XX][XX];
561         ryx = recipbox[YY][XX];
562         rzx = recipbox[ZZ][XX];
563         /* Calculate the node index in x-dimension */
564         for (i = start; i < end; i++)
565         {
566             xptr   = x[i];
567             /* Fractional coordinates along box vectors */
568             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
569             si    = (int)(s + 2*nslab) % nslab;
570             pd[i] = si;
571             count[si]++;
572         }
573     }
574     else
575     {
576         ryy = recipbox[YY][YY];
577         rzy = recipbox[ZZ][YY];
578         /* Calculate the node index in y-dimension */
579         for (i = start; i < end; i++)
580         {
581             xptr   = x[i];
582             /* Fractional coordinates along box vectors */
583             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
584             si    = (int)(s + 2*nslab) % nslab;
585             pd[i] = si;
586             count[si]++;
587         }
588     }
589 }
590
591 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
592                                   pme_atomcomm_t *atc)
593 {
594     int nthread, thread, slab;
595
596     nthread = atc->nthread;
597
598 #pragma omp parallel for num_threads(nthread) schedule(static)
599     for (thread = 0; thread < nthread; thread++)
600     {
601         pme_calc_pidx(natoms* thread   /nthread,
602                       natoms*(thread+1)/nthread,
603                       recipbox, x, atc, atc->count_thread[thread]);
604     }
605     /* Non-parallel reduction, since nslab is small */
606
607     for (thread = 1; thread < nthread; thread++)
608     {
609         for (slab = 0; slab < atc->nslab; slab++)
610         {
611             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
612         }
613     }
614 }
615
616 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
617 {
618     const int padding = 4;
619     int       i;
620
621     srenew(th[XX], nalloc);
622     srenew(th[YY], nalloc);
623     /* In z we add padding, this is only required for the aligned SIMD code */
624     sfree_aligned(*ptr_z);
625     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
626     th[ZZ] = *ptr_z + padding;
627
628     for (i = 0; i < padding; i++)
629     {
630         (*ptr_z)[               i] = 0;
631         (*ptr_z)[padding+nalloc+i] = 0;
632     }
633 }
634
635 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
636 {
637     int i, d;
638
639     srenew(spline->ind, atc->nalloc);
640     /* Initialize the index to identity so it works without threads */
641     for (i = 0; i < atc->nalloc; i++)
642     {
643         spline->ind[i] = i;
644     }
645
646     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
647                       atc->pme_order*atc->nalloc);
648     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
649                       atc->pme_order*atc->nalloc);
650 }
651
652 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
653 {
654     int nalloc_old, i, j, nalloc_tpl;
655
656     /* We have to avoid a NULL pointer for atc->x to avoid
657      * possible fatal errors in MPI routines.
658      */
659     if (atc->n > atc->nalloc || atc->nalloc == 0)
660     {
661         nalloc_old  = atc->nalloc;
662         atc->nalloc = over_alloc_dd(max(atc->n, 1));
663
664         if (atc->nslab > 1)
665         {
666             srenew(atc->x, atc->nalloc);
667             srenew(atc->coefficient, atc->nalloc);
668             srenew(atc->f, atc->nalloc);
669             for (i = nalloc_old; i < atc->nalloc; i++)
670             {
671                 clear_rvec(atc->f[i]);
672             }
673         }
674         if (atc->bSpread)
675         {
676             srenew(atc->fractx, atc->nalloc);
677             srenew(atc->idx, atc->nalloc);
678
679             if (atc->nthread > 1)
680             {
681                 srenew(atc->thread_idx, atc->nalloc);
682             }
683
684             for (i = 0; i < atc->nthread; i++)
685             {
686                 pme_realloc_splinedata(&atc->spline[i], atc);
687             }
688         }
689     }
690 }
691
692 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
693                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
694                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
695                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
696 {
697 #ifdef GMX_MPI
698     int        dest, src;
699     MPI_Status stat;
700
701     if (bBackward == FALSE)
702     {
703         dest = atc->node_dest[shift];
704         src  = atc->node_src[shift];
705     }
706     else
707     {
708         dest = atc->node_src[shift];
709         src  = atc->node_dest[shift];
710     }
711
712     if (nbyte_s > 0 && nbyte_r > 0)
713     {
714         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
715                      dest, shift,
716                      buf_r, nbyte_r, MPI_BYTE,
717                      src, shift,
718                      atc->mpi_comm, &stat);
719     }
720     else if (nbyte_s > 0)
721     {
722         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
723                  dest, shift,
724                  atc->mpi_comm);
725     }
726     else if (nbyte_r > 0)
727     {
728         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
729                  src, shift,
730                  atc->mpi_comm, &stat);
731     }
732 #endif
733 }
734
735 static void dd_pmeredist_pos_coeffs(gmx_pme_t pme,
736                                     int n, gmx_bool bX, rvec *x, real *data,
737                                     pme_atomcomm_t *atc)
738 {
739     int *commnode, *buf_index;
740     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
741
742     commnode  = atc->node_dest;
743     buf_index = atc->buf_index;
744
745     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
746
747     nsend = 0;
748     for (i = 0; i < nnodes_comm; i++)
749     {
750         buf_index[commnode[i]] = nsend;
751         nsend                 += atc->count[commnode[i]];
752     }
753     if (bX)
754     {
755         if (atc->count[atc->nodeid] + nsend != n)
756         {
757             gmx_fatal(FARGS, "%d particles communicated to PME rank %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
758                       "This usually means that your system is not well equilibrated.",
759                       n - (atc->count[atc->nodeid] + nsend),
760                       pme->nodeid, 'x'+atc->dimind);
761         }
762
763         if (nsend > pme->buf_nalloc)
764         {
765             pme->buf_nalloc = over_alloc_dd(nsend);
766             srenew(pme->bufv, pme->buf_nalloc);
767             srenew(pme->bufr, pme->buf_nalloc);
768         }
769
770         atc->n = atc->count[atc->nodeid];
771         for (i = 0; i < nnodes_comm; i++)
772         {
773             scount = atc->count[commnode[i]];
774             /* Communicate the count */
775             if (debug)
776             {
777                 fprintf(debug, "dimind %d PME rank %d send to rank %d: %d\n",
778                         atc->dimind, atc->nodeid, commnode[i], scount);
779             }
780             pme_dd_sendrecv(atc, FALSE, i,
781                             &scount, sizeof(int),
782                             &atc->rcount[i], sizeof(int));
783             atc->n += atc->rcount[i];
784         }
785
786         pme_realloc_atomcomm_things(atc);
787     }
788
789     local_pos = 0;
790     for (i = 0; i < n; i++)
791     {
792         node = atc->pd[i];
793         if (node == atc->nodeid)
794         {
795             /* Copy direct to the receive buffer */
796             if (bX)
797             {
798                 copy_rvec(x[i], atc->x[local_pos]);
799             }
800             atc->coefficient[local_pos] = data[i];
801             local_pos++;
802         }
803         else
804         {
805             /* Copy to the send buffer */
806             if (bX)
807             {
808                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
809             }
810             pme->bufr[buf_index[node]] = data[i];
811             buf_index[node]++;
812         }
813     }
814
815     buf_pos = 0;
816     for (i = 0; i < nnodes_comm; i++)
817     {
818         scount = atc->count[commnode[i]];
819         rcount = atc->rcount[i];
820         if (scount > 0 || rcount > 0)
821         {
822             if (bX)
823             {
824                 /* Communicate the coordinates */
825                 pme_dd_sendrecv(atc, FALSE, i,
826                                 pme->bufv[buf_pos], scount*sizeof(rvec),
827                                 atc->x[local_pos], rcount*sizeof(rvec));
828             }
829             /* Communicate the coefficients */
830             pme_dd_sendrecv(atc, FALSE, i,
831                             pme->bufr+buf_pos, scount*sizeof(real),
832                             atc->coefficient+local_pos, rcount*sizeof(real));
833             buf_pos   += scount;
834             local_pos += atc->rcount[i];
835         }
836     }
837 }
838
839 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
840                            int n, rvec *f,
841                            gmx_bool bAddF)
842 {
843     int *commnode, *buf_index;
844     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
845
846     commnode  = atc->node_dest;
847     buf_index = atc->buf_index;
848
849     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
850
851     local_pos = atc->count[atc->nodeid];
852     buf_pos   = 0;
853     for (i = 0; i < nnodes_comm; i++)
854     {
855         scount = atc->rcount[i];
856         rcount = atc->count[commnode[i]];
857         if (scount > 0 || rcount > 0)
858         {
859             /* Communicate the forces */
860             pme_dd_sendrecv(atc, TRUE, i,
861                             atc->f[local_pos], scount*sizeof(rvec),
862                             pme->bufv[buf_pos], rcount*sizeof(rvec));
863             local_pos += scount;
864         }
865         buf_index[commnode[i]] = buf_pos;
866         buf_pos               += rcount;
867     }
868
869     local_pos = 0;
870     if (bAddF)
871     {
872         for (i = 0; i < n; i++)
873         {
874             node = atc->pd[i];
875             if (node == atc->nodeid)
876             {
877                 /* Add from the local force array */
878                 rvec_inc(f[i], atc->f[local_pos]);
879                 local_pos++;
880             }
881             else
882             {
883                 /* Add from the receive buffer */
884                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
885                 buf_index[node]++;
886             }
887         }
888     }
889     else
890     {
891         for (i = 0; i < n; i++)
892         {
893             node = atc->pd[i];
894             if (node == atc->nodeid)
895             {
896                 /* Copy from the local force array */
897                 copy_rvec(atc->f[local_pos], f[i]);
898                 local_pos++;
899             }
900             else
901             {
902                 /* Copy from the receive buffer */
903                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
904                 buf_index[node]++;
905             }
906         }
907     }
908 }
909
910 #ifdef GMX_MPI
911 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
912 {
913     pme_overlap_t *overlap;
914     int            send_index0, send_nindex;
915     int            recv_index0, recv_nindex;
916     MPI_Status     stat;
917     int            i, j, k, ix, iy, iz, icnt;
918     int            ipulse, send_id, recv_id, datasize;
919     real          *p;
920     real          *sendptr, *recvptr;
921
922     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
923     overlap = &pme->overlap[1];
924
925     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
926     {
927         /* Since we have already (un)wrapped the overlap in the z-dimension,
928          * we only have to communicate 0 to nkz (not pmegrid_nz).
929          */
930         if (direction == GMX_SUM_GRID_FORWARD)
931         {
932             send_id       = overlap->send_id[ipulse];
933             recv_id       = overlap->recv_id[ipulse];
934             send_index0   = overlap->comm_data[ipulse].send_index0;
935             send_nindex   = overlap->comm_data[ipulse].send_nindex;
936             recv_index0   = overlap->comm_data[ipulse].recv_index0;
937             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
938         }
939         else
940         {
941             send_id       = overlap->recv_id[ipulse];
942             recv_id       = overlap->send_id[ipulse];
943             send_index0   = overlap->comm_data[ipulse].recv_index0;
944             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
945             recv_index0   = overlap->comm_data[ipulse].send_index0;
946             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
947         }
948
949         /* Copy data to contiguous send buffer */
950         if (debug)
951         {
952             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
953                     pme->nodeid, overlap->nodeid, send_id,
954                     pme->pmegrid_start_iy,
955                     send_index0-pme->pmegrid_start_iy,
956                     send_index0-pme->pmegrid_start_iy+send_nindex);
957         }
958         icnt = 0;
959         for (i = 0; i < pme->pmegrid_nx; i++)
960         {
961             ix = i;
962             for (j = 0; j < send_nindex; j++)
963             {
964                 iy = j + send_index0 - pme->pmegrid_start_iy;
965                 for (k = 0; k < pme->nkz; k++)
966                 {
967                     iz = k;
968                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
969                 }
970             }
971         }
972
973         datasize      = pme->pmegrid_nx * pme->nkz;
974
975         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
976                      send_id, ipulse,
977                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
978                      recv_id, ipulse,
979                      overlap->mpi_comm, &stat);
980
981         /* Get data from contiguous recv buffer */
982         if (debug)
983         {
984             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
985                     pme->nodeid, overlap->nodeid, recv_id,
986                     pme->pmegrid_start_iy,
987                     recv_index0-pme->pmegrid_start_iy,
988                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
989         }
990         icnt = 0;
991         for (i = 0; i < pme->pmegrid_nx; i++)
992         {
993             ix = i;
994             for (j = 0; j < recv_nindex; j++)
995             {
996                 iy = j + recv_index0 - pme->pmegrid_start_iy;
997                 for (k = 0; k < pme->nkz; k++)
998                 {
999                     iz = k;
1000                     if (direction == GMX_SUM_GRID_FORWARD)
1001                     {
1002                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1003                     }
1004                     else
1005                     {
1006                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1007                     }
1008                 }
1009             }
1010         }
1011     }
1012
1013     /* Major dimension is easier, no copying required,
1014      * but we might have to sum to separate array.
1015      * Since we don't copy, we have to communicate up to pmegrid_nz,
1016      * not nkz as for the minor direction.
1017      */
1018     overlap = &pme->overlap[0];
1019
1020     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1021     {
1022         if (direction == GMX_SUM_GRID_FORWARD)
1023         {
1024             send_id       = overlap->send_id[ipulse];
1025             recv_id       = overlap->recv_id[ipulse];
1026             send_index0   = overlap->comm_data[ipulse].send_index0;
1027             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1028             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1029             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1030             recvptr       = overlap->recvbuf;
1031         }
1032         else
1033         {
1034             send_id       = overlap->recv_id[ipulse];
1035             recv_id       = overlap->send_id[ipulse];
1036             send_index0   = overlap->comm_data[ipulse].recv_index0;
1037             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1038             recv_index0   = overlap->comm_data[ipulse].send_index0;
1039             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1040             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1041         }
1042
1043         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1044         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1045
1046         if (debug)
1047         {
1048             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
1049                     pme->nodeid, overlap->nodeid, send_id,
1050                     pme->pmegrid_start_ix,
1051                     send_index0-pme->pmegrid_start_ix,
1052                     send_index0-pme->pmegrid_start_ix+send_nindex);
1053             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
1054                     pme->nodeid, overlap->nodeid, recv_id,
1055                     pme->pmegrid_start_ix,
1056                     recv_index0-pme->pmegrid_start_ix,
1057                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1058         }
1059
1060         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1061                      send_id, ipulse,
1062                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1063                      recv_id, ipulse,
1064                      overlap->mpi_comm, &stat);
1065
1066         /* ADD data from contiguous recv buffer */
1067         if (direction == GMX_SUM_GRID_FORWARD)
1068         {
1069             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1070             for (i = 0; i < recv_nindex*datasize; i++)
1071             {
1072                 p[i] += overlap->recvbuf[i];
1073             }
1074         }
1075     }
1076 }
1077 #endif
1078
1079
1080 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1081 {
1082     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1083     ivec    local_pme_size;
1084     int     i, ix, iy, iz;
1085     int     pmeidx, fftidx;
1086
1087     /* Dimensions should be identical for A/B grid, so we just use A here */
1088     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1089                                    local_fft_ndata,
1090                                    local_fft_offset,
1091                                    local_fft_size);
1092
1093     local_pme_size[0] = pme->pmegrid_nx;
1094     local_pme_size[1] = pme->pmegrid_ny;
1095     local_pme_size[2] = pme->pmegrid_nz;
1096
1097     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1098        the offset is identical, and the PME grid always has more data (due to overlap)
1099      */
1100     {
1101 #ifdef DEBUG_PME
1102         FILE *fp, *fp2;
1103         char  fn[STRLEN];
1104         real  val;
1105         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1106         fp = gmx_ffopen(fn, "w");
1107         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1108         fp2 = gmx_ffopen(fn, "w");
1109 #endif
1110
1111         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1112         {
1113             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1114             {
1115                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1116                 {
1117                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1118                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1119                     fftgrid[fftidx] = pmegrid[pmeidx];
1120 #ifdef DEBUG_PME
1121                     val = 100*pmegrid[pmeidx];
1122                     if (pmegrid[pmeidx] != 0)
1123                     {
1124                         gmx_fprintf_pdb_atomline(fp, epdbATOM, pmeidx, "CA", ' ', "GLY", ' ', pmeidx, ' ',
1125                                                  5.0*ix, 5.0*iy, 5.0*iz, 1.0, val, "");
1126                     }
1127                     if (pmegrid[pmeidx] != 0)
1128                     {
1129                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1130                                 "qgrid",
1131                                 pme->pmegrid_start_ix + ix,
1132                                 pme->pmegrid_start_iy + iy,
1133                                 pme->pmegrid_start_iz + iz,
1134                                 pmegrid[pmeidx]);
1135                     }
1136 #endif
1137                 }
1138             }
1139         }
1140 #ifdef DEBUG_PME
1141         gmx_ffclose(fp);
1142         gmx_ffclose(fp2);
1143 #endif
1144     }
1145     return 0;
1146 }
1147
1148
1149 static gmx_cycles_t omp_cyc_start()
1150 {
1151     return gmx_cycles_read();
1152 }
1153
1154 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1155 {
1156     return gmx_cycles_read() - c;
1157 }
1158
1159
1160 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1161                                    int nthread, int thread)
1162 {
1163     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1164     ivec          local_pme_size;
1165     int           ixy0, ixy1, ixy, ix, iy, iz;
1166     int           pmeidx, fftidx;
1167 #ifdef PME_TIME_THREADS
1168     gmx_cycles_t  c1;
1169     static double cs1 = 0;
1170     static int    cnt = 0;
1171 #endif
1172
1173 #ifdef PME_TIME_THREADS
1174     c1 = omp_cyc_start();
1175 #endif
1176     /* Dimensions should be identical for A/B grid, so we just use A here */
1177     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1178                                    local_fft_ndata,
1179                                    local_fft_offset,
1180                                    local_fft_size);
1181
1182     local_pme_size[0] = pme->pmegrid_nx;
1183     local_pme_size[1] = pme->pmegrid_ny;
1184     local_pme_size[2] = pme->pmegrid_nz;
1185
1186     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1187        the offset is identical, and the PME grid always has more data (due to overlap)
1188      */
1189     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1190     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1191
1192     for (ixy = ixy0; ixy < ixy1; ixy++)
1193     {
1194         ix = ixy/local_fft_ndata[YY];
1195         iy = ixy - ix*local_fft_ndata[YY];
1196
1197         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1198         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1199         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1200         {
1201             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1202         }
1203     }
1204
1205 #ifdef PME_TIME_THREADS
1206     c1   = omp_cyc_end(c1);
1207     cs1 += (double)c1;
1208     cnt++;
1209     if (cnt % 20 == 0)
1210     {
1211         printf("copy %.2f\n", cs1*1e-9);
1212     }
1213 #endif
1214
1215     return 0;
1216 }
1217
1218
1219 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1220 {
1221     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1222
1223     nx = pme->nkx;
1224     ny = pme->nky;
1225     nz = pme->nkz;
1226
1227     pnx = pme->pmegrid_nx;
1228     pny = pme->pmegrid_ny;
1229     pnz = pme->pmegrid_nz;
1230
1231     overlap = pme->pme_order - 1;
1232
1233     /* Add periodic overlap in z */
1234     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1235     {
1236         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1237         {
1238             for (iz = 0; iz < overlap; iz++)
1239             {
1240                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1241                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1242             }
1243         }
1244     }
1245
1246     if (pme->nnodes_minor == 1)
1247     {
1248         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1249         {
1250             for (iy = 0; iy < overlap; iy++)
1251             {
1252                 for (iz = 0; iz < nz; iz++)
1253                 {
1254                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1255                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1256                 }
1257             }
1258         }
1259     }
1260
1261     if (pme->nnodes_major == 1)
1262     {
1263         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1264
1265         for (ix = 0; ix < overlap; ix++)
1266         {
1267             for (iy = 0; iy < ny_x; iy++)
1268             {
1269                 for (iz = 0; iz < nz; iz++)
1270                 {
1271                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1272                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1273                 }
1274             }
1275         }
1276     }
1277 }
1278
1279
1280 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1281 {
1282     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1283
1284     nx = pme->nkx;
1285     ny = pme->nky;
1286     nz = pme->nkz;
1287
1288     pnx = pme->pmegrid_nx;
1289     pny = pme->pmegrid_ny;
1290     pnz = pme->pmegrid_nz;
1291
1292     overlap = pme->pme_order - 1;
1293
1294     if (pme->nnodes_major == 1)
1295     {
1296         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1297
1298         for (ix = 0; ix < overlap; ix++)
1299         {
1300             int iy, iz;
1301
1302             for (iy = 0; iy < ny_x; iy++)
1303             {
1304                 for (iz = 0; iz < nz; iz++)
1305                 {
1306                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1307                         pmegrid[(ix*pny+iy)*pnz+iz];
1308                 }
1309             }
1310         }
1311     }
1312
1313     if (pme->nnodes_minor == 1)
1314     {
1315 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1316         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1317         {
1318             int iy, iz;
1319
1320             for (iy = 0; iy < overlap; iy++)
1321             {
1322                 for (iz = 0; iz < nz; iz++)
1323                 {
1324                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1325                         pmegrid[(ix*pny+iy)*pnz+iz];
1326                 }
1327             }
1328         }
1329     }
1330
1331     /* Copy periodic overlap in z */
1332 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1333     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1334     {
1335         int iy, iz;
1336
1337         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1338         {
1339             for (iz = 0; iz < overlap; iz++)
1340             {
1341                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1342                     pmegrid[(ix*pny+iy)*pnz+iz];
1343             }
1344         }
1345     }
1346 }
1347
1348
1349 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1350 #define DO_BSPLINE(order)                            \
1351     for (ithx = 0; (ithx < order); ithx++)                    \
1352     {                                                    \
1353         index_x = (i0+ithx)*pny*pnz;                     \
1354         valx    = coefficient*thx[ithx];                          \
1355                                                      \
1356         for (ithy = 0; (ithy < order); ithy++)                \
1357         {                                                \
1358             valxy    = valx*thy[ithy];                   \
1359             index_xy = index_x+(j0+ithy)*pnz;            \
1360                                                      \
1361             for (ithz = 0; (ithz < order); ithz++)            \
1362             {                                            \
1363                 index_xyz        = index_xy+(k0+ithz);   \
1364                 grid[index_xyz] += valxy*thz[ithz];      \
1365             }                                            \
1366         }                                                \
1367     }
1368
1369
1370 static void spread_coefficients_bsplines_thread(pmegrid_t                    *pmegrid,
1371                                                 pme_atomcomm_t               *atc,
1372                                                 splinedata_t                 *spline,
1373                                                 pme_spline_work_t gmx_unused *work)
1374 {
1375
1376     /* spread coefficients from home atoms to local grid */
1377     real          *grid;
1378     pme_overlap_t *ol;
1379     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1380     int       *    idxptr;
1381     int            order, norder, index_x, index_xy, index_xyz;
1382     real           valx, valxy, coefficient;
1383     real          *thx, *thy, *thz;
1384     int            localsize, bndsize;
1385     int            pnx, pny, pnz, ndatatot;
1386     int            offx, offy, offz;
1387
1388 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1389     real           thz_buffer[GMX_SIMD4_WIDTH*3], *thz_aligned;
1390
1391     thz_aligned = gmx_simd4_align_r(thz_buffer);
1392 #endif
1393
1394     pnx = pmegrid->s[XX];
1395     pny = pmegrid->s[YY];
1396     pnz = pmegrid->s[ZZ];
1397
1398     offx = pmegrid->offset[XX];
1399     offy = pmegrid->offset[YY];
1400     offz = pmegrid->offset[ZZ];
1401
1402     ndatatot = pnx*pny*pnz;
1403     grid     = pmegrid->grid;
1404     for (i = 0; i < ndatatot; i++)
1405     {
1406         grid[i] = 0;
1407     }
1408
1409     order = pmegrid->order;
1410
1411     for (nn = 0; nn < spline->n; nn++)
1412     {
1413         n           = spline->ind[nn];
1414         coefficient = atc->coefficient[n];
1415
1416         if (coefficient != 0)
1417         {
1418             idxptr = atc->idx[n];
1419             norder = nn*order;
1420
1421             i0   = idxptr[XX] - offx;
1422             j0   = idxptr[YY] - offy;
1423             k0   = idxptr[ZZ] - offz;
1424
1425             thx = spline->theta[XX] + norder;
1426             thy = spline->theta[YY] + norder;
1427             thz = spline->theta[ZZ] + norder;
1428
1429             switch (order)
1430             {
1431                 case 4:
1432 #ifdef PME_SIMD4_SPREAD_GATHER
1433 #ifdef PME_SIMD4_UNALIGNED
1434 #define PME_SPREAD_SIMD4_ORDER4
1435 #else
1436 #define PME_SPREAD_SIMD4_ALIGNED
1437 #define PME_ORDER 4
1438 #endif
1439 #include "gromacs/mdlib/pme_simd4.h"
1440 #else
1441                     DO_BSPLINE(4);
1442 #endif
1443                     break;
1444                 case 5:
1445 #ifdef PME_SIMD4_SPREAD_GATHER
1446 #define PME_SPREAD_SIMD4_ALIGNED
1447 #define PME_ORDER 5
1448 #include "gromacs/mdlib/pme_simd4.h"
1449 #else
1450                     DO_BSPLINE(5);
1451 #endif
1452                     break;
1453                 default:
1454                     DO_BSPLINE(order);
1455                     break;
1456             }
1457         }
1458     }
1459 }
1460
1461 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1462 {
1463 #ifdef PME_SIMD4_SPREAD_GATHER
1464     if (pme_order == 5
1465 #ifndef PME_SIMD4_UNALIGNED
1466         || pme_order == 4
1467 #endif
1468         )
1469     {
1470         /* Round nz up to a multiple of 4 to ensure alignment */
1471         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1472     }
1473 #endif
1474 }
1475
1476 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1477 {
1478 #ifdef PME_SIMD4_SPREAD_GATHER
1479 #ifndef PME_SIMD4_UNALIGNED
1480     if (pme_order == 4)
1481     {
1482         /* Add extra elements to ensured aligned operations do not go
1483          * beyond the allocated grid size.
1484          * Note that for pme_order=5, the pme grid z-size alignment
1485          * ensures that we will not go beyond the grid size.
1486          */
1487         *gridsize += 4;
1488     }
1489 #endif
1490 #endif
1491 }
1492
1493 static void pmegrid_init(pmegrid_t *grid,
1494                          int cx, int cy, int cz,
1495                          int x0, int y0, int z0,
1496                          int x1, int y1, int z1,
1497                          gmx_bool set_alignment,
1498                          int pme_order,
1499                          real *ptr)
1500 {
1501     int nz, gridsize;
1502
1503     grid->ci[XX]     = cx;
1504     grid->ci[YY]     = cy;
1505     grid->ci[ZZ]     = cz;
1506     grid->offset[XX] = x0;
1507     grid->offset[YY] = y0;
1508     grid->offset[ZZ] = z0;
1509     grid->n[XX]      = x1 - x0 + pme_order - 1;
1510     grid->n[YY]      = y1 - y0 + pme_order - 1;
1511     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1512     copy_ivec(grid->n, grid->s);
1513
1514     nz = grid->s[ZZ];
1515     set_grid_alignment(&nz, pme_order);
1516     if (set_alignment)
1517     {
1518         grid->s[ZZ] = nz;
1519     }
1520     else if (nz != grid->s[ZZ])
1521     {
1522         gmx_incons("pmegrid_init call with an unaligned z size");
1523     }
1524
1525     grid->order = pme_order;
1526     if (ptr == NULL)
1527     {
1528         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1529         set_gridsize_alignment(&gridsize, pme_order);
1530         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1531     }
1532     else
1533     {
1534         grid->grid = ptr;
1535     }
1536 }
1537
1538 static int div_round_up(int enumerator, int denominator)
1539 {
1540     return (enumerator + denominator - 1)/denominator;
1541 }
1542
1543 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1544                                   ivec nsub)
1545 {
1546     int gsize_opt, gsize;
1547     int nsx, nsy, nsz;
1548     char *env;
1549
1550     gsize_opt = -1;
1551     for (nsx = 1; nsx <= nthread; nsx++)
1552     {
1553         if (nthread % nsx == 0)
1554         {
1555             for (nsy = 1; nsy <= nthread; nsy++)
1556             {
1557                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1558                 {
1559                     nsz = nthread/(nsx*nsy);
1560
1561                     /* Determine the number of grid points per thread */
1562                     gsize =
1563                         (div_round_up(n[XX], nsx) + ovl)*
1564                         (div_round_up(n[YY], nsy) + ovl)*
1565                         (div_round_up(n[ZZ], nsz) + ovl);
1566
1567                     /* Minimize the number of grids points per thread
1568                      * and, secondarily, the number of cuts in minor dimensions.
1569                      */
1570                     if (gsize_opt == -1 ||
1571                         gsize < gsize_opt ||
1572                         (gsize == gsize_opt &&
1573                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1574                     {
1575                         nsub[XX]  = nsx;
1576                         nsub[YY]  = nsy;
1577                         nsub[ZZ]  = nsz;
1578                         gsize_opt = gsize;
1579                     }
1580                 }
1581             }
1582         }
1583     }
1584
1585     env = getenv("GMX_PME_THREAD_DIVISION");
1586     if (env != NULL)
1587     {
1588         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1589     }
1590
1591     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1592     {
1593         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1594     }
1595 }
1596
1597 static void pmegrids_init(pmegrids_t *grids,
1598                           int nx, int ny, int nz, int nz_base,
1599                           int pme_order,
1600                           gmx_bool bUseThreads,
1601                           int nthread,
1602                           int overlap_x,
1603                           int overlap_y)
1604 {
1605     ivec n, n_base, g0, g1;
1606     int t, x, y, z, d, i, tfac;
1607     int max_comm_lines = -1;
1608
1609     n[XX] = nx - (pme_order - 1);
1610     n[YY] = ny - (pme_order - 1);
1611     n[ZZ] = nz - (pme_order - 1);
1612
1613     copy_ivec(n, n_base);
1614     n_base[ZZ] = nz_base;
1615
1616     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1617                  NULL);
1618
1619     grids->nthread = nthread;
1620
1621     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1622
1623     if (bUseThreads)
1624     {
1625         ivec nst;
1626         int gridsize;
1627
1628         for (d = 0; d < DIM; d++)
1629         {
1630             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1631         }
1632         set_grid_alignment(&nst[ZZ], pme_order);
1633
1634         if (debug)
1635         {
1636             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1637                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1638             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1639                     nx, ny, nz,
1640                     nst[XX], nst[YY], nst[ZZ]);
1641         }
1642
1643         snew(grids->grid_th, grids->nthread);
1644         t        = 0;
1645         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1646         set_gridsize_alignment(&gridsize, pme_order);
1647         snew_aligned(grids->grid_all,
1648                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1649                      SIMD4_ALIGNMENT);
1650
1651         for (x = 0; x < grids->nc[XX]; x++)
1652         {
1653             for (y = 0; y < grids->nc[YY]; y++)
1654             {
1655                 for (z = 0; z < grids->nc[ZZ]; z++)
1656                 {
1657                     pmegrid_init(&grids->grid_th[t],
1658                                  x, y, z,
1659                                  (n[XX]*(x  ))/grids->nc[XX],
1660                                  (n[YY]*(y  ))/grids->nc[YY],
1661                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1662                                  (n[XX]*(x+1))/grids->nc[XX],
1663                                  (n[YY]*(y+1))/grids->nc[YY],
1664                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1665                                  TRUE,
1666                                  pme_order,
1667                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1668                     t++;
1669                 }
1670             }
1671         }
1672     }
1673     else
1674     {
1675         grids->grid_th = NULL;
1676     }
1677
1678     snew(grids->g2t, DIM);
1679     tfac = 1;
1680     for (d = DIM-1; d >= 0; d--)
1681     {
1682         snew(grids->g2t[d], n[d]);
1683         t = 0;
1684         for (i = 0; i < n[d]; i++)
1685         {
1686             /* The second check should match the parameters
1687              * of the pmegrid_init call above.
1688              */
1689             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1690             {
1691                 t++;
1692             }
1693             grids->g2t[d][i] = t*tfac;
1694         }
1695
1696         tfac *= grids->nc[d];
1697
1698         switch (d)
1699         {
1700             case XX: max_comm_lines = overlap_x;     break;
1701             case YY: max_comm_lines = overlap_y;     break;
1702             case ZZ: max_comm_lines = pme_order - 1; break;
1703         }
1704         grids->nthread_comm[d] = 0;
1705         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1706                grids->nthread_comm[d] < grids->nc[d])
1707         {
1708             grids->nthread_comm[d]++;
1709         }
1710         if (debug != NULL)
1711         {
1712             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1713                     'x'+d, grids->nthread_comm[d]);
1714         }
1715         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1716          * work, but this is not a problematic restriction.
1717          */
1718         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1719         {
1720             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1721         }
1722     }
1723 }
1724
1725
1726 static void pmegrids_destroy(pmegrids_t *grids)
1727 {
1728     int t;
1729
1730     if (grids->grid.grid != NULL)
1731     {
1732         sfree(grids->grid.grid);
1733
1734         if (grids->nthread > 0)
1735         {
1736             for (t = 0; t < grids->nthread; t++)
1737             {
1738                 sfree(grids->grid_th[t].grid);
1739             }
1740             sfree(grids->grid_th);
1741         }
1742     }
1743 }
1744
1745
1746 static void realloc_work(pme_work_t *work, int nkx)
1747 {
1748     int simd_width;
1749
1750     if (nkx > work->nalloc)
1751     {
1752         work->nalloc = nkx;
1753         srenew(work->mhx, work->nalloc);
1754         srenew(work->mhy, work->nalloc);
1755         srenew(work->mhz, work->nalloc);
1756         srenew(work->m2, work->nalloc);
1757         /* Allocate an aligned pointer for SIMD operations, including extra
1758          * elements at the end for padding.
1759          */
1760 #ifdef PME_SIMD_SOLVE
1761         simd_width = GMX_SIMD_REAL_WIDTH;
1762 #else
1763         /* We can use any alignment, apart from 0, so we use 4 */
1764         simd_width = 4;
1765 #endif
1766         sfree_aligned(work->denom);
1767         sfree_aligned(work->tmp1);
1768         sfree_aligned(work->tmp2);
1769         sfree_aligned(work->eterm);
1770         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1771         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1772         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1773         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1774         srenew(work->m2inv, work->nalloc);
1775     }
1776 }
1777
1778
1779 static void free_work(pme_work_t *work)
1780 {
1781     sfree(work->mhx);
1782     sfree(work->mhy);
1783     sfree(work->mhz);
1784     sfree(work->m2);
1785     sfree_aligned(work->denom);
1786     sfree_aligned(work->tmp1);
1787     sfree_aligned(work->tmp2);
1788     sfree_aligned(work->eterm);
1789     sfree(work->m2inv);
1790 }
1791
1792
1793 #if defined PME_SIMD_SOLVE
1794 /* Calculate exponentials through SIMD */
1795 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1796 {
1797     {
1798         const gmx_simd_real_t two = gmx_simd_set1_r(2.0);
1799         gmx_simd_real_t f_simd;
1800         gmx_simd_real_t lu;
1801         gmx_simd_real_t tmp_d1, d_inv, tmp_r, tmp_e;
1802         int kx;
1803         f_simd = gmx_simd_set1_r(f);
1804         /* We only need to calculate from start. But since start is 0 or 1
1805          * and we want to use aligned loads/stores, we always start from 0.
1806          */
1807         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1808         {
1809             tmp_d1   = gmx_simd_load_r(d_aligned+kx);
1810             d_inv    = gmx_simd_inv_r(tmp_d1);
1811             tmp_r    = gmx_simd_load_r(r_aligned+kx);
1812             tmp_r    = gmx_simd_exp_r(tmp_r);
1813             tmp_e    = gmx_simd_mul_r(f_simd, d_inv);
1814             tmp_e    = gmx_simd_mul_r(tmp_e, tmp_r);
1815             gmx_simd_store_r(e_aligned+kx, tmp_e);
1816         }
1817     }
1818 }
1819 #else
1820 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1821 {
1822     int kx;
1823     for (kx = start; kx < end; kx++)
1824     {
1825         d[kx] = 1.0/d[kx];
1826     }
1827     for (kx = start; kx < end; kx++)
1828     {
1829         r[kx] = exp(r[kx]);
1830     }
1831     for (kx = start; kx < end; kx++)
1832     {
1833         e[kx] = f*r[kx]*d[kx];
1834     }
1835 }
1836 #endif
1837
1838 #if defined PME_SIMD_SOLVE
1839 /* Calculate exponentials through SIMD */
1840 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1841 {
1842     gmx_simd_real_t tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1843     const gmx_simd_real_t sqr_PI = gmx_simd_sqrt_r(gmx_simd_set1_r(M_PI));
1844     int kx;
1845     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1846     {
1847         /* We only need to calculate from start. But since start is 0 or 1
1848          * and we want to use aligned loads/stores, we always start from 0.
1849          */
1850         tmp_d = gmx_simd_load_r(d_aligned+kx);
1851         d_inv = gmx_simd_inv_r(tmp_d);
1852         gmx_simd_store_r(d_aligned+kx, d_inv);
1853         tmp_r = gmx_simd_load_r(r_aligned+kx);
1854         tmp_r = gmx_simd_exp_r(tmp_r);
1855         gmx_simd_store_r(r_aligned+kx, tmp_r);
1856         tmp_mk  = gmx_simd_load_r(factor_aligned+kx);
1857         tmp_fac = gmx_simd_mul_r(sqr_PI, gmx_simd_mul_r(tmp_mk, gmx_simd_erfc_r(tmp_mk)));
1858         gmx_simd_store_r(factor_aligned+kx, tmp_fac);
1859     }
1860 }
1861 #else
1862 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1863 {
1864     int kx;
1865     real mk;
1866     for (kx = start; kx < end; kx++)
1867     {
1868         d[kx] = 1.0/d[kx];
1869     }
1870
1871     for (kx = start; kx < end; kx++)
1872     {
1873         r[kx] = exp(r[kx]);
1874     }
1875
1876     for (kx = start; kx < end; kx++)
1877     {
1878         mk       = tmp2[kx];
1879         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1880     }
1881 }
1882 #endif
1883
1884 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1885                          real ewaldcoeff, real vol,
1886                          gmx_bool bEnerVir,
1887                          int nthread, int thread)
1888 {
1889     /* do recip sum over local cells in grid */
1890     /* y major, z middle, x minor or continuous */
1891     t_complex *p0;
1892     int     kx, ky, kz, maxkx, maxky, maxkz;
1893     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1894     real    mx, my, mz;
1895     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1896     real    ets2, struct2, vfactor, ets2vf;
1897     real    d1, d2, energy = 0;
1898     real    by, bz;
1899     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1900     real    rxx, ryx, ryy, rzx, rzy, rzz;
1901     pme_work_t *work;
1902     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1903     real    mhxk, mhyk, mhzk, m2k;
1904     real    corner_fac;
1905     ivec    complex_order;
1906     ivec    local_ndata, local_offset, local_size;
1907     real    elfac;
1908
1909     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1910
1911     nx = pme->nkx;
1912     ny = pme->nky;
1913     nz = pme->nkz;
1914
1915     /* Dimensions should be identical for A/B grid, so we just use A here */
1916     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
1917                                       complex_order,
1918                                       local_ndata,
1919                                       local_offset,
1920                                       local_size);
1921
1922     rxx = pme->recipbox[XX][XX];
1923     ryx = pme->recipbox[YY][XX];
1924     ryy = pme->recipbox[YY][YY];
1925     rzx = pme->recipbox[ZZ][XX];
1926     rzy = pme->recipbox[ZZ][YY];
1927     rzz = pme->recipbox[ZZ][ZZ];
1928
1929     maxkx = (nx+1)/2;
1930     maxky = (ny+1)/2;
1931     maxkz = nz/2+1;
1932
1933     work  = &pme->work[thread];
1934     mhx   = work->mhx;
1935     mhy   = work->mhy;
1936     mhz   = work->mhz;
1937     m2    = work->m2;
1938     denom = work->denom;
1939     tmp1  = work->tmp1;
1940     eterm = work->eterm;
1941     m2inv = work->m2inv;
1942
1943     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1944     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1945
1946     for (iyz = iyz0; iyz < iyz1; iyz++)
1947     {
1948         iy = iyz/local_ndata[ZZ];
1949         iz = iyz - iy*local_ndata[ZZ];
1950
1951         ky = iy + local_offset[YY];
1952
1953         if (ky < maxky)
1954         {
1955             my = ky;
1956         }
1957         else
1958         {
1959             my = (ky - ny);
1960         }
1961
1962         by = M_PI*vol*pme->bsp_mod[YY][ky];
1963
1964         kz = iz + local_offset[ZZ];
1965
1966         mz = kz;
1967
1968         bz = pme->bsp_mod[ZZ][kz];
1969
1970         /* 0.5 correction for corner points */
1971         corner_fac = 1;
1972         if (kz == 0 || kz == (nz+1)/2)
1973         {
1974             corner_fac = 0.5;
1975         }
1976
1977         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1978
1979         /* We should skip the k-space point (0,0,0) */
1980         /* Note that since here x is the minor index, local_offset[XX]=0 */
1981         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1982         {
1983             kxstart = local_offset[XX];
1984         }
1985         else
1986         {
1987             kxstart = local_offset[XX] + 1;
1988             p0++;
1989         }
1990         kxend = local_offset[XX] + local_ndata[XX];
1991
1992         if (bEnerVir)
1993         {
1994             /* More expensive inner loop, especially because of the storage
1995              * of the mh elements in array's.
1996              * Because x is the minor grid index, all mh elements
1997              * depend on kx for triclinic unit cells.
1998              */
1999
2000             /* Two explicit loops to avoid a conditional inside the loop */
2001             for (kx = kxstart; kx < maxkx; kx++)
2002             {
2003                 mx = kx;
2004
2005                 mhxk      = mx * rxx;
2006                 mhyk      = mx * ryx + my * ryy;
2007                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2008                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2009                 mhx[kx]   = mhxk;
2010                 mhy[kx]   = mhyk;
2011                 mhz[kx]   = mhzk;
2012                 m2[kx]    = m2k;
2013                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2014                 tmp1[kx]  = -factor*m2k;
2015             }
2016
2017             for (kx = maxkx; kx < kxend; kx++)
2018             {
2019                 mx = (kx - nx);
2020
2021                 mhxk      = mx * rxx;
2022                 mhyk      = mx * ryx + my * ryy;
2023                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2024                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2025                 mhx[kx]   = mhxk;
2026                 mhy[kx]   = mhyk;
2027                 mhz[kx]   = mhzk;
2028                 m2[kx]    = m2k;
2029                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2030                 tmp1[kx]  = -factor*m2k;
2031             }
2032
2033             for (kx = kxstart; kx < kxend; kx++)
2034             {
2035                 m2inv[kx] = 1.0/m2[kx];
2036             }
2037
2038             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2039
2040             for (kx = kxstart; kx < kxend; kx++, p0++)
2041             {
2042                 d1      = p0->re;
2043                 d2      = p0->im;
2044
2045                 p0->re  = d1*eterm[kx];
2046                 p0->im  = d2*eterm[kx];
2047
2048                 struct2 = 2.0*(d1*d1+d2*d2);
2049
2050                 tmp1[kx] = eterm[kx]*struct2;
2051             }
2052
2053             for (kx = kxstart; kx < kxend; kx++)
2054             {
2055                 ets2     = corner_fac*tmp1[kx];
2056                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2057                 energy  += ets2;
2058
2059                 ets2vf   = ets2*vfactor;
2060                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2061                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2062                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2063                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2064                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2065                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2066             }
2067         }
2068         else
2069         {
2070             /* We don't need to calculate the energy and the virial.
2071              * In this case the triclinic overhead is small.
2072              */
2073
2074             /* Two explicit loops to avoid a conditional inside the loop */
2075
2076             for (kx = kxstart; kx < maxkx; kx++)
2077             {
2078                 mx = kx;
2079
2080                 mhxk      = mx * rxx;
2081                 mhyk      = mx * ryx + my * ryy;
2082                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2083                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2084                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2085                 tmp1[kx]  = -factor*m2k;
2086             }
2087
2088             for (kx = maxkx; kx < kxend; kx++)
2089             {
2090                 mx = (kx - nx);
2091
2092                 mhxk      = mx * rxx;
2093                 mhyk      = mx * ryx + my * ryy;
2094                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2095                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2096                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2097                 tmp1[kx]  = -factor*m2k;
2098             }
2099
2100             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2101
2102             for (kx = kxstart; kx < kxend; kx++, p0++)
2103             {
2104                 d1      = p0->re;
2105                 d2      = p0->im;
2106
2107                 p0->re  = d1*eterm[kx];
2108                 p0->im  = d2*eterm[kx];
2109             }
2110         }
2111     }
2112
2113     if (bEnerVir)
2114     {
2115         /* Update virial with local values.
2116          * The virial is symmetric by definition.
2117          * this virial seems ok for isotropic scaling, but I'm
2118          * experiencing problems on semiisotropic membranes.
2119          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2120          */
2121         work->vir_q[XX][XX] = 0.25*virxx;
2122         work->vir_q[YY][YY] = 0.25*viryy;
2123         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2124         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2125         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2126         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2127
2128         /* This energy should be corrected for a charged system */
2129         work->energy_q = 0.5*energy;
2130     }
2131
2132     /* Return the loop count */
2133     return local_ndata[YY]*local_ndata[XX];
2134 }
2135
2136 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2137                             real ewaldcoeff, real vol,
2138                             gmx_bool bEnerVir, int nthread, int thread)
2139 {
2140     /* do recip sum over local cells in grid */
2141     /* y major, z middle, x minor or continuous */
2142     int     ig, gcount;
2143     int     kx, ky, kz, maxkx, maxky, maxkz;
2144     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2145     real    mx, my, mz;
2146     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2147     real    ets2, ets2vf;
2148     real    eterm, vterm, d1, d2, energy = 0;
2149     real    by, bz;
2150     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2151     real    rxx, ryx, ryy, rzx, rzy, rzz;
2152     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2153     real    mhxk, mhyk, mhzk, m2k;
2154     real    mk;
2155     pme_work_t *work;
2156     real    corner_fac;
2157     ivec    complex_order;
2158     ivec    local_ndata, local_offset, local_size;
2159     nx = pme->nkx;
2160     ny = pme->nky;
2161     nz = pme->nkz;
2162
2163     /* Dimensions should be identical for A/B grid, so we just use A here */
2164     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2165                                       complex_order,
2166                                       local_ndata,
2167                                       local_offset,
2168                                       local_size);
2169     rxx = pme->recipbox[XX][XX];
2170     ryx = pme->recipbox[YY][XX];
2171     ryy = pme->recipbox[YY][YY];
2172     rzx = pme->recipbox[ZZ][XX];
2173     rzy = pme->recipbox[ZZ][YY];
2174     rzz = pme->recipbox[ZZ][ZZ];
2175
2176     maxkx = (nx+1)/2;
2177     maxky = (ny+1)/2;
2178     maxkz = nz/2+1;
2179
2180     work  = &pme->work[thread];
2181     mhx   = work->mhx;
2182     mhy   = work->mhy;
2183     mhz   = work->mhz;
2184     m2    = work->m2;
2185     denom = work->denom;
2186     tmp1  = work->tmp1;
2187     tmp2  = work->tmp2;
2188
2189     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2190     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2191
2192     for (iyz = iyz0; iyz < iyz1; iyz++)
2193     {
2194         iy = iyz/local_ndata[ZZ];
2195         iz = iyz - iy*local_ndata[ZZ];
2196
2197         ky = iy + local_offset[YY];
2198
2199         if (ky < maxky)
2200         {
2201             my = ky;
2202         }
2203         else
2204         {
2205             my = (ky - ny);
2206         }
2207
2208         by = 3.0*vol*pme->bsp_mod[YY][ky]
2209             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2210
2211         kz = iz + local_offset[ZZ];
2212
2213         mz = kz;
2214
2215         bz = pme->bsp_mod[ZZ][kz];
2216
2217         /* 0.5 correction for corner points */
2218         corner_fac = 1;
2219         if (kz == 0 || kz == (nz+1)/2)
2220         {
2221             corner_fac = 0.5;
2222         }
2223
2224         kxstart = local_offset[XX];
2225         kxend   = local_offset[XX] + local_ndata[XX];
2226         if (bEnerVir)
2227         {
2228             /* More expensive inner loop, especially because of the
2229              * storage of the mh elements in array's.  Because x is the
2230              * minor grid index, all mh elements depend on kx for
2231              * triclinic unit cells.
2232              */
2233
2234             /* Two explicit loops to avoid a conditional inside the loop */
2235             for (kx = kxstart; kx < maxkx; kx++)
2236             {
2237                 mx = kx;
2238
2239                 mhxk      = mx * rxx;
2240                 mhyk      = mx * ryx + my * ryy;
2241                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2242                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2243                 mhx[kx]   = mhxk;
2244                 mhy[kx]   = mhyk;
2245                 mhz[kx]   = mhzk;
2246                 m2[kx]    = m2k;
2247                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2248                 tmp1[kx]  = -factor*m2k;
2249                 tmp2[kx]  = sqrt(factor*m2k);
2250             }
2251
2252             for (kx = maxkx; kx < kxend; kx++)
2253             {
2254                 mx = (kx - nx);
2255
2256                 mhxk      = mx * rxx;
2257                 mhyk      = mx * ryx + my * ryy;
2258                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2259                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2260                 mhx[kx]   = mhxk;
2261                 mhy[kx]   = mhyk;
2262                 mhz[kx]   = mhzk;
2263                 m2[kx]    = m2k;
2264                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2265                 tmp1[kx]  = -factor*m2k;
2266                 tmp2[kx]  = sqrt(factor*m2k);
2267             }
2268
2269             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2270
2271             for (kx = kxstart; kx < kxend; kx++)
2272             {
2273                 m2k   = factor*m2[kx];
2274                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2275                           + 2.0*m2k*tmp2[kx]);
2276                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2277                 tmp1[kx] = eterm*denom[kx];
2278                 tmp2[kx] = vterm*denom[kx];
2279             }
2280
2281             if (!bLB)
2282             {
2283                 t_complex *p0;
2284                 real       struct2;
2285
2286                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2287                 for (kx = kxstart; kx < kxend; kx++, p0++)
2288                 {
2289                     d1      = p0->re;
2290                     d2      = p0->im;
2291
2292                     eterm   = tmp1[kx];
2293                     vterm   = tmp2[kx];
2294                     p0->re  = d1*eterm;
2295                     p0->im  = d2*eterm;
2296
2297                     struct2 = 2.0*(d1*d1+d2*d2);
2298
2299                     tmp1[kx] = eterm*struct2;
2300                     tmp2[kx] = vterm*struct2;
2301                 }
2302             }
2303             else
2304             {
2305                 real *struct2 = denom;
2306                 real  str2;
2307
2308                 for (kx = kxstart; kx < kxend; kx++)
2309                 {
2310                     struct2[kx] = 0.0;
2311                 }
2312                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2313                 for (ig = 0; ig <= 3; ++ig)
2314                 {
2315                     t_complex *p0, *p1;
2316                     real       scale;
2317
2318                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2319                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2320                     scale = 2.0*lb_scale_factor_symm[ig];
2321                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2322                     {
2323                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2324                     }
2325
2326                 }
2327                 for (ig = 0; ig <= 6; ++ig)
2328                 {
2329                     t_complex *p0;
2330
2331                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2332                     for (kx = kxstart; kx < kxend; kx++, p0++)
2333                     {
2334                         d1     = p0->re;
2335                         d2     = p0->im;
2336
2337                         eterm  = tmp1[kx];
2338                         p0->re = d1*eterm;
2339                         p0->im = d2*eterm;
2340                     }
2341                 }
2342                 for (kx = kxstart; kx < kxend; kx++)
2343                 {
2344                     eterm    = tmp1[kx];
2345                     vterm    = tmp2[kx];
2346                     str2     = struct2[kx];
2347                     tmp1[kx] = eterm*str2;
2348                     tmp2[kx] = vterm*str2;
2349                 }
2350             }
2351
2352             for (kx = kxstart; kx < kxend; kx++)
2353             {
2354                 ets2     = corner_fac*tmp1[kx];
2355                 vterm    = 2.0*factor*tmp2[kx];
2356                 energy  += ets2;
2357                 ets2vf   = corner_fac*vterm;
2358                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2359                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2360                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2361                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2362                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2363                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2364             }
2365         }
2366         else
2367         {
2368             /* We don't need to calculate the energy and the virial.
2369              *  In this case the triclinic overhead is small.
2370              */
2371
2372             /* Two explicit loops to avoid a conditional inside the loop */
2373
2374             for (kx = kxstart; kx < maxkx; kx++)
2375             {
2376                 mx = kx;
2377
2378                 mhxk      = mx * rxx;
2379                 mhyk      = mx * ryx + my * ryy;
2380                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2381                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2382                 m2[kx]    = m2k;
2383                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2384                 tmp1[kx]  = -factor*m2k;
2385                 tmp2[kx]  = sqrt(factor*m2k);
2386             }
2387
2388             for (kx = maxkx; kx < kxend; kx++)
2389             {
2390                 mx = (kx - nx);
2391
2392                 mhxk      = mx * rxx;
2393                 mhyk      = mx * ryx + my * ryy;
2394                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2395                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2396                 m2[kx]    = m2k;
2397                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2398                 tmp1[kx]  = -factor*m2k;
2399                 tmp2[kx]  = sqrt(factor*m2k);
2400             }
2401
2402             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2403
2404             for (kx = kxstart; kx < kxend; kx++)
2405             {
2406                 m2k    = factor*m2[kx];
2407                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2408                            + 2.0*m2k*tmp2[kx]);
2409                 tmp1[kx] = eterm*denom[kx];
2410             }
2411             gcount = (bLB ? 7 : 1);
2412             for (ig = 0; ig < gcount; ++ig)
2413             {
2414                 t_complex *p0;
2415
2416                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2417                 for (kx = kxstart; kx < kxend; kx++, p0++)
2418                 {
2419                     d1      = p0->re;
2420                     d2      = p0->im;
2421
2422                     eterm   = tmp1[kx];
2423
2424                     p0->re  = d1*eterm;
2425                     p0->im  = d2*eterm;
2426                 }
2427             }
2428         }
2429     }
2430     if (bEnerVir)
2431     {
2432         work->vir_lj[XX][XX] = 0.25*virxx;
2433         work->vir_lj[YY][YY] = 0.25*viryy;
2434         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2435         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2436         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2437         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2438
2439         /* This energy should be corrected for a charged system */
2440         work->energy_lj = 0.5*energy;
2441     }
2442     /* Return the loop count */
2443     return local_ndata[YY]*local_ndata[XX];
2444 }
2445
2446 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2447                                real *mesh_energy, matrix vir)
2448 {
2449     /* This function sums output over threads and should therefore
2450      * only be called after thread synchronization.
2451      */
2452     int thread;
2453
2454     *mesh_energy = pme->work[0].energy_q;
2455     copy_mat(pme->work[0].vir_q, vir);
2456
2457     for (thread = 1; thread < nthread; thread++)
2458     {
2459         *mesh_energy += pme->work[thread].energy_q;
2460         m_add(vir, pme->work[thread].vir_q, vir);
2461     }
2462 }
2463
2464 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2465                                 real *mesh_energy, matrix vir)
2466 {
2467     /* This function sums output over threads and should therefore
2468      * only be called after thread synchronization.
2469      */
2470     int thread;
2471
2472     *mesh_energy = pme->work[0].energy_lj;
2473     copy_mat(pme->work[0].vir_lj, vir);
2474
2475     for (thread = 1; thread < nthread; thread++)
2476     {
2477         *mesh_energy += pme->work[thread].energy_lj;
2478         m_add(vir, pme->work[thread].vir_lj, vir);
2479     }
2480 }
2481
2482
2483 #define DO_FSPLINE(order)                      \
2484     for (ithx = 0; (ithx < order); ithx++)              \
2485     {                                              \
2486         index_x = (i0+ithx)*pny*pnz;               \
2487         tx      = thx[ithx];                       \
2488         dx      = dthx[ithx];                      \
2489                                                \
2490         for (ithy = 0; (ithy < order); ithy++)          \
2491         {                                          \
2492             index_xy = index_x+(j0+ithy)*pnz;      \
2493             ty       = thy[ithy];                  \
2494             dy       = dthy[ithy];                 \
2495             fxy1     = fz1 = 0;                    \
2496                                                \
2497             for (ithz = 0; (ithz < order); ithz++)      \
2498             {                                      \
2499                 gval  = grid[index_xy+(k0+ithz)];  \
2500                 fxy1 += thz[ithz]*gval;            \
2501                 fz1  += dthz[ithz]*gval;           \
2502             }                                      \
2503             fx += dx*ty*fxy1;                      \
2504             fy += tx*dy*fxy1;                      \
2505             fz += tx*ty*fz1;                       \
2506         }                                          \
2507     }
2508
2509
2510 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2511                               gmx_bool bClearF, pme_atomcomm_t *atc,
2512                               splinedata_t *spline,
2513                               real scale)
2514 {
2515     /* sum forces for local particles */
2516     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2517     int     index_x, index_xy;
2518     int     nx, ny, nz, pnx, pny, pnz;
2519     int *   idxptr;
2520     real    tx, ty, dx, dy, coefficient;
2521     real    fx, fy, fz, gval;
2522     real    fxy1, fz1;
2523     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2524     int     norder;
2525     real    rxx, ryx, ryy, rzx, rzy, rzz;
2526     int     order;
2527
2528     pme_spline_work_t *work;
2529
2530 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2531     real           thz_buffer[GMX_SIMD4_WIDTH*3],  *thz_aligned;
2532     real           dthz_buffer[GMX_SIMD4_WIDTH*3], *dthz_aligned;
2533
2534     thz_aligned  = gmx_simd4_align_r(thz_buffer);
2535     dthz_aligned = gmx_simd4_align_r(dthz_buffer);
2536 #endif
2537
2538     work = pme->spline_work;
2539
2540     order = pme->pme_order;
2541     thx   = spline->theta[XX];
2542     thy   = spline->theta[YY];
2543     thz   = spline->theta[ZZ];
2544     dthx  = spline->dtheta[XX];
2545     dthy  = spline->dtheta[YY];
2546     dthz  = spline->dtheta[ZZ];
2547     nx    = pme->nkx;
2548     ny    = pme->nky;
2549     nz    = pme->nkz;
2550     pnx   = pme->pmegrid_nx;
2551     pny   = pme->pmegrid_ny;
2552     pnz   = pme->pmegrid_nz;
2553
2554     rxx   = pme->recipbox[XX][XX];
2555     ryx   = pme->recipbox[YY][XX];
2556     ryy   = pme->recipbox[YY][YY];
2557     rzx   = pme->recipbox[ZZ][XX];
2558     rzy   = pme->recipbox[ZZ][YY];
2559     rzz   = pme->recipbox[ZZ][ZZ];
2560
2561     for (nn = 0; nn < spline->n; nn++)
2562     {
2563         n           = spline->ind[nn];
2564         coefficient = scale*atc->coefficient[n];
2565
2566         if (bClearF)
2567         {
2568             atc->f[n][XX] = 0;
2569             atc->f[n][YY] = 0;
2570             atc->f[n][ZZ] = 0;
2571         }
2572         if (coefficient != 0)
2573         {
2574             fx     = 0;
2575             fy     = 0;
2576             fz     = 0;
2577             idxptr = atc->idx[n];
2578             norder = nn*order;
2579
2580             i0   = idxptr[XX];
2581             j0   = idxptr[YY];
2582             k0   = idxptr[ZZ];
2583
2584             /* Pointer arithmetic alert, next six statements */
2585             thx  = spline->theta[XX] + norder;
2586             thy  = spline->theta[YY] + norder;
2587             thz  = spline->theta[ZZ] + norder;
2588             dthx = spline->dtheta[XX] + norder;
2589             dthy = spline->dtheta[YY] + norder;
2590             dthz = spline->dtheta[ZZ] + norder;
2591
2592             switch (order)
2593             {
2594                 case 4:
2595 #ifdef PME_SIMD4_SPREAD_GATHER
2596 #ifdef PME_SIMD4_UNALIGNED
2597 #define PME_GATHER_F_SIMD4_ORDER4
2598 #else
2599 #define PME_GATHER_F_SIMD4_ALIGNED
2600 #define PME_ORDER 4
2601 #endif
2602 #include "gromacs/mdlib/pme_simd4.h"
2603 #else
2604                     DO_FSPLINE(4);
2605 #endif
2606                     break;
2607                 case 5:
2608 #ifdef PME_SIMD4_SPREAD_GATHER
2609 #define PME_GATHER_F_SIMD4_ALIGNED
2610 #define PME_ORDER 5
2611 #include "gromacs/mdlib/pme_simd4.h"
2612 #else
2613                     DO_FSPLINE(5);
2614 #endif
2615                     break;
2616                 default:
2617                     DO_FSPLINE(order);
2618                     break;
2619             }
2620
2621             atc->f[n][XX] += -coefficient*( fx*nx*rxx );
2622             atc->f[n][YY] += -coefficient*( fx*nx*ryx + fy*ny*ryy );
2623             atc->f[n][ZZ] += -coefficient*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2624         }
2625     }
2626     /* Since the energy and not forces are interpolated
2627      * the net force might not be exactly zero.
2628      * This can be solved by also interpolating F, but
2629      * that comes at a cost.
2630      * A better hack is to remove the net force every
2631      * step, but that must be done at a higher level
2632      * since this routine doesn't see all atoms if running
2633      * in parallel. Don't know how important it is?  EL 990726
2634      */
2635 }
2636
2637
2638 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2639                                    pme_atomcomm_t *atc)
2640 {
2641     splinedata_t *spline;
2642     int     n, ithx, ithy, ithz, i0, j0, k0;
2643     int     index_x, index_xy;
2644     int *   idxptr;
2645     real    energy, pot, tx, ty, coefficient, gval;
2646     real    *thx, *thy, *thz;
2647     int     norder;
2648     int     order;
2649
2650     spline = &atc->spline[0];
2651
2652     order = pme->pme_order;
2653
2654     energy = 0;
2655     for (n = 0; (n < atc->n); n++)
2656     {
2657         coefficient      = atc->coefficient[n];
2658
2659         if (coefficient != 0)
2660         {
2661             idxptr = atc->idx[n];
2662             norder = n*order;
2663
2664             i0   = idxptr[XX];
2665             j0   = idxptr[YY];
2666             k0   = idxptr[ZZ];
2667
2668             /* Pointer arithmetic alert, next three statements */
2669             thx  = spline->theta[XX] + norder;
2670             thy  = spline->theta[YY] + norder;
2671             thz  = spline->theta[ZZ] + norder;
2672
2673             pot = 0;
2674             for (ithx = 0; (ithx < order); ithx++)
2675             {
2676                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2677                 tx      = thx[ithx];
2678
2679                 for (ithy = 0; (ithy < order); ithy++)
2680                 {
2681                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2682                     ty       = thy[ithy];
2683
2684                     for (ithz = 0; (ithz < order); ithz++)
2685                     {
2686                         gval  = grid[index_xy+(k0+ithz)];
2687                         pot  += tx*ty*thz[ithz]*gval;
2688                     }
2689
2690                 }
2691             }
2692
2693             energy += pot*coefficient;
2694         }
2695     }
2696
2697     return energy;
2698 }
2699
2700 /* Macro to force loop unrolling by fixing order.
2701  * This gives a significant performance gain.
2702  */
2703 #define CALC_SPLINE(order)                     \
2704     {                                              \
2705         int j, k, l;                                 \
2706         real dr, div;                               \
2707         real data[PME_ORDER_MAX];                  \
2708         real ddata[PME_ORDER_MAX];                 \
2709                                                \
2710         for (j = 0; (j < DIM); j++)                     \
2711         {                                          \
2712             dr  = xptr[j];                         \
2713                                                \
2714             /* dr is relative offset from lower cell limit */ \
2715             data[order-1] = 0;                     \
2716             data[1]       = dr;                          \
2717             data[0]       = 1 - dr;                      \
2718                                                \
2719             for (k = 3; (k < order); k++)               \
2720             {                                      \
2721                 div       = 1.0/(k - 1.0);               \
2722                 data[k-1] = div*dr*data[k-2];      \
2723                 for (l = 1; (l < (k-1)); l++)           \
2724                 {                                  \
2725                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2726                                        data[k-l-1]);                \
2727                 }                                  \
2728                 data[0] = div*(1-dr)*data[0];      \
2729             }                                      \
2730             /* differentiate */                    \
2731             ddata[0] = -data[0];                   \
2732             for (k = 1; (k < order); k++)               \
2733             {                                      \
2734                 ddata[k] = data[k-1] - data[k];    \
2735             }                                      \
2736                                                \
2737             div           = 1.0/(order - 1);                 \
2738             data[order-1] = div*dr*data[order-2];  \
2739             for (l = 1; (l < (order-1)); l++)           \
2740             {                                      \
2741                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2742                                        (order-l-dr)*data[order-l-1]); \
2743             }                                      \
2744             data[0] = div*(1 - dr)*data[0];        \
2745                                                \
2746             for (k = 0; k < order; k++)                 \
2747             {                                      \
2748                 theta[j][i*order+k]  = data[k];    \
2749                 dtheta[j][i*order+k] = ddata[k];   \
2750             }                                      \
2751         }                                          \
2752     }
2753
2754 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2755                    rvec fractx[], int nr, int ind[], real coefficient[],
2756                    gmx_bool bDoSplines)
2757 {
2758     /* construct splines for local atoms */
2759     int  i, ii;
2760     real *xptr;
2761
2762     for (i = 0; i < nr; i++)
2763     {
2764         /* With free energy we do not use the coefficient check.
2765          * In most cases this will be more efficient than calling make_bsplines
2766          * twice, since usually more than half the particles have non-zero coefficients.
2767          */
2768         ii = ind[i];
2769         if (bDoSplines || coefficient[ii] != 0.0)
2770         {
2771             xptr = fractx[ii];
2772             switch (order)
2773             {
2774                 case 4:  CALC_SPLINE(4);     break;
2775                 case 5:  CALC_SPLINE(5);     break;
2776                 default: CALC_SPLINE(order); break;
2777             }
2778         }
2779     }
2780 }
2781
2782
2783 void make_dft_mod(real *mod, real *data, int ndata)
2784 {
2785     int i, j;
2786     real sc, ss, arg;
2787
2788     for (i = 0; i < ndata; i++)
2789     {
2790         sc = ss = 0;
2791         for (j = 0; j < ndata; j++)
2792         {
2793             arg = (2.0*M_PI*i*j)/ndata;
2794             sc += data[j]*cos(arg);
2795             ss += data[j]*sin(arg);
2796         }
2797         mod[i] = sc*sc+ss*ss;
2798     }
2799     for (i = 0; i < ndata; i++)
2800     {
2801         if (mod[i] < 1e-7)
2802         {
2803             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2804         }
2805     }
2806 }
2807
2808
2809 static void make_bspline_moduli(splinevec bsp_mod,
2810                                 int nx, int ny, int nz, int order)
2811 {
2812     int nmax = max(nx, max(ny, nz));
2813     real *data, *ddata, *bsp_data;
2814     int i, k, l;
2815     real div;
2816
2817     snew(data, order);
2818     snew(ddata, order);
2819     snew(bsp_data, nmax);
2820
2821     data[order-1] = 0;
2822     data[1]       = 0;
2823     data[0]       = 1;
2824
2825     for (k = 3; k < order; k++)
2826     {
2827         div       = 1.0/(k-1.0);
2828         data[k-1] = 0;
2829         for (l = 1; l < (k-1); l++)
2830         {
2831             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2832         }
2833         data[0] = div*data[0];
2834     }
2835     /* differentiate */
2836     ddata[0] = -data[0];
2837     for (k = 1; k < order; k++)
2838     {
2839         ddata[k] = data[k-1]-data[k];
2840     }
2841     div           = 1.0/(order-1);
2842     data[order-1] = 0;
2843     for (l = 1; l < (order-1); l++)
2844     {
2845         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2846     }
2847     data[0] = div*data[0];
2848
2849     for (i = 0; i < nmax; i++)
2850     {
2851         bsp_data[i] = 0;
2852     }
2853     for (i = 1; i <= order; i++)
2854     {
2855         bsp_data[i] = data[i-1];
2856     }
2857
2858     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2859     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2860     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2861
2862     sfree(data);
2863     sfree(ddata);
2864     sfree(bsp_data);
2865 }
2866
2867
2868 /* Return the P3M optimal influence function */
2869 static double do_p3m_influence(double z, int order)
2870 {
2871     double z2, z4;
2872
2873     z2 = z*z;
2874     z4 = z2*z2;
2875
2876     /* The formula and most constants can be found in:
2877      * Ballenegger et al., JCTC 8, 936 (2012)
2878      */
2879     switch (order)
2880     {
2881         case 2:
2882             return 1.0 - 2.0*z2/3.0;
2883             break;
2884         case 3:
2885             return 1.0 - z2 + 2.0*z4/15.0;
2886             break;
2887         case 4:
2888             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2889             break;
2890         case 5:
2891             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2892             break;
2893         case 6:
2894             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2895             break;
2896         case 7:
2897             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2898         case 8:
2899             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2900             break;
2901     }
2902
2903     return 0.0;
2904 }
2905
2906 /* Calculate the P3M B-spline moduli for one dimension */
2907 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2908 {
2909     double zarg, zai, sinzai, infl;
2910     int    maxk, i;
2911
2912     if (order > 8)
2913     {
2914         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2915     }
2916
2917     zarg = M_PI/n;
2918
2919     maxk = (n + 1)/2;
2920
2921     for (i = -maxk; i < 0; i++)
2922     {
2923         zai          = zarg*i;
2924         sinzai       = sin(zai);
2925         infl         = do_p3m_influence(sinzai, order);
2926         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2927     }
2928     bsp_mod[0] = 1.0;
2929     for (i = 1; i < maxk; i++)
2930     {
2931         zai        = zarg*i;
2932         sinzai     = sin(zai);
2933         infl       = do_p3m_influence(sinzai, order);
2934         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2935     }
2936 }
2937
2938 /* Calculate the P3M B-spline moduli */
2939 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2940                                     int nx, int ny, int nz, int order)
2941 {
2942     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2943     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2944     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2945 }
2946
2947
2948 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2949 {
2950     int nslab, n, i;
2951     int fw, bw;
2952
2953     nslab = atc->nslab;
2954
2955     n = 0;
2956     for (i = 1; i <= nslab/2; i++)
2957     {
2958         fw = (atc->nodeid + i) % nslab;
2959         bw = (atc->nodeid - i + nslab) % nslab;
2960         if (n < nslab - 1)
2961         {
2962             atc->node_dest[n] = fw;
2963             atc->node_src[n]  = bw;
2964             n++;
2965         }
2966         if (n < nslab - 1)
2967         {
2968             atc->node_dest[n] = bw;
2969             atc->node_src[n]  = fw;
2970             n++;
2971         }
2972     }
2973 }
2974
2975 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2976 {
2977     int thread, i;
2978
2979     if (NULL != log)
2980     {
2981         fprintf(log, "Destroying PME data structures.\n");
2982     }
2983
2984     sfree((*pmedata)->nnx);
2985     sfree((*pmedata)->nny);
2986     sfree((*pmedata)->nnz);
2987
2988     for (i = 0; i < (*pmedata)->ngrids; ++i)
2989     {
2990         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
2991         sfree((*pmedata)->fftgrid[i]);
2992         sfree((*pmedata)->cfftgrid[i]);
2993         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
2994     }
2995
2996     sfree((*pmedata)->lb_buf1);
2997     sfree((*pmedata)->lb_buf2);
2998
2999     for (thread = 0; thread < (*pmedata)->nthread; thread++)
3000     {
3001         free_work(&(*pmedata)->work[thread]);
3002     }
3003     sfree((*pmedata)->work);
3004
3005     sfree(*pmedata);
3006     *pmedata = NULL;
3007
3008     return 0;
3009 }
3010
3011 static int mult_up(int n, int f)
3012 {
3013     return ((n + f - 1)/f)*f;
3014 }
3015
3016
3017 static double pme_load_imbalance(gmx_pme_t pme)
3018 {
3019     int    nma, nmi;
3020     double n1, n2, n3;
3021
3022     nma = pme->nnodes_major;
3023     nmi = pme->nnodes_minor;
3024
3025     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3026     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3027     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3028
3029     /* pme_solve is roughly double the cost of an fft */
3030
3031     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3032 }
3033
3034 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3035                           int dimind, gmx_bool bSpread)
3036 {
3037     int nk, k, s, thread;
3038
3039     atc->dimind    = dimind;
3040     atc->nslab     = 1;
3041     atc->nodeid    = 0;
3042     atc->pd_nalloc = 0;
3043 #ifdef GMX_MPI
3044     if (pme->nnodes > 1)
3045     {
3046         atc->mpi_comm = pme->mpi_comm_d[dimind];
3047         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3048         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3049     }
3050     if (debug)
3051     {
3052         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3053     }
3054 #endif
3055
3056     atc->bSpread   = bSpread;
3057     atc->pme_order = pme->pme_order;
3058
3059     if (atc->nslab > 1)
3060     {
3061         snew(atc->node_dest, atc->nslab);
3062         snew(atc->node_src, atc->nslab);
3063         setup_coordinate_communication(atc);
3064
3065         snew(atc->count_thread, pme->nthread);
3066         for (thread = 0; thread < pme->nthread; thread++)
3067         {
3068             snew(atc->count_thread[thread], atc->nslab);
3069         }
3070         atc->count = atc->count_thread[0];
3071         snew(atc->rcount, atc->nslab);
3072         snew(atc->buf_index, atc->nslab);
3073     }
3074
3075     atc->nthread = pme->nthread;
3076     if (atc->nthread > 1)
3077     {
3078         snew(atc->thread_plist, atc->nthread);
3079     }
3080     snew(atc->spline, atc->nthread);
3081     for (thread = 0; thread < atc->nthread; thread++)
3082     {
3083         if (atc->nthread > 1)
3084         {
3085             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3086             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3087         }
3088         snew(atc->spline[thread].thread_one, pme->nthread);
3089         atc->spline[thread].thread_one[thread] = 1;
3090     }
3091 }
3092
3093 static void
3094 init_overlap_comm(pme_overlap_t *  ol,
3095                   int              norder,
3096 #ifdef GMX_MPI
3097                   MPI_Comm         comm,
3098 #endif
3099                   int              nnodes,
3100                   int              nodeid,
3101                   int              ndata,
3102                   int              commplainsize)
3103 {
3104     int lbnd, rbnd, maxlr, b, i;
3105     int exten;
3106     int nn, nk;
3107     pme_grid_comm_t *pgc;
3108     gmx_bool bCont;
3109     int fft_start, fft_end, send_index1, recv_index1;
3110 #ifdef GMX_MPI
3111     MPI_Status stat;
3112
3113     ol->mpi_comm = comm;
3114 #endif
3115
3116     ol->nnodes = nnodes;
3117     ol->nodeid = nodeid;
3118
3119     /* Linear translation of the PME grid won't affect reciprocal space
3120      * calculations, so to optimize we only interpolate "upwards",
3121      * which also means we only have to consider overlap in one direction.
3122      * I.e., particles on this node might also be spread to grid indices
3123      * that belong to higher nodes (modulo nnodes)
3124      */
3125
3126     snew(ol->s2g0, ol->nnodes+1);
3127     snew(ol->s2g1, ol->nnodes);
3128     if (debug)
3129     {
3130         fprintf(debug, "PME slab boundaries:");
3131     }
3132     for (i = 0; i < nnodes; i++)
3133     {
3134         /* s2g0 the local interpolation grid start.
3135          * s2g1 the local interpolation grid end.
3136          * Since in calc_pidx we divide particles, and not grid lines,
3137          * spatially uniform along dimension x or y, we need to round
3138          * s2g0 down and s2g1 up.
3139          */
3140         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3141         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3142
3143         if (debug)
3144         {
3145             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3146         }
3147     }
3148     ol->s2g0[nnodes] = ndata;
3149     if (debug)
3150     {
3151         fprintf(debug, "\n");
3152     }
3153
3154     /* Determine with how many nodes we need to communicate the grid overlap */
3155     b = 0;
3156     do
3157     {
3158         b++;
3159         bCont = FALSE;
3160         for (i = 0; i < nnodes; i++)
3161         {
3162             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3163                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3164             {
3165                 bCont = TRUE;
3166             }
3167         }
3168     }
3169     while (bCont && b < nnodes);
3170     ol->noverlap_nodes = b - 1;
3171
3172     snew(ol->send_id, ol->noverlap_nodes);
3173     snew(ol->recv_id, ol->noverlap_nodes);
3174     for (b = 0; b < ol->noverlap_nodes; b++)
3175     {
3176         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3177         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3178     }
3179     snew(ol->comm_data, ol->noverlap_nodes);
3180
3181     ol->send_size = 0;
3182     for (b = 0; b < ol->noverlap_nodes; b++)
3183     {
3184         pgc = &ol->comm_data[b];
3185         /* Send */
3186         fft_start        = ol->s2g0[ol->send_id[b]];
3187         fft_end          = ol->s2g0[ol->send_id[b]+1];
3188         if (ol->send_id[b] < nodeid)
3189         {
3190             fft_start += ndata;
3191             fft_end   += ndata;
3192         }
3193         send_index1       = ol->s2g1[nodeid];
3194         send_index1       = min(send_index1, fft_end);
3195         pgc->send_index0  = fft_start;
3196         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3197         ol->send_size    += pgc->send_nindex;
3198
3199         /* We always start receiving to the first index of our slab */
3200         fft_start        = ol->s2g0[ol->nodeid];
3201         fft_end          = ol->s2g0[ol->nodeid+1];
3202         recv_index1      = ol->s2g1[ol->recv_id[b]];
3203         if (ol->recv_id[b] > nodeid)
3204         {
3205             recv_index1 -= ndata;
3206         }
3207         recv_index1      = min(recv_index1, fft_end);
3208         pgc->recv_index0 = fft_start;
3209         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3210     }
3211
3212 #ifdef GMX_MPI
3213     /* Communicate the buffer sizes to receive */
3214     for (b = 0; b < ol->noverlap_nodes; b++)
3215     {
3216         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3217                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3218                      ol->mpi_comm, &stat);
3219     }
3220 #endif
3221
3222     /* For non-divisible grid we need pme_order iso pme_order-1 */
3223     snew(ol->sendbuf, norder*commplainsize);
3224     snew(ol->recvbuf, norder*commplainsize);
3225 }
3226
3227 static void
3228 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3229                               int **global_to_local,
3230                               real **fraction_shift)
3231 {
3232     int i;
3233     int * gtl;
3234     real * fsh;
3235
3236     snew(gtl, 5*n);
3237     snew(fsh, 5*n);
3238     for (i = 0; (i < 5*n); i++)
3239     {
3240         /* Determine the global to local grid index */
3241         gtl[i] = (i - local_start + n) % n;
3242         /* For coordinates that fall within the local grid the fraction
3243          * is correct, we don't need to shift it.
3244          */
3245         fsh[i] = 0;
3246         if (local_range < n)
3247         {
3248             /* Due to rounding issues i could be 1 beyond the lower or
3249              * upper boundary of the local grid. Correct the index for this.
3250              * If we shift the index, we need to shift the fraction by
3251              * the same amount in the other direction to not affect
3252              * the weights.
3253              * Note that due to this shifting the weights at the end of
3254              * the spline might change, but that will only involve values
3255              * between zero and values close to the precision of a real,
3256              * which is anyhow the accuracy of the whole mesh calculation.
3257              */
3258             /* With local_range=0 we should not change i=local_start */
3259             if (i % n != local_start)
3260             {
3261                 if (gtl[i] == n-1)
3262                 {
3263                     gtl[i] = 0;
3264                     fsh[i] = -1;
3265                 }
3266                 else if (gtl[i] == local_range)
3267                 {
3268                     gtl[i] = local_range - 1;
3269                     fsh[i] = 1;
3270                 }
3271             }
3272         }
3273     }
3274
3275     *global_to_local = gtl;
3276     *fraction_shift  = fsh;
3277 }
3278
3279 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3280 {
3281     pme_spline_work_t *work;
3282
3283 #ifdef PME_SIMD4_SPREAD_GATHER
3284     real             tmp[GMX_SIMD4_WIDTH*3], *tmp_aligned;
3285     gmx_simd4_real_t zero_S;
3286     gmx_simd4_real_t real_mask_S0, real_mask_S1;
3287     int              of, i;
3288
3289     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3290
3291     tmp_aligned = gmx_simd4_align_r(tmp);
3292
3293     zero_S = gmx_simd4_setzero_r();
3294
3295     /* Generate bit masks to mask out the unused grid entries,
3296      * as we only operate on order of the 8 grid entries that are
3297      * load into 2 SIMD registers.
3298      */
3299     for (of = 0; of < 2*GMX_SIMD4_WIDTH-(order-1); of++)
3300     {
3301         for (i = 0; i < 2*GMX_SIMD4_WIDTH; i++)
3302         {
3303             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3304         }
3305         real_mask_S0      = gmx_simd4_load_r(tmp_aligned);
3306         real_mask_S1      = gmx_simd4_load_r(tmp_aligned+GMX_SIMD4_WIDTH);
3307         work->mask_S0[of] = gmx_simd4_cmplt_r(real_mask_S0, zero_S);
3308         work->mask_S1[of] = gmx_simd4_cmplt_r(real_mask_S1, zero_S);
3309     }
3310 #else
3311     work = NULL;
3312 #endif
3313
3314     return work;
3315 }
3316
3317 void gmx_pme_check_restrictions(int pme_order,
3318                                 int nkx, int nky, int nkz,
3319                                 int nnodes_major,
3320                                 int nnodes_minor,
3321                                 gmx_bool bUseThreads,
3322                                 gmx_bool bFatal,
3323                                 gmx_bool *bValidSettings)
3324 {
3325     if (pme_order > PME_ORDER_MAX)
3326     {
3327         if (!bFatal)
3328         {
3329             *bValidSettings = FALSE;
3330             return;
3331         }
3332         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3333                   pme_order, PME_ORDER_MAX);
3334     }
3335
3336     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3337         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3338         nkz <= pme_order)
3339     {
3340         if (!bFatal)
3341         {
3342             *bValidSettings = FALSE;
3343             return;
3344         }
3345         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3346                   pme_order);
3347     }
3348
3349     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3350      * We only allow multiple communication pulses in dim 1, not in dim 0.
3351      */
3352     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3353                         nkx != nnodes_major*(pme_order - 1)))
3354     {
3355         if (!bFatal)
3356         {
3357             *bValidSettings = FALSE;
3358             return;
3359         }
3360         gmx_fatal(FARGS, "The number of PME grid lines per rank along x is %g. But when using OpenMP threads, the number of grid lines per rank along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use fewer ranks along x (and possibly more along y and/or z) by specifying -dd manually.",
3361                   nkx/(double)nnodes_major, pme_order);
3362     }
3363
3364     if (bValidSettings != NULL)
3365     {
3366         *bValidSettings = TRUE;
3367     }
3368
3369     return;
3370 }
3371
3372 int gmx_pme_init(gmx_pme_t *         pmedata,
3373                  t_commrec *         cr,
3374                  int                 nnodes_major,
3375                  int                 nnodes_minor,
3376                  t_inputrec *        ir,
3377                  int                 homenr,
3378                  gmx_bool            bFreeEnergy_q,
3379                  gmx_bool            bFreeEnergy_lj,
3380                  gmx_bool            bReproducible,
3381                  int                 nthread)
3382 {
3383     gmx_pme_t pme = NULL;
3384
3385     int  use_threads, sum_use_threads, i;
3386     ivec ndata;
3387
3388     if (debug)
3389     {
3390         fprintf(debug, "Creating PME data structures.\n");
3391     }
3392     snew(pme, 1);
3393
3394     pme->sum_qgrid_tmp       = NULL;
3395     pme->sum_qgrid_dd_tmp    = NULL;
3396     pme->buf_nalloc          = 0;
3397
3398     pme->nnodes              = 1;
3399     pme->bPPnode             = TRUE;
3400
3401     pme->nnodes_major        = nnodes_major;
3402     pme->nnodes_minor        = nnodes_minor;
3403
3404 #ifdef GMX_MPI
3405     if (nnodes_major*nnodes_minor > 1)
3406     {
3407         pme->mpi_comm = cr->mpi_comm_mygroup;
3408
3409         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3410         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3411         if (pme->nnodes != nnodes_major*nnodes_minor)
3412         {
3413             gmx_incons("PME rank count mismatch");
3414         }
3415     }
3416     else
3417     {
3418         pme->mpi_comm = MPI_COMM_NULL;
3419     }
3420 #endif
3421
3422     if (pme->nnodes == 1)
3423     {
3424 #ifdef GMX_MPI
3425         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3426         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3427 #endif
3428         pme->ndecompdim   = 0;
3429         pme->nodeid_major = 0;
3430         pme->nodeid_minor = 0;
3431 #ifdef GMX_MPI
3432         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3433 #endif
3434     }
3435     else
3436     {
3437         if (nnodes_minor == 1)
3438         {
3439 #ifdef GMX_MPI
3440             pme->mpi_comm_d[0] = pme->mpi_comm;
3441             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3442 #endif
3443             pme->ndecompdim   = 1;
3444             pme->nodeid_major = pme->nodeid;
3445             pme->nodeid_minor = 0;
3446
3447         }
3448         else if (nnodes_major == 1)
3449         {
3450 #ifdef GMX_MPI
3451             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3452             pme->mpi_comm_d[1] = pme->mpi_comm;
3453 #endif
3454             pme->ndecompdim   = 1;
3455             pme->nodeid_major = 0;
3456             pme->nodeid_minor = pme->nodeid;
3457         }
3458         else
3459         {
3460             if (pme->nnodes % nnodes_major != 0)
3461             {
3462                 gmx_incons("For 2D PME decomposition, #PME ranks must be divisible by the number of ranks in the major dimension");
3463             }
3464             pme->ndecompdim = 2;
3465
3466 #ifdef GMX_MPI
3467             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3468                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3469             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3470                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3471
3472             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3473             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3474             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3475             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3476 #endif
3477         }
3478         pme->bPPnode = (cr->duty & DUTY_PP);
3479     }
3480
3481     pme->nthread = nthread;
3482
3483     /* Check if any of the PME MPI ranks uses threads */
3484     use_threads = (pme->nthread > 1 ? 1 : 0);
3485 #ifdef GMX_MPI
3486     if (pme->nnodes > 1)
3487     {
3488         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3489                       MPI_SUM, pme->mpi_comm);
3490     }
3491     else
3492 #endif
3493     {
3494         sum_use_threads = use_threads;
3495     }
3496     pme->bUseThreads = (sum_use_threads > 0);
3497
3498     if (ir->ePBC == epbcSCREW)
3499     {
3500         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3501     }
3502
3503     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3504     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3505     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3506     pme->nkx         = ir->nkx;
3507     pme->nky         = ir->nky;
3508     pme->nkz         = ir->nkz;
3509     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3510     pme->pme_order   = ir->pme_order;
3511
3512     /* Always constant electrostatics coefficients */
3513     pme->epsilon_r   = ir->epsilon_r;
3514
3515     /* Always constant LJ coefficients */
3516     pme->ljpme_combination_rule = ir->ljpme_combination_rule;
3517
3518     /* If we violate restrictions, generate a fatal error here */
3519     gmx_pme_check_restrictions(pme->pme_order,
3520                                pme->nkx, pme->nky, pme->nkz,
3521                                pme->nnodes_major,
3522                                pme->nnodes_minor,
3523                                pme->bUseThreads,
3524                                TRUE,
3525                                NULL);
3526
3527     if (pme->nnodes > 1)
3528     {
3529         double imbal;
3530
3531 #ifdef GMX_MPI
3532         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3533         MPI_Type_commit(&(pme->rvec_mpi));
3534 #endif
3535
3536         /* Note that the coefficient spreading and force gathering, which usually
3537          * takes about the same amount of time as FFT+solve_pme,
3538          * is always fully load balanced
3539          * (unless the coefficient distribution is inhomogeneous).
3540          */
3541
3542         imbal = pme_load_imbalance(pme);
3543         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3544         {
3545             fprintf(stderr,
3546                     "\n"
3547                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3548                     "      For optimal PME load balancing\n"
3549                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_ranks_x (%d)\n"
3550                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_ranks_y (%d)\n"
3551                     "\n",
3552                     (int)((imbal-1)*100 + 0.5),
3553                     pme->nkx, pme->nky, pme->nnodes_major,
3554                     pme->nky, pme->nkz, pme->nnodes_minor);
3555         }
3556     }
3557
3558     /* For non-divisible grid we need pme_order iso pme_order-1 */
3559     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3560      * y is always copied through a buffer: we don't need padding in z,
3561      * but we do need the overlap in x because of the communication order.
3562      */
3563     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3564 #ifdef GMX_MPI
3565                       pme->mpi_comm_d[0],
3566 #endif
3567                       pme->nnodes_major, pme->nodeid_major,
3568                       pme->nkx,
3569                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3570
3571     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3572      * We do this with an offset buffer of equal size, so we need to allocate
3573      * extra for the offset. That's what the (+1)*pme->nkz is for.
3574      */
3575     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3576 #ifdef GMX_MPI
3577                       pme->mpi_comm_d[1],
3578 #endif
3579                       pme->nnodes_minor, pme->nodeid_minor,
3580                       pme->nky,
3581                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3582
3583     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3584      * Note that gmx_pme_check_restrictions checked for this already.
3585      */
3586     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3587     {
3588         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3589     }
3590
3591     snew(pme->bsp_mod[XX], pme->nkx);
3592     snew(pme->bsp_mod[YY], pme->nky);
3593     snew(pme->bsp_mod[ZZ], pme->nkz);
3594
3595     /* The required size of the interpolation grid, including overlap.
3596      * The allocated size (pmegrid_n?) might be slightly larger.
3597      */
3598     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3599         pme->overlap[0].s2g0[pme->nodeid_major];
3600     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3601         pme->overlap[1].s2g0[pme->nodeid_minor];
3602     pme->pmegrid_nz_base = pme->nkz;
3603     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3604     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3605
3606     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3607     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3608     pme->pmegrid_start_iz = 0;
3609
3610     make_gridindex5_to_localindex(pme->nkx,
3611                                   pme->pmegrid_start_ix,
3612                                   pme->pmegrid_nx - (pme->pme_order-1),
3613                                   &pme->nnx, &pme->fshx);
3614     make_gridindex5_to_localindex(pme->nky,
3615                                   pme->pmegrid_start_iy,
3616                                   pme->pmegrid_ny - (pme->pme_order-1),
3617                                   &pme->nny, &pme->fshy);
3618     make_gridindex5_to_localindex(pme->nkz,
3619                                   pme->pmegrid_start_iz,
3620                                   pme->pmegrid_nz_base,
3621                                   &pme->nnz, &pme->fshz);
3622
3623     pme->spline_work = make_pme_spline_work(pme->pme_order);
3624
3625     ndata[0]    = pme->nkx;
3626     ndata[1]    = pme->nky;
3627     ndata[2]    = pme->nkz;
3628     /* It doesn't matter if we allocate too many grids here,
3629      * we only allocate and use the ones we need.
3630      */
3631     if (EVDW_PME(ir->vdwtype))
3632     {
3633         pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3634     }
3635     else
3636     {
3637         pme->ngrids = DO_Q;
3638     }
3639     snew(pme->fftgrid, pme->ngrids);
3640     snew(pme->cfftgrid, pme->ngrids);
3641     snew(pme->pfft_setup, pme->ngrids);
3642
3643     for (i = 0; i < pme->ngrids; ++i)
3644     {
3645         if ((i <  DO_Q && EEL_PME(ir->coulombtype) && (i == 0 ||
3646                                                        bFreeEnergy_q)) ||
3647             (i >= DO_Q && EVDW_PME(ir->vdwtype) && (i == 2 ||
3648                                                     bFreeEnergy_lj ||
3649                                                     ir->ljpme_combination_rule == eljpmeLB)))
3650         {
3651             pmegrids_init(&pme->pmegrid[i],
3652                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3653                           pme->pmegrid_nz_base,
3654                           pme->pme_order,
3655                           pme->bUseThreads,
3656                           pme->nthread,
3657                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3658                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3659             /* This routine will allocate the grid data to fit the FFTs */
3660             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3661                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3662                                     pme->mpi_comm_d,
3663                                     bReproducible, pme->nthread);
3664
3665         }
3666     }
3667
3668     if (!pme->bP3M)
3669     {
3670         /* Use plain SPME B-spline interpolation */
3671         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3672     }
3673     else
3674     {
3675         /* Use the P3M grid-optimized influence function */
3676         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3677     }
3678
3679     /* Use atc[0] for spreading */
3680     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3681     if (pme->ndecompdim >= 2)
3682     {
3683         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3684     }
3685
3686     if (pme->nnodes == 1)
3687     {
3688         pme->atc[0].n = homenr;
3689         pme_realloc_atomcomm_things(&pme->atc[0]);
3690     }
3691
3692     pme->lb_buf1       = NULL;
3693     pme->lb_buf2       = NULL;
3694     pme->lb_buf_nalloc = 0;
3695
3696     {
3697         int thread;
3698
3699         /* Use fft5d, order after FFT is y major, z, x minor */
3700
3701         snew(pme->work, pme->nthread);
3702         for (thread = 0; thread < pme->nthread; thread++)
3703         {
3704             realloc_work(&pme->work[thread], pme->nkx);
3705         }
3706     }
3707
3708     *pmedata = pme;
3709
3710     return 0;
3711 }
3712
3713 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3714 {
3715     int d, t;
3716
3717     for (d = 0; d < DIM; d++)
3718     {
3719         if (new->grid.n[d] > old->grid.n[d])
3720         {
3721             return;
3722         }
3723     }
3724
3725     sfree_aligned(new->grid.grid);
3726     new->grid.grid = old->grid.grid;
3727
3728     if (new->grid_th != NULL && new->nthread == old->nthread)
3729     {
3730         sfree_aligned(new->grid_all);
3731         for (t = 0; t < new->nthread; t++)
3732         {
3733             new->grid_th[t].grid = old->grid_th[t].grid;
3734         }
3735     }
3736 }
3737
3738 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3739                    t_commrec *         cr,
3740                    gmx_pme_t           pme_src,
3741                    const t_inputrec *  ir,
3742                    ivec                grid_size)
3743 {
3744     t_inputrec irc;
3745     int homenr;
3746     int ret;
3747
3748     irc     = *ir;
3749     irc.nkx = grid_size[XX];
3750     irc.nky = grid_size[YY];
3751     irc.nkz = grid_size[ZZ];
3752
3753     if (pme_src->nnodes == 1)
3754     {
3755         homenr = pme_src->atc[0].n;
3756     }
3757     else
3758     {
3759         homenr = -1;
3760     }
3761
3762     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3763                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3764
3765     if (ret == 0)
3766     {
3767         /* We can easily reuse the allocated pme grids in pme_src */
3768         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3769         /* We would like to reuse the fft grids, but that's harder */
3770     }
3771
3772     return ret;
3773 }
3774
3775
3776 static void copy_local_grid(gmx_pme_t pme, pmegrids_t *pmegrids,
3777                             int grid_index, int thread, real *fftgrid)
3778 {
3779     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3780     int  fft_my, fft_mz;
3781     int  nsx, nsy, nsz;
3782     ivec nf;
3783     int  offx, offy, offz, x, y, z, i0, i0t;
3784     int  d;
3785     pmegrid_t *pmegrid;
3786     real *grid_th;
3787
3788     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3789                                    local_fft_ndata,
3790                                    local_fft_offset,
3791                                    local_fft_size);
3792     fft_my = local_fft_size[YY];
3793     fft_mz = local_fft_size[ZZ];
3794
3795     pmegrid = &pmegrids->grid_th[thread];
3796
3797     nsx = pmegrid->s[XX];
3798     nsy = pmegrid->s[YY];
3799     nsz = pmegrid->s[ZZ];
3800
3801     for (d = 0; d < DIM; d++)
3802     {
3803         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3804                     local_fft_ndata[d] - pmegrid->offset[d]);
3805     }
3806
3807     offx = pmegrid->offset[XX];
3808     offy = pmegrid->offset[YY];
3809     offz = pmegrid->offset[ZZ];
3810
3811     /* Directly copy the non-overlapping parts of the local grids.
3812      * This also initializes the full grid.
3813      */
3814     grid_th = pmegrid->grid;
3815     for (x = 0; x < nf[XX]; x++)
3816     {
3817         for (y = 0; y < nf[YY]; y++)
3818         {
3819             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3820             i0t = (x*nsy + y)*nsz;
3821             for (z = 0; z < nf[ZZ]; z++)
3822             {
3823                 fftgrid[i0+z] = grid_th[i0t+z];
3824             }
3825         }
3826     }
3827 }
3828
3829 static void
3830 reduce_threadgrid_overlap(gmx_pme_t pme,
3831                           const pmegrids_t *pmegrids, int thread,
3832                           real *fftgrid, real *commbuf_x, real *commbuf_y,
3833                           int grid_index)
3834 {
3835     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3836     int  fft_nx, fft_ny, fft_nz;
3837     int  fft_my, fft_mz;
3838     int  buf_my = -1;
3839     int  nsx, nsy, nsz;
3840     ivec localcopy_end;
3841     int  offx, offy, offz, x, y, z, i0, i0t;
3842     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3843     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3844     gmx_bool bCommX, bCommY;
3845     int  d;
3846     int  thread_f;
3847     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3848     const real *grid_th;
3849     real *commbuf = NULL;
3850
3851     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3852                                    local_fft_ndata,
3853                                    local_fft_offset,
3854                                    local_fft_size);
3855     fft_nx = local_fft_ndata[XX];
3856     fft_ny = local_fft_ndata[YY];
3857     fft_nz = local_fft_ndata[ZZ];
3858
3859     fft_my = local_fft_size[YY];
3860     fft_mz = local_fft_size[ZZ];
3861
3862     /* This routine is called when all thread have finished spreading.
3863      * Here each thread sums grid contributions calculated by other threads
3864      * to the thread local grid volume.
3865      * To minimize the number of grid copying operations,
3866      * this routines sums immediately from the pmegrid to the fftgrid.
3867      */
3868
3869     /* Determine which part of the full node grid we should operate on,
3870      * this is our thread local part of the full grid.
3871      */
3872     pmegrid = &pmegrids->grid_th[thread];
3873
3874     for (d = 0; d < DIM; d++)
3875     {
3876         /* Determine up to where our thread needs to copy from the
3877          * thread-local charge spreading grid to the rank-local FFT grid.
3878          * This is up to our spreading grid end minus order-1 and
3879          * not beyond the local FFT grid.
3880          */
3881         localcopy_end[d] =
3882             min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3883                 local_fft_ndata[d]);
3884     }
3885
3886     offx = pmegrid->offset[XX];
3887     offy = pmegrid->offset[YY];
3888     offz = pmegrid->offset[ZZ];
3889
3890
3891     bClearBufX  = TRUE;
3892     bClearBufY  = TRUE;
3893     bClearBufXY = TRUE;
3894
3895     /* Now loop over all the thread data blocks that contribute
3896      * to the grid region we (our thread) are operating on.
3897      */
3898     /* Note that fft_nx/y is equal to the number of grid points
3899      * between the first point of our node grid and the one of the next node.
3900      */
3901     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3902     {
3903         fx     = pmegrid->ci[XX] + sx;
3904         ox     = 0;
3905         bCommX = FALSE;
3906         if (fx < 0)
3907         {
3908             fx    += pmegrids->nc[XX];
3909             ox    -= fft_nx;
3910             bCommX = (pme->nnodes_major > 1);
3911         }
3912         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3913         ox       += pmegrid_g->offset[XX];
3914         /* Determine the end of our part of the source grid */
3915         if (!bCommX)
3916         {
3917             /* Use our thread local source grid and target grid part */
3918             tx1 = min(ox + pmegrid_g->n[XX], localcopy_end[XX]);
3919         }
3920         else
3921         {
3922             /* Use our thread local source grid and the spreading range */
3923             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3924         }
3925
3926         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3927         {
3928             fy     = pmegrid->ci[YY] + sy;
3929             oy     = 0;
3930             bCommY = FALSE;
3931             if (fy < 0)
3932             {
3933                 fy    += pmegrids->nc[YY];
3934                 oy    -= fft_ny;
3935                 bCommY = (pme->nnodes_minor > 1);
3936             }
3937             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3938             oy       += pmegrid_g->offset[YY];
3939             /* Determine the end of our part of the source grid */
3940             if (!bCommY)
3941             {
3942                 /* Use our thread local source grid and target grid part */
3943                 ty1 = min(oy + pmegrid_g->n[YY], localcopy_end[YY]);
3944             }
3945             else
3946             {
3947                 /* Use our thread local source grid and the spreading range */
3948                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3949             }
3950
3951             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3952             {
3953                 fz = pmegrid->ci[ZZ] + sz;
3954                 oz = 0;
3955                 if (fz < 0)
3956                 {
3957                     fz += pmegrids->nc[ZZ];
3958                     oz -= fft_nz;
3959                 }
3960                 pmegrid_g = &pmegrids->grid_th[fz];
3961                 oz       += pmegrid_g->offset[ZZ];
3962                 tz1       = min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
3963
3964                 if (sx == 0 && sy == 0 && sz == 0)
3965                 {
3966                     /* We have already added our local contribution
3967                      * before calling this routine, so skip it here.
3968                      */
3969                     continue;
3970                 }
3971
3972                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3973
3974                 pmegrid_f = &pmegrids->grid_th[thread_f];
3975
3976                 grid_th = pmegrid_f->grid;
3977
3978                 nsx = pmegrid_f->s[XX];
3979                 nsy = pmegrid_f->s[YY];
3980                 nsz = pmegrid_f->s[ZZ];
3981
3982 #ifdef DEBUG_PME_REDUCE
3983                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3984                        pme->nodeid, thread, thread_f,
3985                        pme->pmegrid_start_ix,
3986                        pme->pmegrid_start_iy,
3987                        pme->pmegrid_start_iz,
3988                        sx, sy, sz,
3989                        offx-ox, tx1-ox, offx, tx1,
3990                        offy-oy, ty1-oy, offy, ty1,
3991                        offz-oz, tz1-oz, offz, tz1);
3992 #endif
3993
3994                 if (!(bCommX || bCommY))
3995                 {
3996                     /* Copy from the thread local grid to the node grid */
3997                     for (x = offx; x < tx1; x++)
3998                     {
3999                         for (y = offy; y < ty1; y++)
4000                         {
4001                             i0  = (x*fft_my + y)*fft_mz;
4002                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4003                             for (z = offz; z < tz1; z++)
4004                             {
4005                                 fftgrid[i0+z] += grid_th[i0t+z];
4006                             }
4007                         }
4008                     }
4009                 }
4010                 else
4011                 {
4012                     /* The order of this conditional decides
4013                      * where the corner volume gets stored with x+y decomp.
4014                      */
4015                     if (bCommY)
4016                     {
4017                         commbuf = commbuf_y;
4018                         /* The y-size of the communication buffer is set by
4019                          * the overlap of the grid part of our local slab
4020                          * with the part starting at the next slab.
4021                          */
4022                         buf_my  =
4023                             pme->overlap[1].s2g1[pme->nodeid_minor] -
4024                             pme->overlap[1].s2g0[pme->nodeid_minor+1];
4025                         if (bCommX)
4026                         {
4027                             /* We index commbuf modulo the local grid size */
4028                             commbuf += buf_my*fft_nx*fft_nz;
4029
4030                             bClearBuf   = bClearBufXY;
4031                             bClearBufXY = FALSE;
4032                         }
4033                         else
4034                         {
4035                             bClearBuf  = bClearBufY;
4036                             bClearBufY = FALSE;
4037                         }
4038                     }
4039                     else
4040                     {
4041                         commbuf    = commbuf_x;
4042                         buf_my     = fft_ny;
4043                         bClearBuf  = bClearBufX;
4044                         bClearBufX = FALSE;
4045                     }
4046
4047                     /* Copy to the communication buffer */
4048                     for (x = offx; x < tx1; x++)
4049                     {
4050                         for (y = offy; y < ty1; y++)
4051                         {
4052                             i0  = (x*buf_my + y)*fft_nz;
4053                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4054
4055                             if (bClearBuf)
4056                             {
4057                                 /* First access of commbuf, initialize it */
4058                                 for (z = offz; z < tz1; z++)
4059                                 {
4060                                     commbuf[i0+z]  = grid_th[i0t+z];
4061                                 }
4062                             }
4063                             else
4064                             {
4065                                 for (z = offz; z < tz1; z++)
4066                                 {
4067                                     commbuf[i0+z] += grid_th[i0t+z];
4068                                 }
4069                             }
4070                         }
4071                     }
4072                 }
4073             }
4074         }
4075     }
4076 }
4077
4078
4079 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid, int grid_index)
4080 {
4081     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4082     pme_overlap_t *overlap;
4083     int  send_index0, send_nindex;
4084     int  recv_nindex;
4085 #ifdef GMX_MPI
4086     MPI_Status stat;
4087 #endif
4088     int  send_size_y, recv_size_y;
4089     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4090     real *sendptr, *recvptr;
4091     int  x, y, z, indg, indb;
4092
4093     /* Note that this routine is only used for forward communication.
4094      * Since the force gathering, unlike the coefficient spreading,
4095      * can be trivially parallelized over the particles,
4096      * the backwards process is much simpler and can use the "old"
4097      * communication setup.
4098      */
4099
4100     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
4101                                    local_fft_ndata,
4102                                    local_fft_offset,
4103                                    local_fft_size);
4104
4105     if (pme->nnodes_minor > 1)
4106     {
4107         /* Major dimension */
4108         overlap = &pme->overlap[1];
4109
4110         if (pme->nnodes_major > 1)
4111         {
4112             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4113         }
4114         else
4115         {
4116             size_yx = 0;
4117         }
4118         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4119
4120         send_size_y = overlap->send_size;
4121
4122         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4123         {
4124             send_id       = overlap->send_id[ipulse];
4125             recv_id       = overlap->recv_id[ipulse];
4126             send_index0   =
4127                 overlap->comm_data[ipulse].send_index0 -
4128                 overlap->comm_data[0].send_index0;
4129             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4130             /* We don't use recv_index0, as we always receive starting at 0 */
4131             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4132             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4133
4134             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4135             recvptr = overlap->recvbuf;
4136
4137             if (debug != NULL)
4138             {
4139                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n",
4140                         local_fft_ndata[XX], send_nindex, local_fft_ndata[ZZ]);
4141             }
4142
4143 #ifdef GMX_MPI
4144             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4145                          send_id, ipulse,
4146                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4147                          recv_id, ipulse,
4148                          overlap->mpi_comm, &stat);
4149 #endif
4150
4151             for (x = 0; x < local_fft_ndata[XX]; x++)
4152             {
4153                 for (y = 0; y < recv_nindex; y++)
4154                 {
4155                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4156                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4157                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4158                     {
4159                         fftgrid[indg+z] += recvptr[indb+z];
4160                     }
4161                 }
4162             }
4163
4164             if (pme->nnodes_major > 1)
4165             {
4166                 /* Copy from the received buffer to the send buffer for dim 0 */
4167                 sendptr = pme->overlap[0].sendbuf;
4168                 for (x = 0; x < size_yx; x++)
4169                 {
4170                     for (y = 0; y < recv_nindex; y++)
4171                     {
4172                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4173                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4174                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4175                         {
4176                             sendptr[indg+z] += recvptr[indb+z];
4177                         }
4178                     }
4179                 }
4180             }
4181         }
4182     }
4183
4184     /* We only support a single pulse here.
4185      * This is not a severe limitation, as this code is only used
4186      * with OpenMP and with OpenMP the (PME) domains can be larger.
4187      */
4188     if (pme->nnodes_major > 1)
4189     {
4190         /* Major dimension */
4191         overlap = &pme->overlap[0];
4192
4193         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4194         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4195
4196         ipulse = 0;
4197
4198         send_id       = overlap->send_id[ipulse];
4199         recv_id       = overlap->recv_id[ipulse];
4200         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4201         /* We don't use recv_index0, as we always receive starting at 0 */
4202         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4203
4204         sendptr = overlap->sendbuf;
4205         recvptr = overlap->recvbuf;
4206
4207         if (debug != NULL)
4208         {
4209             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n",
4210                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4211         }
4212
4213 #ifdef GMX_MPI
4214         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4215                      send_id, ipulse,
4216                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4217                      recv_id, ipulse,
4218                      overlap->mpi_comm, &stat);
4219 #endif
4220
4221         for (x = 0; x < recv_nindex; x++)
4222         {
4223             for (y = 0; y < local_fft_ndata[YY]; y++)
4224             {
4225                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4226                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4227                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4228                 {
4229                     fftgrid[indg+z] += recvptr[indb+z];
4230                 }
4231             }
4232         }
4233     }
4234 }
4235
4236
4237 static void spread_on_grid(gmx_pme_t pme,
4238                            pme_atomcomm_t *atc, pmegrids_t *grids,
4239                            gmx_bool bCalcSplines, gmx_bool bSpread,
4240                            real *fftgrid, gmx_bool bDoSplines, int grid_index)
4241 {
4242     int nthread, thread;
4243 #ifdef PME_TIME_THREADS
4244     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4245     static double cs1     = 0, cs2 = 0, cs3 = 0;
4246     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4247     static int cnt        = 0;
4248 #endif
4249
4250     nthread = pme->nthread;
4251     assert(nthread > 0);
4252
4253 #ifdef PME_TIME_THREADS
4254     c1 = omp_cyc_start();
4255 #endif
4256     if (bCalcSplines)
4257     {
4258 #pragma omp parallel for num_threads(nthread) schedule(static)
4259         for (thread = 0; thread < nthread; thread++)
4260         {
4261             int start, end;
4262
4263             start = atc->n* thread   /nthread;
4264             end   = atc->n*(thread+1)/nthread;
4265
4266             /* Compute fftgrid index for all atoms,
4267              * with help of some extra variables.
4268              */
4269             calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
4270         }
4271     }
4272 #ifdef PME_TIME_THREADS
4273     c1   = omp_cyc_end(c1);
4274     cs1 += (double)c1;
4275 #endif
4276
4277 #ifdef PME_TIME_THREADS
4278     c2 = omp_cyc_start();
4279 #endif
4280 #pragma omp parallel for num_threads(nthread) schedule(static)
4281     for (thread = 0; thread < nthread; thread++)
4282     {
4283         splinedata_t *spline;
4284         pmegrid_t *grid = NULL;
4285
4286         /* make local bsplines  */
4287         if (grids == NULL || !pme->bUseThreads)
4288         {
4289             spline = &atc->spline[0];
4290
4291             spline->n = atc->n;
4292
4293             if (bSpread)
4294             {
4295                 grid = &grids->grid;
4296             }
4297         }
4298         else
4299         {
4300             spline = &atc->spline[thread];
4301
4302             if (grids->nthread == 1)
4303             {
4304                 /* One thread, we operate on all coefficients */
4305                 spline->n = atc->n;
4306             }
4307             else
4308             {
4309                 /* Get the indices our thread should operate on */
4310                 make_thread_local_ind(atc, thread, spline);
4311             }
4312
4313             grid = &grids->grid_th[thread];
4314         }
4315
4316         if (bCalcSplines)
4317         {
4318             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4319                           atc->fractx, spline->n, spline->ind, atc->coefficient, bDoSplines);
4320         }
4321
4322         if (bSpread)
4323         {
4324             /* put local atoms on grid. */
4325 #ifdef PME_TIME_SPREAD
4326             ct1a = omp_cyc_start();
4327 #endif
4328             spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
4329
4330             if (pme->bUseThreads)
4331             {
4332                 copy_local_grid(pme, grids, grid_index, thread, fftgrid);
4333             }
4334 #ifdef PME_TIME_SPREAD
4335             ct1a          = omp_cyc_end(ct1a);
4336             cs1a[thread] += (double)ct1a;
4337 #endif
4338         }
4339     }
4340 #ifdef PME_TIME_THREADS
4341     c2   = omp_cyc_end(c2);
4342     cs2 += (double)c2;
4343 #endif
4344
4345     if (bSpread && pme->bUseThreads)
4346     {
4347 #ifdef PME_TIME_THREADS
4348         c3 = omp_cyc_start();
4349 #endif
4350 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4351         for (thread = 0; thread < grids->nthread; thread++)
4352         {
4353             reduce_threadgrid_overlap(pme, grids, thread,
4354                                       fftgrid,
4355                                       pme->overlap[0].sendbuf,
4356                                       pme->overlap[1].sendbuf,
4357                                       grid_index);
4358         }
4359 #ifdef PME_TIME_THREADS
4360         c3   = omp_cyc_end(c3);
4361         cs3 += (double)c3;
4362 #endif
4363
4364         if (pme->nnodes > 1)
4365         {
4366             /* Communicate the overlapping part of the fftgrid.
4367              * For this communication call we need to check pme->bUseThreads
4368              * to have all ranks communicate here, regardless of pme->nthread.
4369              */
4370             sum_fftgrid_dd(pme, fftgrid, grid_index);
4371         }
4372     }
4373
4374 #ifdef PME_TIME_THREADS
4375     cnt++;
4376     if (cnt % 20 == 0)
4377     {
4378         printf("idx %.2f spread %.2f red %.2f",
4379                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4380 #ifdef PME_TIME_SPREAD
4381         for (thread = 0; thread < nthread; thread++)
4382         {
4383             printf(" %.2f", cs1a[thread]*1e-9);
4384         }
4385 #endif
4386         printf("\n");
4387     }
4388 #endif
4389 }
4390
4391
4392 static void dump_grid(FILE *fp,
4393                       int sx, int sy, int sz, int nx, int ny, int nz,
4394                       int my, int mz, const real *g)
4395 {
4396     int x, y, z;
4397
4398     for (x = 0; x < nx; x++)
4399     {
4400         for (y = 0; y < ny; y++)
4401         {
4402             for (z = 0; z < nz; z++)
4403             {
4404                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4405                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4406             }
4407         }
4408     }
4409 }
4410
4411 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4412 {
4413     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4414
4415     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4416                                    local_fft_ndata,
4417                                    local_fft_offset,
4418                                    local_fft_size);
4419
4420     dump_grid(stderr,
4421               pme->pmegrid_start_ix,
4422               pme->pmegrid_start_iy,
4423               pme->pmegrid_start_iz,
4424               pme->pmegrid_nx-pme->pme_order+1,
4425               pme->pmegrid_ny-pme->pme_order+1,
4426               pme->pmegrid_nz-pme->pme_order+1,
4427               local_fft_size[YY],
4428               local_fft_size[ZZ],
4429               fftgrid);
4430 }
4431
4432
4433 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4434 {
4435     pme_atomcomm_t *atc;
4436     pmegrids_t *grid;
4437
4438     if (pme->nnodes > 1)
4439     {
4440         gmx_incons("gmx_pme_calc_energy called in parallel");
4441     }
4442     if (pme->bFEP_q > 1)
4443     {
4444         gmx_incons("gmx_pme_calc_energy with free energy");
4445     }
4446
4447     atc            = &pme->atc_energy;
4448     atc->nthread   = 1;
4449     if (atc->spline == NULL)
4450     {
4451         snew(atc->spline, atc->nthread);
4452     }
4453     atc->nslab     = 1;
4454     atc->bSpread   = TRUE;
4455     atc->pme_order = pme->pme_order;
4456     atc->n         = n;
4457     pme_realloc_atomcomm_things(atc);
4458     atc->x           = x;
4459     atc->coefficient = q;
4460
4461     /* We only use the A-charges grid */
4462     grid = &pme->pmegrid[PME_GRID_QA];
4463
4464     /* Only calculate the spline coefficients, don't actually spread */
4465     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE, PME_GRID_QA);
4466
4467     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4468 }
4469
4470
4471 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4472                                    gmx_walltime_accounting_t walltime_accounting,
4473                                    t_nrnb *nrnb, t_inputrec *ir,
4474                                    gmx_int64_t step)
4475 {
4476     /* Reset all the counters related to performance over the run */
4477     wallcycle_stop(wcycle, ewcRUN);
4478     wallcycle_reset_all(wcycle);
4479     init_nrnb(nrnb);
4480     if (ir->nsteps >= 0)
4481     {
4482         /* ir->nsteps is not used here, but we update it for consistency */
4483         ir->nsteps -= step - ir->init_step;
4484     }
4485     ir->init_step = step;
4486     wallcycle_start(wcycle, ewcRUN);
4487     walltime_accounting_start(walltime_accounting);
4488 }
4489
4490
4491 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4492                                ivec grid_size,
4493                                t_commrec *cr, t_inputrec *ir,
4494                                gmx_pme_t *pme_ret)
4495 {
4496     int ind;
4497     gmx_pme_t pme = NULL;
4498
4499     ind = 0;
4500     while (ind < *npmedata)
4501     {
4502         pme = (*pmedata)[ind];
4503         if (pme->nkx == grid_size[XX] &&
4504             pme->nky == grid_size[YY] &&
4505             pme->nkz == grid_size[ZZ])
4506         {
4507             *pme_ret = pme;
4508
4509             return;
4510         }
4511
4512         ind++;
4513     }
4514
4515     (*npmedata)++;
4516     srenew(*pmedata, *npmedata);
4517
4518     /* Generate a new PME data structure, copying part of the old pointers */
4519     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4520
4521     *pme_ret = (*pmedata)[ind];
4522 }
4523
4524 int gmx_pmeonly(gmx_pme_t pme,
4525                 t_commrec *cr,    t_nrnb *mynrnb,
4526                 gmx_wallcycle_t wcycle,
4527                 gmx_walltime_accounting_t walltime_accounting,
4528                 real ewaldcoeff_q, real ewaldcoeff_lj,
4529                 t_inputrec *ir)
4530 {
4531     int npmedata;
4532     gmx_pme_t *pmedata;
4533     gmx_pme_pp_t pme_pp;
4534     int  ret;
4535     int  natoms;
4536     matrix box;
4537     rvec *x_pp      = NULL, *f_pp = NULL;
4538     real *chargeA   = NULL, *chargeB = NULL;
4539     real *c6A       = NULL, *c6B = NULL;
4540     real *sigmaA    = NULL, *sigmaB = NULL;
4541     real lambda_q   = 0;
4542     real lambda_lj  = 0;
4543     int  maxshift_x = 0, maxshift_y = 0;
4544     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4545     matrix vir_q, vir_lj;
4546     float cycles;
4547     int  count;
4548     gmx_bool bEnerVir;
4549     int pme_flags;
4550     gmx_int64_t step, step_rel;
4551     ivec grid_switch;
4552
4553     /* This data will only use with PME tuning, i.e. switching PME grids */
4554     npmedata = 1;
4555     snew(pmedata, npmedata);
4556     pmedata[0] = pme;
4557
4558     pme_pp = gmx_pme_pp_init(cr);
4559
4560     init_nrnb(mynrnb);
4561
4562     count = 0;
4563     do /****** this is a quasi-loop over time steps! */
4564     {
4565         /* The reason for having a loop here is PME grid tuning/switching */
4566         do
4567         {
4568             /* Domain decomposition */
4569             ret = gmx_pme_recv_coeffs_coords(pme_pp,
4570                                              &natoms,
4571                                              &chargeA, &chargeB,
4572                                              &c6A, &c6B,
4573                                              &sigmaA, &sigmaB,
4574                                              box, &x_pp, &f_pp,
4575                                              &maxshift_x, &maxshift_y,
4576                                              &pme->bFEP_q, &pme->bFEP_lj,
4577                                              &lambda_q, &lambda_lj,
4578                                              &bEnerVir,
4579                                              &pme_flags,
4580                                              &step,
4581                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4582
4583             if (ret == pmerecvqxSWITCHGRID)
4584             {
4585                 /* Switch the PME grid to grid_switch */
4586                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4587             }
4588
4589             if (ret == pmerecvqxRESETCOUNTERS)
4590             {
4591                 /* Reset the cycle and flop counters */
4592                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4593             }
4594         }
4595         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4596
4597         if (ret == pmerecvqxFINISH)
4598         {
4599             /* We should stop: break out of the loop */
4600             break;
4601         }
4602
4603         step_rel = step - ir->init_step;
4604
4605         if (count == 0)
4606         {
4607             wallcycle_start(wcycle, ewcRUN);
4608             walltime_accounting_start(walltime_accounting);
4609         }
4610
4611         wallcycle_start(wcycle, ewcPMEMESH);
4612
4613         dvdlambda_q  = 0;
4614         dvdlambda_lj = 0;
4615         clear_mat(vir_q);
4616         clear_mat(vir_lj);
4617
4618         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4619                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4620                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4621                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4622                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4623                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4624
4625         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4626
4627         gmx_pme_send_force_vir_ener(pme_pp,
4628                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4629                                     dvdlambda_q, dvdlambda_lj, cycles);
4630
4631         count++;
4632     } /***** end of quasi-loop, we stop with the break above */
4633     while (TRUE);
4634
4635     walltime_accounting_end(walltime_accounting);
4636
4637     return 0;
4638 }
4639
4640 static void
4641 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4642 {
4643     int  i;
4644
4645     for (i = 0; i < pme->atc[0].n; ++i)
4646     {
4647         real sigma4;
4648
4649         sigma4                     = local_sigma[i];
4650         sigma4                     = sigma4*sigma4;
4651         sigma4                     = sigma4*sigma4;
4652         pme->atc[0].coefficient[i] = local_c6[i] / sigma4;
4653     }
4654 }
4655
4656 static void
4657 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4658 {
4659     int  i;
4660
4661     for (i = 0; i < pme->atc[0].n; ++i)
4662     {
4663         pme->atc[0].coefficient[i] *= local_sigma[i];
4664     }
4665 }
4666
4667 static void
4668 do_redist_pos_coeffs(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4669                      gmx_bool bFirst, rvec x[], real *data)
4670 {
4671     int      d;
4672     pme_atomcomm_t *atc;
4673     atc = &pme->atc[0];
4674
4675     for (d = pme->ndecompdim - 1; d >= 0; d--)
4676     {
4677         int             n_d;
4678         rvec           *x_d;
4679         real           *param_d;
4680
4681         if (d == pme->ndecompdim - 1)
4682         {
4683             n_d     = homenr;
4684             x_d     = x + start;
4685             param_d = data;
4686         }
4687         else
4688         {
4689             n_d     = pme->atc[d + 1].n;
4690             x_d     = atc->x;
4691             param_d = atc->coefficient;
4692         }
4693         atc      = &pme->atc[d];
4694         atc->npd = n_d;
4695         if (atc->npd > atc->pd_nalloc)
4696         {
4697             atc->pd_nalloc = over_alloc_dd(atc->npd);
4698             srenew(atc->pd, atc->pd_nalloc);
4699         }
4700         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4701         where();
4702         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4703         if (DOMAINDECOMP(cr))
4704         {
4705             dd_pmeredist_pos_coeffs(pme, n_d, bFirst, x_d, param_d, atc);
4706         }
4707     }
4708 }
4709
4710 int gmx_pme_do(gmx_pme_t pme,
4711                int start,       int homenr,
4712                rvec x[],        rvec f[],
4713                real *chargeA,   real *chargeB,
4714                real *c6A,       real *c6B,
4715                real *sigmaA,    real *sigmaB,
4716                matrix box, t_commrec *cr,
4717                int  maxshift_x, int maxshift_y,
4718                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4719                matrix vir_q,      real ewaldcoeff_q,
4720                matrix vir_lj,   real ewaldcoeff_lj,
4721                real *energy_q,  real *energy_lj,
4722                real lambda_q, real lambda_lj,
4723                real *dvdlambda_q, real *dvdlambda_lj,
4724                int flags)
4725 {
4726     int     d, i, j, k, ntot, npme, grid_index, max_grid_index;
4727     int     nx, ny, nz;
4728     int     n_d, local_ny;
4729     pme_atomcomm_t *atc = NULL;
4730     pmegrids_t *pmegrid = NULL;
4731     real    *grid       = NULL;
4732     real    *ptr;
4733     rvec    *x_d, *f_d;
4734     real    *coefficient = NULL;
4735     real    energy_AB[4];
4736     matrix  vir_AB[4];
4737     real    scale, lambda;
4738     gmx_bool bClearF;
4739     gmx_parallel_3dfft_t pfft_setup;
4740     real *  fftgrid;
4741     t_complex * cfftgrid;
4742     int     thread;
4743     gmx_bool bFirst, bDoSplines;
4744     int fep_state;
4745     int fep_states_lj           = pme->bFEP_lj ? 2 : 1;
4746     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4747     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4748
4749     assert(pme->nnodes > 0);
4750     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4751
4752     if (pme->nnodes > 1)
4753     {
4754         atc      = &pme->atc[0];
4755         atc->npd = homenr;
4756         if (atc->npd > atc->pd_nalloc)
4757         {
4758             atc->pd_nalloc = over_alloc_dd(atc->npd);
4759             srenew(atc->pd, atc->pd_nalloc);
4760         }
4761         for (d = pme->ndecompdim-1; d >= 0; d--)
4762         {
4763             atc           = &pme->atc[d];
4764             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4765         }
4766     }
4767     else
4768     {
4769         atc = &pme->atc[0];
4770         /* This could be necessary for TPI */
4771         pme->atc[0].n = homenr;
4772         if (DOMAINDECOMP(cr))
4773         {
4774             pme_realloc_atomcomm_things(atc);
4775         }
4776         atc->x = x;
4777         atc->f = f;
4778     }
4779
4780     m_inv_ur0(box, pme->recipbox);
4781     bFirst = TRUE;
4782
4783     /* For simplicity, we construct the splines for all particles if
4784      * more than one PME calculations is needed. Some optimization
4785      * could be done by keeping track of which atoms have splines
4786      * constructed, and construct new splines on each pass for atoms
4787      * that don't yet have them.
4788      */
4789
4790     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4791
4792     /* We need a maximum of four separate PME calculations:
4793      * grid_index=0: Coulomb PME with charges from state A
4794      * grid_index=1: Coulomb PME with charges from state B
4795      * grid_index=2: LJ PME with C6 from state A
4796      * grid_index=3: LJ PME with C6 from state B
4797      * For Lorentz-Berthelot combination rules, a separate loop is used to
4798      * calculate all the terms
4799      */
4800
4801     /* If we are doing LJ-PME with LB, we only do Q here */
4802     max_grid_index = (pme->ljpme_combination_rule == eljpmeLB) ? DO_Q : DO_Q_AND_LJ;
4803
4804     for (grid_index = 0; grid_index < max_grid_index; ++grid_index)
4805     {
4806         /* Check if we should do calculations at this grid_index
4807          * If grid_index is odd we should be doing FEP
4808          * If grid_index < 2 we should be doing electrostatic PME
4809          * If grid_index >= 2 we should be doing LJ-PME
4810          */
4811         if ((grid_index <  DO_Q && (!(flags & GMX_PME_DO_COULOMB) ||
4812                                     (grid_index == 1 && !pme->bFEP_q))) ||
4813             (grid_index >= DO_Q && (!(flags & GMX_PME_DO_LJ) ||
4814                                     (grid_index == 3 && !pme->bFEP_lj))))
4815         {
4816             continue;
4817         }
4818         /* Unpack structure */
4819         pmegrid    = &pme->pmegrid[grid_index];
4820         fftgrid    = pme->fftgrid[grid_index];
4821         cfftgrid   = pme->cfftgrid[grid_index];
4822         pfft_setup = pme->pfft_setup[grid_index];
4823         switch (grid_index)
4824         {
4825             case 0: coefficient = chargeA + start; break;
4826             case 1: coefficient = chargeB + start; break;
4827             case 2: coefficient = c6A + start; break;
4828             case 3: coefficient = c6B + start; break;
4829         }
4830
4831         grid = pmegrid->grid.grid;
4832
4833         if (debug)
4834         {
4835             fprintf(debug, "PME: number of ranks = %d, rank = %d\n",
4836                     cr->nnodes, cr->nodeid);
4837             fprintf(debug, "Grid = %p\n", (void*)grid);
4838             if (grid == NULL)
4839             {
4840                 gmx_fatal(FARGS, "No grid!");
4841             }
4842         }
4843         where();
4844
4845         if (pme->nnodes == 1)
4846         {
4847             atc->coefficient = coefficient;
4848         }
4849         else
4850         {
4851             wallcycle_start(wcycle, ewcPME_REDISTXF);
4852             do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, coefficient);
4853             where();
4854
4855             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4856         }
4857
4858         if (debug)
4859         {
4860             fprintf(debug, "Rank= %6d, pme local particles=%6d\n",
4861                     cr->nodeid, atc->n);
4862         }
4863
4864         if (flags & GMX_PME_SPREAD)
4865         {
4866             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4867
4868             /* Spread the coefficients on a grid */
4869             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
4870
4871             if (bFirst)
4872             {
4873                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4874             }
4875             inc_nrnb(nrnb, eNR_SPREADBSP,
4876                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4877
4878             if (!pme->bUseThreads)
4879             {
4880                 wrap_periodic_pmegrid(pme, grid);
4881
4882                 /* sum contributions to local grid from other nodes */
4883 #ifdef GMX_MPI
4884                 if (pme->nnodes > 1)
4885                 {
4886                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
4887                     where();
4888                 }
4889 #endif
4890
4891                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
4892             }
4893
4894             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4895
4896             /*
4897                dump_local_fftgrid(pme,fftgrid);
4898                exit(0);
4899              */
4900         }
4901
4902         /* Here we start a large thread parallel region */
4903 #pragma omp parallel num_threads(pme->nthread) private(thread)
4904         {
4905             thread = gmx_omp_get_thread_num();
4906             if (flags & GMX_PME_SOLVE)
4907             {
4908                 int loop_count;
4909
4910                 /* do 3d-fft */
4911                 if (thread == 0)
4912                 {
4913                     wallcycle_start(wcycle, ewcPME_FFT);
4914                 }
4915                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4916                                            thread, wcycle);
4917                 if (thread == 0)
4918                 {
4919                     wallcycle_stop(wcycle, ewcPME_FFT);
4920                 }
4921                 where();
4922
4923                 /* solve in k-space for our local cells */
4924                 if (thread == 0)
4925                 {
4926                     wallcycle_start(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4927                 }
4928                 if (grid_index < DO_Q)
4929                 {
4930                     loop_count =
4931                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
4932                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4933                                       bCalcEnerVir,
4934                                       pme->nthread, thread);
4935                 }
4936                 else
4937                 {
4938                     loop_count =
4939                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
4940                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4941                                          bCalcEnerVir,
4942                                          pme->nthread, thread);
4943                 }
4944
4945                 if (thread == 0)
4946                 {
4947                     wallcycle_stop(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4948                     where();
4949                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4950                 }
4951             }
4952
4953             if (bCalcF)
4954             {
4955                 /* do 3d-invfft */
4956                 if (thread == 0)
4957                 {
4958                     where();
4959                     wallcycle_start(wcycle, ewcPME_FFT);
4960                 }
4961                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4962                                            thread, wcycle);
4963                 if (thread == 0)
4964                 {
4965                     wallcycle_stop(wcycle, ewcPME_FFT);
4966
4967                     where();
4968
4969                     if (pme->nodeid == 0)
4970                     {
4971                         ntot  = pme->nkx*pme->nky*pme->nkz;
4972                         npme  = ntot*log((real)ntot)/log(2.0);
4973                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4974                     }
4975
4976                     /* Note: this wallcycle region is closed below
4977                        outside an OpenMP region, so take care if
4978                        refactoring code here. */
4979                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4980                 }
4981
4982                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
4983             }
4984         }
4985         /* End of thread parallel section.
4986          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4987          */
4988
4989         if (bCalcF)
4990         {
4991             /* distribute local grid to all nodes */
4992 #ifdef GMX_MPI
4993             if (pme->nnodes > 1)
4994             {
4995                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
4996             }
4997 #endif
4998             where();
4999
5000             unwrap_periodic_pmegrid(pme, grid);
5001
5002             /* interpolate forces for our local atoms */
5003
5004             where();
5005
5006             /* If we are running without parallelization,
5007              * atc->f is the actual force array, not a buffer,
5008              * therefore we should not clear it.
5009              */
5010             lambda  = grid_index < DO_Q ? lambda_q : lambda_lj;
5011             bClearF = (bFirst && PAR(cr));
5012 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5013             for (thread = 0; thread < pme->nthread; thread++)
5014             {
5015                 gather_f_bsplines(pme, grid, bClearF, atc,
5016                                   &atc->spline[thread],
5017                                   pme->bFEP ? (grid_index % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
5018             }
5019
5020             where();
5021
5022             inc_nrnb(nrnb, eNR_GATHERFBSP,
5023                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5024             /* Note: this wallcycle region is opened above inside an OpenMP
5025                region, so take care if refactoring code here. */
5026             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5027         }
5028
5029         if (bCalcEnerVir)
5030         {
5031             /* This should only be called on the master thread
5032              * and after the threads have synchronized.
5033              */
5034             if (grid_index < 2)
5035             {
5036                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5037             }
5038             else
5039             {
5040                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5041             }
5042         }
5043         bFirst = FALSE;
5044     } /* of grid_index-loop */
5045
5046     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
5047      * seven terms. */
5048
5049     if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB)
5050     {
5051         /* Loop over A- and B-state if we are doing FEP */
5052         for (fep_state = 0; fep_state < fep_states_lj; ++fep_state)
5053         {
5054             real *local_c6 = NULL, *local_sigma = NULL, *RedistC6 = NULL, *RedistSigma = NULL;
5055             if (pme->nnodes == 1)
5056             {
5057                 if (pme->lb_buf1 == NULL)
5058                 {
5059                     pme->lb_buf_nalloc = pme->atc[0].n;
5060                     snew(pme->lb_buf1, pme->lb_buf_nalloc);
5061                 }
5062                 pme->atc[0].coefficient = pme->lb_buf1;
5063                 switch (fep_state)
5064                 {
5065                     case 0:
5066                         local_c6      = c6A;
5067                         local_sigma   = sigmaA;
5068                         break;
5069                     case 1:
5070                         local_c6      = c6B;
5071                         local_sigma   = sigmaB;
5072                         break;
5073                     default:
5074                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5075                 }
5076             }
5077             else
5078             {
5079                 atc = &pme->atc[0];
5080                 switch (fep_state)
5081                 {
5082                     case 0:
5083                         RedistC6      = c6A;
5084                         RedistSigma   = sigmaA;
5085                         break;
5086                     case 1:
5087                         RedistC6      = c6B;
5088                         RedistSigma   = sigmaB;
5089                         break;
5090                     default:
5091                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5092                 }
5093                 wallcycle_start(wcycle, ewcPME_REDISTXF);
5094
5095                 do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, RedistC6);
5096                 if (pme->lb_buf_nalloc < atc->n)
5097                 {
5098                     pme->lb_buf_nalloc = atc->nalloc;
5099                     srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5100                     srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5101                 }
5102                 local_c6 = pme->lb_buf1;
5103                 for (i = 0; i < atc->n; ++i)
5104                 {
5105                     local_c6[i] = atc->coefficient[i];
5106                 }
5107                 where();
5108
5109                 do_redist_pos_coeffs(pme, cr, start, homenr, FALSE, x, RedistSigma);
5110                 local_sigma = pme->lb_buf2;
5111                 for (i = 0; i < atc->n; ++i)
5112                 {
5113                     local_sigma[i] = atc->coefficient[i];
5114                 }
5115                 where();
5116
5117                 wallcycle_stop(wcycle, ewcPME_REDISTXF);
5118             }
5119             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5120
5121             /*Seven terms in LJ-PME with LB, grid_index < 2 reserved for electrostatics*/
5122             for (grid_index = 2; grid_index < 9; ++grid_index)
5123             {
5124                 /* Unpack structure */
5125                 pmegrid    = &pme->pmegrid[grid_index];
5126                 fftgrid    = pme->fftgrid[grid_index];
5127                 cfftgrid   = pme->cfftgrid[grid_index];
5128                 pfft_setup = pme->pfft_setup[grid_index];
5129                 calc_next_lb_coeffs(pme, local_sigma);
5130                 grid = pmegrid->grid.grid;
5131                 where();
5132
5133                 if (flags & GMX_PME_SPREAD)
5134                 {
5135                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5136                     /* Spread the c6 on a grid */
5137                     spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
5138
5139                     if (bFirst)
5140                     {
5141                         inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5142                     }
5143
5144                     inc_nrnb(nrnb, eNR_SPREADBSP,
5145                              pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5146                     if (pme->nthread == 1)
5147                     {
5148                         wrap_periodic_pmegrid(pme, grid);
5149                         /* sum contributions to local grid from other nodes */
5150 #ifdef GMX_MPI
5151                         if (pme->nnodes > 1)
5152                         {
5153                             gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
5154                             where();
5155                         }
5156 #endif
5157                         copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
5158                     }
5159                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5160                 }
5161                 /*Here we start a large thread parallel region*/
5162 #pragma omp parallel num_threads(pme->nthread) private(thread)
5163                 {
5164                     thread = gmx_omp_get_thread_num();
5165                     if (flags & GMX_PME_SOLVE)
5166                     {
5167                         /* do 3d-fft */
5168                         if (thread == 0)
5169                         {
5170                             wallcycle_start(wcycle, ewcPME_FFT);
5171                         }
5172
5173                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5174                                                    thread, wcycle);
5175                         if (thread == 0)
5176                         {
5177                             wallcycle_stop(wcycle, ewcPME_FFT);
5178                         }
5179                         where();
5180                     }
5181                 }
5182                 bFirst = FALSE;
5183             }
5184             if (flags & GMX_PME_SOLVE)
5185             {
5186                 /* solve in k-space for our local cells */
5187 #pragma omp parallel num_threads(pme->nthread) private(thread)
5188                 {
5189                     int loop_count;
5190                     thread = gmx_omp_get_thread_num();
5191                     if (thread == 0)
5192                     {
5193                         wallcycle_start(wcycle, ewcLJPME);
5194                     }
5195
5196                     loop_count =
5197                         solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5198                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5199                                          bCalcEnerVir,
5200                                          pme->nthread, thread);
5201                     if (thread == 0)
5202                     {
5203                         wallcycle_stop(wcycle, ewcLJPME);
5204                         where();
5205                         inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5206                     }
5207                 }
5208             }
5209
5210             if (bCalcEnerVir)
5211             {
5212                 /* This should only be called on the master thread and
5213                  * after the threads have synchronized.
5214                  */
5215                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
5216             }
5217
5218             if (bCalcF)
5219             {
5220                 bFirst = !(flags & GMX_PME_DO_COULOMB);
5221                 calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5222                 for (grid_index = 8; grid_index >= 2; --grid_index)
5223                 {
5224                     /* Unpack structure */
5225                     pmegrid    = &pme->pmegrid[grid_index];
5226                     fftgrid    = pme->fftgrid[grid_index];
5227                     cfftgrid   = pme->cfftgrid[grid_index];
5228                     pfft_setup = pme->pfft_setup[grid_index];
5229                     grid       = pmegrid->grid.grid;
5230                     calc_next_lb_coeffs(pme, local_sigma);
5231                     where();
5232 #pragma omp parallel num_threads(pme->nthread) private(thread)
5233                     {
5234                         thread = gmx_omp_get_thread_num();
5235                         /* do 3d-invfft */
5236                         if (thread == 0)
5237                         {
5238                             where();
5239                             wallcycle_start(wcycle, ewcPME_FFT);
5240                         }
5241
5242                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5243                                                    thread, wcycle);
5244                         if (thread == 0)
5245                         {
5246                             wallcycle_stop(wcycle, ewcPME_FFT);
5247
5248                             where();
5249
5250                             if (pme->nodeid == 0)
5251                             {
5252                                 ntot  = pme->nkx*pme->nky*pme->nkz;
5253                                 npme  = ntot*log((real)ntot)/log(2.0);
5254                                 inc_nrnb(nrnb, eNR_FFT, 2*npme);
5255                             }
5256                             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5257                         }
5258
5259                         copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
5260
5261                     } /*#pragma omp parallel*/
5262
5263                     /* distribute local grid to all nodes */
5264 #ifdef GMX_MPI
5265                     if (pme->nnodes > 1)
5266                     {
5267                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
5268                     }
5269 #endif
5270                     where();
5271
5272                     unwrap_periodic_pmegrid(pme, grid);
5273
5274                     /* interpolate forces for our local atoms */
5275                     where();
5276                     bClearF = (bFirst && PAR(cr));
5277                     scale   = pme->bFEP ? (fep_state < 1 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5278                     scale  *= lb_scale_factor[grid_index-2];
5279 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5280                     for (thread = 0; thread < pme->nthread; thread++)
5281                     {
5282                         gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5283                                           &pme->atc[0].spline[thread],
5284                                           scale);
5285                     }
5286                     where();
5287
5288                     inc_nrnb(nrnb, eNR_GATHERFBSP,
5289                              pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5290                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5291
5292                     bFirst = FALSE;
5293                 } /* for (grid_index = 8; grid_index >= 2; --grid_index) */
5294             }     /* if (bCalcF) */
5295         }         /* for (fep_state = 0; fep_state < fep_states_lj; ++fep_state) */
5296     }             /* if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB) */
5297
5298     if (bCalcF && pme->nnodes > 1)
5299     {
5300         wallcycle_start(wcycle, ewcPME_REDISTXF);
5301         for (d = 0; d < pme->ndecompdim; d++)
5302         {
5303             atc = &pme->atc[d];
5304             if (d == pme->ndecompdim - 1)
5305             {
5306                 n_d = homenr;
5307                 f_d = f + start;
5308             }
5309             else
5310             {
5311                 n_d = pme->atc[d+1].n;
5312                 f_d = pme->atc[d+1].f;
5313             }
5314             if (DOMAINDECOMP(cr))
5315             {
5316                 dd_pmeredist_f(pme, atc, n_d, f_d,
5317                                d == pme->ndecompdim-1 && pme->bPPnode);
5318             }
5319         }
5320
5321         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5322     }
5323     where();
5324
5325     if (bCalcEnerVir)
5326     {
5327         if (flags & GMX_PME_DO_COULOMB)
5328         {
5329             if (!pme->bFEP_q)
5330             {
5331                 *energy_q = energy_AB[0];
5332                 m_add(vir_q, vir_AB[0], vir_q);
5333             }
5334             else
5335             {
5336                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5337                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5338                 for (i = 0; i < DIM; i++)
5339                 {
5340                     for (j = 0; j < DIM; j++)
5341                     {
5342                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5343                             lambda_q*vir_AB[1][i][j];
5344                     }
5345                 }
5346             }
5347             if (debug)
5348             {
5349                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5350             }
5351         }
5352         else
5353         {
5354             *energy_q = 0;
5355         }
5356
5357         if (flags & GMX_PME_DO_LJ)
5358         {
5359             if (!pme->bFEP_lj)
5360             {
5361                 *energy_lj = energy_AB[2];
5362                 m_add(vir_lj, vir_AB[2], vir_lj);
5363             }
5364             else
5365             {
5366                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5367                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5368                 for (i = 0; i < DIM; i++)
5369                 {
5370                     for (j = 0; j < DIM; j++)
5371                     {
5372                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5373                     }
5374                 }
5375             }
5376             if (debug)
5377             {
5378                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5379             }
5380         }
5381         else
5382         {
5383             *energy_lj = 0;
5384         }
5385     }
5386     return 0;
5387 }