Fix OpenMP scope error
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #ifdef HAVE_CONFIG_H
61 #include <config.h>
62 #endif
63
64 #include <stdio.h>
65 #include <string.h>
66 #include <math.h>
67 #include <assert.h>
68 #include "typedefs.h"
69 #include "txtdump.h"
70 #include "vec.h"
71 #include "gromacs/utility/smalloc.h"
72 #include "coulomb.h"
73 #include "gmx_fatal.h"
74 #include "pme.h"
75 #include "network.h"
76 #include "physics.h"
77 #include "nrnb.h"
78 #include "macros.h"
79
80 #include "gromacs/fft/parallel_3dfft.h"
81 #include "gromacs/fileio/futil.h"
82 #include "gromacs/fileio/pdbio.h"
83 #include "gromacs/math/gmxcomplex.h"
84 #include "gromacs/timing/cyclecounter.h"
85 #include "gromacs/timing/wallcycle.h"
86 #include "gromacs/utility/gmxmpi.h"
87 #include "gromacs/utility/gmxomp.h"
88
89 /* Include the SIMD macro file and then check for support */
90 #include "gromacs/simd/simd.h"
91 #include "gromacs/simd/simd_math.h"
92 #ifdef GMX_SIMD_HAVE_REAL
93 /* Turn on arbitrary width SIMD intrinsics for PME solve */
94 #    define PME_SIMD_SOLVE
95 #endif
96
97 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
98 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
99 #define DO_Q           2 /* Electrostatic grids have index q<2 */
100 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
101 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
102
103 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
104 const real lb_scale_factor[] = {
105     1.0/64, 6.0/64, 15.0/64, 20.0/64,
106     15.0/64, 6.0/64, 1.0/64
107 };
108
109 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
110 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
111
112 /* Check if we have 4-wide SIMD macro support */
113 #if (defined GMX_SIMD4_HAVE_REAL)
114 /* Do PME spread and gather with 4-wide SIMD.
115  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
116  */
117 #    define PME_SIMD4_SPREAD_GATHER
118
119 #    if (defined GMX_SIMD_HAVE_LOADU) && (defined GMX_SIMD_HAVE_STOREU)
120 /* With PME-order=4 on x86, unaligned load+store is slightly faster
121  * than doubling all SIMD operations when using aligned load+store.
122  */
123 #        define PME_SIMD4_UNALIGNED
124 #    endif
125 #endif
126
127 #define DFT_TOL 1e-7
128 /* #define PRT_FORCE */
129 /* conditions for on the fly time-measurement */
130 /* #define TAKETIME (step > 1 && timesteps < 10) */
131 #define TAKETIME FALSE
132
133 /* #define PME_TIME_THREADS */
134
135 #ifdef GMX_DOUBLE
136 #define mpi_type MPI_DOUBLE
137 #else
138 #define mpi_type MPI_FLOAT
139 #endif
140
141 #ifdef PME_SIMD4_SPREAD_GATHER
142 #    define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
143 #else
144 /* We can use any alignment, apart from 0, so we use 4 reals */
145 #    define SIMD4_ALIGNMENT  (4*sizeof(real))
146 #endif
147
148 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
149  * to preserve alignment.
150  */
151 #define GMX_CACHE_SEP 64
152
153 /* We only define a maximum to be able to use local arrays without allocation.
154  * An order larger than 12 should never be needed, even for test cases.
155  * If needed it can be changed here.
156  */
157 #define PME_ORDER_MAX 12
158
159 /* Internal datastructures */
160 typedef struct {
161     int send_index0;
162     int send_nindex;
163     int recv_index0;
164     int recv_nindex;
165     int recv_size;   /* Receive buffer width, used with OpenMP */
166 } pme_grid_comm_t;
167
168 typedef struct {
169 #ifdef GMX_MPI
170     MPI_Comm         mpi_comm;
171 #endif
172     int              nnodes, nodeid;
173     int             *s2g0;
174     int             *s2g1;
175     int              noverlap_nodes;
176     int             *send_id, *recv_id;
177     int              send_size; /* Send buffer width, used with OpenMP */
178     pme_grid_comm_t *comm_data;
179     real            *sendbuf;
180     real            *recvbuf;
181 } pme_overlap_t;
182
183 typedef struct {
184     int *n;      /* Cumulative counts of the number of particles per thread */
185     int  nalloc; /* Allocation size of i */
186     int *i;      /* Particle indices ordered on thread index (n) */
187 } thread_plist_t;
188
189 typedef struct {
190     int      *thread_one;
191     int       n;
192     int      *ind;
193     splinevec theta;
194     real     *ptr_theta_z;
195     splinevec dtheta;
196     real     *ptr_dtheta_z;
197 } splinedata_t;
198
199 typedef struct {
200     int      dimind;        /* The index of the dimension, 0=x, 1=y */
201     int      nslab;
202     int      nodeid;
203 #ifdef GMX_MPI
204     MPI_Comm mpi_comm;
205 #endif
206
207     int     *node_dest;     /* The nodes to send x and q to with DD */
208     int     *node_src;      /* The nodes to receive x and q from with DD */
209     int     *buf_index;     /* Index for commnode into the buffers */
210
211     int      maxshift;
212
213     int      npd;
214     int      pd_nalloc;
215     int     *pd;
216     int     *count;         /* The number of atoms to send to each node */
217     int    **count_thread;
218     int     *rcount;        /* The number of atoms to receive */
219
220     int      n;
221     int      nalloc;
222     rvec    *x;
223     real    *coefficient;
224     rvec    *f;
225     gmx_bool bSpread;       /* These coordinates are used for spreading */
226     int      pme_order;
227     ivec    *idx;
228     rvec    *fractx;            /* Fractional coordinate relative to
229                                  * the lower cell boundary
230                                  */
231     int             nthread;
232     int            *thread_idx; /* Which thread should spread which coefficient */
233     thread_plist_t *thread_plist;
234     splinedata_t   *spline;
235 } pme_atomcomm_t;
236
237 #define FLBS  3
238 #define FLBSZ 4
239
240 typedef struct {
241     ivec  ci;     /* The spatial location of this grid         */
242     ivec  n;      /* The used size of *grid, including order-1 */
243     ivec  offset; /* The grid offset from the full node grid   */
244     int   order;  /* PME spreading order                       */
245     ivec  s;      /* The allocated size of *grid, s >= n       */
246     real *grid;   /* The grid local thread, size n             */
247 } pmegrid_t;
248
249 typedef struct {
250     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
251     int        nthread;      /* The number of threads operating on this grid     */
252     ivec       nc;           /* The local spatial decomposition over the threads */
253     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
254     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
255     int      **g2t;          /* The grid to thread index                         */
256     ivec       nthread_comm; /* The number of threads to communicate with        */
257 } pmegrids_t;
258
259 typedef struct {
260 #ifdef PME_SIMD4_SPREAD_GATHER
261     /* Masks for 4-wide SIMD aligned spreading and gathering */
262     gmx_simd4_bool_t mask_S0[6], mask_S1[6];
263 #else
264     int              dummy; /* C89 requires that struct has at least one member */
265 #endif
266 } pme_spline_work_t;
267
268 typedef struct {
269     /* work data for solve_pme */
270     int      nalloc;
271     real *   mhx;
272     real *   mhy;
273     real *   mhz;
274     real *   m2;
275     real *   denom;
276     real *   tmp1_alloc;
277     real *   tmp1;
278     real *   tmp2;
279     real *   eterm;
280     real *   m2inv;
281
282     real     energy_q;
283     matrix   vir_q;
284     real     energy_lj;
285     matrix   vir_lj;
286 } pme_work_t;
287
288 typedef struct gmx_pme {
289     int           ndecompdim; /* The number of decomposition dimensions */
290     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
291     int           nodeid_major;
292     int           nodeid_minor;
293     int           nnodes;    /* The number of nodes doing PME */
294     int           nnodes_major;
295     int           nnodes_minor;
296
297     MPI_Comm      mpi_comm;
298     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
299 #ifdef GMX_MPI
300     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
301 #endif
302
303     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
304     int        nthread;       /* The number of threads doing PME on our rank */
305
306     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
307     gmx_bool   bFEP;          /* Compute Free energy contribution */
308     gmx_bool   bFEP_q;
309     gmx_bool   bFEP_lj;
310     int        nkx, nky, nkz; /* Grid dimensions */
311     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
312     int        pme_order;
313     real       epsilon_r;
314
315     int        ljpme_combination_rule;  /* Type of combination rule in LJ-PME */
316
317     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
318
319     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
320                                          * includes overlap Grid indices are ordered as
321                                          * follows:
322                                          * 0: Coloumb PME, state A
323                                          * 1: Coloumb PME, state B
324                                          * 2-8: LJ-PME
325                                          * This can probably be done in a better way
326                                          * but this simple hack works for now
327                                          */
328     /* The PME coefficient spreading grid sizes/strides, includes pme_order-1 */
329     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
330     /* pmegrid_nz might be larger than strictly necessary to ensure
331      * memory alignment, pmegrid_nz_base gives the real base size.
332      */
333     int     pmegrid_nz_base;
334     /* The local PME grid starting indices */
335     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
336
337     /* Work data for spreading and gathering */
338     pme_spline_work_t     *spline_work;
339
340     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
341     /* inside the interpolation grid, but separate for 2D PME decomp. */
342     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
343
344     t_complex            **cfftgrid;  /* Grids for complex FFT data */
345
346     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
347
348     gmx_parallel_3dfft_t  *pfft_setup;
349
350     int                   *nnx, *nny, *nnz;
351     real                  *fshx, *fshy, *fshz;
352
353     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
354     matrix                 recipbox;
355     splinevec              bsp_mod;
356     /* Buffers to store data for local atoms for L-B combination rule
357      * calculations in LJ-PME. lb_buf1 stores either the coefficients
358      * for spreading/gathering (in serial), or the C6 coefficient for
359      * local atoms (in parallel).  lb_buf2 is only used in parallel,
360      * and stores the sigma values for local atoms. */
361     real                 *lb_buf1, *lb_buf2;
362     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
363
364     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
365
366     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
367
368     rvec                 *bufv;          /* Communication buffer */
369     real                 *bufr;          /* Communication buffer */
370     int                   buf_nalloc;    /* The communication buffer size */
371
372     /* thread local work data for solve_pme */
373     pme_work_t *work;
374
375     /* Work data for sum_qgrid */
376     real *   sum_qgrid_tmp;
377     real *   sum_qgrid_dd_tmp;
378 } t_gmx_pme;
379
380 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
381                                    int start, int grid_index, int end, int thread)
382 {
383     int             i;
384     int            *idxptr, tix, tiy, tiz;
385     real           *xptr, *fptr, tx, ty, tz;
386     real            rxx, ryx, ryy, rzx, rzy, rzz;
387     int             nx, ny, nz;
388     int             start_ix, start_iy, start_iz;
389     int            *g2tx, *g2ty, *g2tz;
390     gmx_bool        bThreads;
391     int            *thread_idx = NULL;
392     thread_plist_t *tpl        = NULL;
393     int            *tpl_n      = NULL;
394     int             thread_i;
395
396     nx  = pme->nkx;
397     ny  = pme->nky;
398     nz  = pme->nkz;
399
400     start_ix = pme->pmegrid_start_ix;
401     start_iy = pme->pmegrid_start_iy;
402     start_iz = pme->pmegrid_start_iz;
403
404     rxx = pme->recipbox[XX][XX];
405     ryx = pme->recipbox[YY][XX];
406     ryy = pme->recipbox[YY][YY];
407     rzx = pme->recipbox[ZZ][XX];
408     rzy = pme->recipbox[ZZ][YY];
409     rzz = pme->recipbox[ZZ][ZZ];
410
411     g2tx = pme->pmegrid[grid_index].g2t[XX];
412     g2ty = pme->pmegrid[grid_index].g2t[YY];
413     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
414
415     bThreads = (atc->nthread > 1);
416     if (bThreads)
417     {
418         thread_idx = atc->thread_idx;
419
420         tpl   = &atc->thread_plist[thread];
421         tpl_n = tpl->n;
422         for (i = 0; i < atc->nthread; i++)
423         {
424             tpl_n[i] = 0;
425         }
426     }
427
428     for (i = start; i < end; i++)
429     {
430         xptr   = atc->x[i];
431         idxptr = atc->idx[i];
432         fptr   = atc->fractx[i];
433
434         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
435         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
436         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
437         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
438
439         tix = (int)(tx);
440         tiy = (int)(ty);
441         tiz = (int)(tz);
442
443         /* Because decomposition only occurs in x and y,
444          * we never have a fraction correction in z.
445          */
446         fptr[XX] = tx - tix + pme->fshx[tix];
447         fptr[YY] = ty - tiy + pme->fshy[tiy];
448         fptr[ZZ] = tz - tiz;
449
450         idxptr[XX] = pme->nnx[tix];
451         idxptr[YY] = pme->nny[tiy];
452         idxptr[ZZ] = pme->nnz[tiz];
453
454 #ifdef DEBUG
455         range_check(idxptr[XX], 0, pme->pmegrid_nx);
456         range_check(idxptr[YY], 0, pme->pmegrid_ny);
457         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
458 #endif
459
460         if (bThreads)
461         {
462             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
463             thread_idx[i] = thread_i;
464             tpl_n[thread_i]++;
465         }
466     }
467
468     if (bThreads)
469     {
470         /* Make a list of particle indices sorted on thread */
471
472         /* Get the cumulative count */
473         for (i = 1; i < atc->nthread; i++)
474         {
475             tpl_n[i] += tpl_n[i-1];
476         }
477         /* The current implementation distributes particles equally
478          * over the threads, so we could actually allocate for that
479          * in pme_realloc_atomcomm_things.
480          */
481         if (tpl_n[atc->nthread-1] > tpl->nalloc)
482         {
483             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
484             srenew(tpl->i, tpl->nalloc);
485         }
486         /* Set tpl_n to the cumulative start */
487         for (i = atc->nthread-1; i >= 1; i--)
488         {
489             tpl_n[i] = tpl_n[i-1];
490         }
491         tpl_n[0] = 0;
492
493         /* Fill our thread local array with indices sorted on thread */
494         for (i = start; i < end; i++)
495         {
496             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
497         }
498         /* Now tpl_n contains the cummulative count again */
499     }
500 }
501
502 static void make_thread_local_ind(pme_atomcomm_t *atc,
503                                   int thread, splinedata_t *spline)
504 {
505     int             n, t, i, start, end;
506     thread_plist_t *tpl;
507
508     /* Combine the indices made by each thread into one index */
509
510     n     = 0;
511     start = 0;
512     for (t = 0; t < atc->nthread; t++)
513     {
514         tpl = &atc->thread_plist[t];
515         /* Copy our part (start - end) from the list of thread t */
516         if (thread > 0)
517         {
518             start = tpl->n[thread-1];
519         }
520         end = tpl->n[thread];
521         for (i = start; i < end; i++)
522         {
523             spline->ind[n++] = tpl->i[i];
524         }
525     }
526
527     spline->n = n;
528 }
529
530
531 static void pme_calc_pidx(int start, int end,
532                           matrix recipbox, rvec x[],
533                           pme_atomcomm_t *atc, int *count)
534 {
535     int   nslab, i;
536     int   si;
537     real *xptr, s;
538     real  rxx, ryx, rzx, ryy, rzy;
539     int  *pd;
540
541     /* Calculate PME task index (pidx) for each grid index.
542      * Here we always assign equally sized slabs to each node
543      * for load balancing reasons (the PME grid spacing is not used).
544      */
545
546     nslab = atc->nslab;
547     pd    = atc->pd;
548
549     /* Reset the count */
550     for (i = 0; i < nslab; i++)
551     {
552         count[i] = 0;
553     }
554
555     if (atc->dimind == 0)
556     {
557         rxx = recipbox[XX][XX];
558         ryx = recipbox[YY][XX];
559         rzx = recipbox[ZZ][XX];
560         /* Calculate the node index in x-dimension */
561         for (i = start; i < end; i++)
562         {
563             xptr   = x[i];
564             /* Fractional coordinates along box vectors */
565             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
566             si    = (int)(s + 2*nslab) % nslab;
567             pd[i] = si;
568             count[si]++;
569         }
570     }
571     else
572     {
573         ryy = recipbox[YY][YY];
574         rzy = recipbox[ZZ][YY];
575         /* Calculate the node index in y-dimension */
576         for (i = start; i < end; i++)
577         {
578             xptr   = x[i];
579             /* Fractional coordinates along box vectors */
580             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
581             si    = (int)(s + 2*nslab) % nslab;
582             pd[i] = si;
583             count[si]++;
584         }
585     }
586 }
587
588 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
589                                   pme_atomcomm_t *atc)
590 {
591     int nthread, thread, slab;
592
593     nthread = atc->nthread;
594
595 #pragma omp parallel for num_threads(nthread) schedule(static)
596     for (thread = 0; thread < nthread; thread++)
597     {
598         pme_calc_pidx(natoms* thread   /nthread,
599                       natoms*(thread+1)/nthread,
600                       recipbox, x, atc, atc->count_thread[thread]);
601     }
602     /* Non-parallel reduction, since nslab is small */
603
604     for (thread = 1; thread < nthread; thread++)
605     {
606         for (slab = 0; slab < atc->nslab; slab++)
607         {
608             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
609         }
610     }
611 }
612
613 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
614 {
615     const int padding = 4;
616     int       i;
617
618     srenew(th[XX], nalloc);
619     srenew(th[YY], nalloc);
620     /* In z we add padding, this is only required for the aligned SIMD code */
621     sfree_aligned(*ptr_z);
622     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
623     th[ZZ] = *ptr_z + padding;
624
625     for (i = 0; i < padding; i++)
626     {
627         (*ptr_z)[               i] = 0;
628         (*ptr_z)[padding+nalloc+i] = 0;
629     }
630 }
631
632 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
633 {
634     int i, d;
635
636     srenew(spline->ind, atc->nalloc);
637     /* Initialize the index to identity so it works without threads */
638     for (i = 0; i < atc->nalloc; i++)
639     {
640         spline->ind[i] = i;
641     }
642
643     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
644                       atc->pme_order*atc->nalloc);
645     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
646                       atc->pme_order*atc->nalloc);
647 }
648
649 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
650 {
651     int nalloc_old, i, j, nalloc_tpl;
652
653     /* We have to avoid a NULL pointer for atc->x to avoid
654      * possible fatal errors in MPI routines.
655      */
656     if (atc->n > atc->nalloc || atc->nalloc == 0)
657     {
658         nalloc_old  = atc->nalloc;
659         atc->nalloc = over_alloc_dd(max(atc->n, 1));
660
661         if (atc->nslab > 1)
662         {
663             srenew(atc->x, atc->nalloc);
664             srenew(atc->coefficient, atc->nalloc);
665             srenew(atc->f, atc->nalloc);
666             for (i = nalloc_old; i < atc->nalloc; i++)
667             {
668                 clear_rvec(atc->f[i]);
669             }
670         }
671         if (atc->bSpread)
672         {
673             srenew(atc->fractx, atc->nalloc);
674             srenew(atc->idx, atc->nalloc);
675
676             if (atc->nthread > 1)
677             {
678                 srenew(atc->thread_idx, atc->nalloc);
679             }
680
681             for (i = 0; i < atc->nthread; i++)
682             {
683                 pme_realloc_splinedata(&atc->spline[i], atc);
684             }
685         }
686     }
687 }
688
689 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
690                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
691                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
692                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
693 {
694 #ifdef GMX_MPI
695     int        dest, src;
696     MPI_Status stat;
697
698     if (bBackward == FALSE)
699     {
700         dest = atc->node_dest[shift];
701         src  = atc->node_src[shift];
702     }
703     else
704     {
705         dest = atc->node_src[shift];
706         src  = atc->node_dest[shift];
707     }
708
709     if (nbyte_s > 0 && nbyte_r > 0)
710     {
711         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
712                      dest, shift,
713                      buf_r, nbyte_r, MPI_BYTE,
714                      src, shift,
715                      atc->mpi_comm, &stat);
716     }
717     else if (nbyte_s > 0)
718     {
719         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
720                  dest, shift,
721                  atc->mpi_comm);
722     }
723     else if (nbyte_r > 0)
724     {
725         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
726                  src, shift,
727                  atc->mpi_comm, &stat);
728     }
729 #endif
730 }
731
732 static void dd_pmeredist_pos_coeffs(gmx_pme_t pme,
733                                     int n, gmx_bool bX, rvec *x, real *data,
734                                     pme_atomcomm_t *atc)
735 {
736     int *commnode, *buf_index;
737     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
738
739     commnode  = atc->node_dest;
740     buf_index = atc->buf_index;
741
742     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
743
744     nsend = 0;
745     for (i = 0; i < nnodes_comm; i++)
746     {
747         buf_index[commnode[i]] = nsend;
748         nsend                 += atc->count[commnode[i]];
749     }
750     if (bX)
751     {
752         if (atc->count[atc->nodeid] + nsend != n)
753         {
754             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
755                       "This usually means that your system is not well equilibrated.",
756                       n - (atc->count[atc->nodeid] + nsend),
757                       pme->nodeid, 'x'+atc->dimind);
758         }
759
760         if (nsend > pme->buf_nalloc)
761         {
762             pme->buf_nalloc = over_alloc_dd(nsend);
763             srenew(pme->bufv, pme->buf_nalloc);
764             srenew(pme->bufr, pme->buf_nalloc);
765         }
766
767         atc->n = atc->count[atc->nodeid];
768         for (i = 0; i < nnodes_comm; i++)
769         {
770             scount = atc->count[commnode[i]];
771             /* Communicate the count */
772             if (debug)
773             {
774                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
775                         atc->dimind, atc->nodeid, commnode[i], scount);
776             }
777             pme_dd_sendrecv(atc, FALSE, i,
778                             &scount, sizeof(int),
779                             &atc->rcount[i], sizeof(int));
780             atc->n += atc->rcount[i];
781         }
782
783         pme_realloc_atomcomm_things(atc);
784     }
785
786     local_pos = 0;
787     for (i = 0; i < n; i++)
788     {
789         node = atc->pd[i];
790         if (node == atc->nodeid)
791         {
792             /* Copy direct to the receive buffer */
793             if (bX)
794             {
795                 copy_rvec(x[i], atc->x[local_pos]);
796             }
797             atc->coefficient[local_pos] = data[i];
798             local_pos++;
799         }
800         else
801         {
802             /* Copy to the send buffer */
803             if (bX)
804             {
805                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
806             }
807             pme->bufr[buf_index[node]] = data[i];
808             buf_index[node]++;
809         }
810     }
811
812     buf_pos = 0;
813     for (i = 0; i < nnodes_comm; i++)
814     {
815         scount = atc->count[commnode[i]];
816         rcount = atc->rcount[i];
817         if (scount > 0 || rcount > 0)
818         {
819             if (bX)
820             {
821                 /* Communicate the coordinates */
822                 pme_dd_sendrecv(atc, FALSE, i,
823                                 pme->bufv[buf_pos], scount*sizeof(rvec),
824                                 atc->x[local_pos], rcount*sizeof(rvec));
825             }
826             /* Communicate the coefficients */
827             pme_dd_sendrecv(atc, FALSE, i,
828                             pme->bufr+buf_pos, scount*sizeof(real),
829                             atc->coefficient+local_pos, rcount*sizeof(real));
830             buf_pos   += scount;
831             local_pos += atc->rcount[i];
832         }
833     }
834 }
835
836 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
837                            int n, rvec *f,
838                            gmx_bool bAddF)
839 {
840     int *commnode, *buf_index;
841     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
842
843     commnode  = atc->node_dest;
844     buf_index = atc->buf_index;
845
846     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
847
848     local_pos = atc->count[atc->nodeid];
849     buf_pos   = 0;
850     for (i = 0; i < nnodes_comm; i++)
851     {
852         scount = atc->rcount[i];
853         rcount = atc->count[commnode[i]];
854         if (scount > 0 || rcount > 0)
855         {
856             /* Communicate the forces */
857             pme_dd_sendrecv(atc, TRUE, i,
858                             atc->f[local_pos], scount*sizeof(rvec),
859                             pme->bufv[buf_pos], rcount*sizeof(rvec));
860             local_pos += scount;
861         }
862         buf_index[commnode[i]] = buf_pos;
863         buf_pos               += rcount;
864     }
865
866     local_pos = 0;
867     if (bAddF)
868     {
869         for (i = 0; i < n; i++)
870         {
871             node = atc->pd[i];
872             if (node == atc->nodeid)
873             {
874                 /* Add from the local force array */
875                 rvec_inc(f[i], atc->f[local_pos]);
876                 local_pos++;
877             }
878             else
879             {
880                 /* Add from the receive buffer */
881                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
882                 buf_index[node]++;
883             }
884         }
885     }
886     else
887     {
888         for (i = 0; i < n; i++)
889         {
890             node = atc->pd[i];
891             if (node == atc->nodeid)
892             {
893                 /* Copy from the local force array */
894                 copy_rvec(atc->f[local_pos], f[i]);
895                 local_pos++;
896             }
897             else
898             {
899                 /* Copy from the receive buffer */
900                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
901                 buf_index[node]++;
902             }
903         }
904     }
905 }
906
907 #ifdef GMX_MPI
908 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
909 {
910     pme_overlap_t *overlap;
911     int            send_index0, send_nindex;
912     int            recv_index0, recv_nindex;
913     MPI_Status     stat;
914     int            i, j, k, ix, iy, iz, icnt;
915     int            ipulse, send_id, recv_id, datasize;
916     real          *p;
917     real          *sendptr, *recvptr;
918
919     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
920     overlap = &pme->overlap[1];
921
922     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
923     {
924         /* Since we have already (un)wrapped the overlap in the z-dimension,
925          * we only have to communicate 0 to nkz (not pmegrid_nz).
926          */
927         if (direction == GMX_SUM_GRID_FORWARD)
928         {
929             send_id       = overlap->send_id[ipulse];
930             recv_id       = overlap->recv_id[ipulse];
931             send_index0   = overlap->comm_data[ipulse].send_index0;
932             send_nindex   = overlap->comm_data[ipulse].send_nindex;
933             recv_index0   = overlap->comm_data[ipulse].recv_index0;
934             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
935         }
936         else
937         {
938             send_id       = overlap->recv_id[ipulse];
939             recv_id       = overlap->send_id[ipulse];
940             send_index0   = overlap->comm_data[ipulse].recv_index0;
941             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
942             recv_index0   = overlap->comm_data[ipulse].send_index0;
943             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
944         }
945
946         /* Copy data to contiguous send buffer */
947         if (debug)
948         {
949             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
950                     pme->nodeid, overlap->nodeid, send_id,
951                     pme->pmegrid_start_iy,
952                     send_index0-pme->pmegrid_start_iy,
953                     send_index0-pme->pmegrid_start_iy+send_nindex);
954         }
955         icnt = 0;
956         for (i = 0; i < pme->pmegrid_nx; i++)
957         {
958             ix = i;
959             for (j = 0; j < send_nindex; j++)
960             {
961                 iy = j + send_index0 - pme->pmegrid_start_iy;
962                 for (k = 0; k < pme->nkz; k++)
963                 {
964                     iz = k;
965                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
966                 }
967             }
968         }
969
970         datasize      = pme->pmegrid_nx * pme->nkz;
971
972         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
973                      send_id, ipulse,
974                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
975                      recv_id, ipulse,
976                      overlap->mpi_comm, &stat);
977
978         /* Get data from contiguous recv buffer */
979         if (debug)
980         {
981             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
982                     pme->nodeid, overlap->nodeid, recv_id,
983                     pme->pmegrid_start_iy,
984                     recv_index0-pme->pmegrid_start_iy,
985                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
986         }
987         icnt = 0;
988         for (i = 0; i < pme->pmegrid_nx; i++)
989         {
990             ix = i;
991             for (j = 0; j < recv_nindex; j++)
992             {
993                 iy = j + recv_index0 - pme->pmegrid_start_iy;
994                 for (k = 0; k < pme->nkz; k++)
995                 {
996                     iz = k;
997                     if (direction == GMX_SUM_GRID_FORWARD)
998                     {
999                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1000                     }
1001                     else
1002                     {
1003                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1004                     }
1005                 }
1006             }
1007         }
1008     }
1009
1010     /* Major dimension is easier, no copying required,
1011      * but we might have to sum to separate array.
1012      * Since we don't copy, we have to communicate up to pmegrid_nz,
1013      * not nkz as for the minor direction.
1014      */
1015     overlap = &pme->overlap[0];
1016
1017     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1018     {
1019         if (direction == GMX_SUM_GRID_FORWARD)
1020         {
1021             send_id       = overlap->send_id[ipulse];
1022             recv_id       = overlap->recv_id[ipulse];
1023             send_index0   = overlap->comm_data[ipulse].send_index0;
1024             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1025             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1026             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1027             recvptr       = overlap->recvbuf;
1028         }
1029         else
1030         {
1031             send_id       = overlap->recv_id[ipulse];
1032             recv_id       = overlap->send_id[ipulse];
1033             send_index0   = overlap->comm_data[ipulse].recv_index0;
1034             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1035             recv_index0   = overlap->comm_data[ipulse].send_index0;
1036             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1037             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1038         }
1039
1040         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1041         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1042
1043         if (debug)
1044         {
1045             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1046                     pme->nodeid, overlap->nodeid, send_id,
1047                     pme->pmegrid_start_ix,
1048                     send_index0-pme->pmegrid_start_ix,
1049                     send_index0-pme->pmegrid_start_ix+send_nindex);
1050             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1051                     pme->nodeid, overlap->nodeid, recv_id,
1052                     pme->pmegrid_start_ix,
1053                     recv_index0-pme->pmegrid_start_ix,
1054                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1055         }
1056
1057         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1058                      send_id, ipulse,
1059                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1060                      recv_id, ipulse,
1061                      overlap->mpi_comm, &stat);
1062
1063         /* ADD data from contiguous recv buffer */
1064         if (direction == GMX_SUM_GRID_FORWARD)
1065         {
1066             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1067             for (i = 0; i < recv_nindex*datasize; i++)
1068             {
1069                 p[i] += overlap->recvbuf[i];
1070             }
1071         }
1072     }
1073 }
1074 #endif
1075
1076
1077 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1078 {
1079     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1080     ivec    local_pme_size;
1081     int     i, ix, iy, iz;
1082     int     pmeidx, fftidx;
1083
1084     /* Dimensions should be identical for A/B grid, so we just use A here */
1085     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1086                                    local_fft_ndata,
1087                                    local_fft_offset,
1088                                    local_fft_size);
1089
1090     local_pme_size[0] = pme->pmegrid_nx;
1091     local_pme_size[1] = pme->pmegrid_ny;
1092     local_pme_size[2] = pme->pmegrid_nz;
1093
1094     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1095        the offset is identical, and the PME grid always has more data (due to overlap)
1096      */
1097     {
1098 #ifdef DEBUG_PME
1099         FILE *fp, *fp2;
1100         char  fn[STRLEN], format[STRLEN];
1101         real  val;
1102         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1103         fp = gmx_ffopen(fn, "w");
1104         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1105         fp2 = gmx_ffopen(fn, "w");
1106         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1107 #endif
1108
1109         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1110         {
1111             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1112             {
1113                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1114                 {
1115                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1116                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1117                     fftgrid[fftidx] = pmegrid[pmeidx];
1118 #ifdef DEBUG_PME
1119                     val = 100*pmegrid[pmeidx];
1120                     if (pmegrid[pmeidx] != 0)
1121                     {
1122                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1123                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1124                     }
1125                     if (pmegrid[pmeidx] != 0)
1126                     {
1127                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1128                                 "qgrid",
1129                                 pme->pmegrid_start_ix + ix,
1130                                 pme->pmegrid_start_iy + iy,
1131                                 pme->pmegrid_start_iz + iz,
1132                                 pmegrid[pmeidx]);
1133                     }
1134 #endif
1135                 }
1136             }
1137         }
1138 #ifdef DEBUG_PME
1139         gmx_ffclose(fp);
1140         gmx_ffclose(fp2);
1141 #endif
1142     }
1143     return 0;
1144 }
1145
1146
1147 static gmx_cycles_t omp_cyc_start()
1148 {
1149     return gmx_cycles_read();
1150 }
1151
1152 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1153 {
1154     return gmx_cycles_read() - c;
1155 }
1156
1157
1158 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1159                                    int nthread, int thread)
1160 {
1161     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1162     ivec          local_pme_size;
1163     int           ixy0, ixy1, ixy, ix, iy, iz;
1164     int           pmeidx, fftidx;
1165 #ifdef PME_TIME_THREADS
1166     gmx_cycles_t  c1;
1167     static double cs1 = 0;
1168     static int    cnt = 0;
1169 #endif
1170
1171 #ifdef PME_TIME_THREADS
1172     c1 = omp_cyc_start();
1173 #endif
1174     /* Dimensions should be identical for A/B grid, so we just use A here */
1175     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1176                                    local_fft_ndata,
1177                                    local_fft_offset,
1178                                    local_fft_size);
1179
1180     local_pme_size[0] = pme->pmegrid_nx;
1181     local_pme_size[1] = pme->pmegrid_ny;
1182     local_pme_size[2] = pme->pmegrid_nz;
1183
1184     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1185        the offset is identical, and the PME grid always has more data (due to overlap)
1186      */
1187     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1188     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1189
1190     for (ixy = ixy0; ixy < ixy1; ixy++)
1191     {
1192         ix = ixy/local_fft_ndata[YY];
1193         iy = ixy - ix*local_fft_ndata[YY];
1194
1195         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1196         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1197         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1198         {
1199             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1200         }
1201     }
1202
1203 #ifdef PME_TIME_THREADS
1204     c1   = omp_cyc_end(c1);
1205     cs1 += (double)c1;
1206     cnt++;
1207     if (cnt % 20 == 0)
1208     {
1209         printf("copy %.2f\n", cs1*1e-9);
1210     }
1211 #endif
1212
1213     return 0;
1214 }
1215
1216
1217 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1218 {
1219     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1220
1221     nx = pme->nkx;
1222     ny = pme->nky;
1223     nz = pme->nkz;
1224
1225     pnx = pme->pmegrid_nx;
1226     pny = pme->pmegrid_ny;
1227     pnz = pme->pmegrid_nz;
1228
1229     overlap = pme->pme_order - 1;
1230
1231     /* Add periodic overlap in z */
1232     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1233     {
1234         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1235         {
1236             for (iz = 0; iz < overlap; iz++)
1237             {
1238                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1239                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1240             }
1241         }
1242     }
1243
1244     if (pme->nnodes_minor == 1)
1245     {
1246         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1247         {
1248             for (iy = 0; iy < overlap; iy++)
1249             {
1250                 for (iz = 0; iz < nz; iz++)
1251                 {
1252                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1253                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1254                 }
1255             }
1256         }
1257     }
1258
1259     if (pme->nnodes_major == 1)
1260     {
1261         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1262
1263         for (ix = 0; ix < overlap; ix++)
1264         {
1265             for (iy = 0; iy < ny_x; iy++)
1266             {
1267                 for (iz = 0; iz < nz; iz++)
1268                 {
1269                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1270                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1271                 }
1272             }
1273         }
1274     }
1275 }
1276
1277
1278 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1279 {
1280     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1281
1282     nx = pme->nkx;
1283     ny = pme->nky;
1284     nz = pme->nkz;
1285
1286     pnx = pme->pmegrid_nx;
1287     pny = pme->pmegrid_ny;
1288     pnz = pme->pmegrid_nz;
1289
1290     overlap = pme->pme_order - 1;
1291
1292     if (pme->nnodes_major == 1)
1293     {
1294         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1295
1296         for (ix = 0; ix < overlap; ix++)
1297         {
1298             int iy, iz;
1299
1300             for (iy = 0; iy < ny_x; iy++)
1301             {
1302                 for (iz = 0; iz < nz; iz++)
1303                 {
1304                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1305                         pmegrid[(ix*pny+iy)*pnz+iz];
1306                 }
1307             }
1308         }
1309     }
1310
1311     if (pme->nnodes_minor == 1)
1312     {
1313 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1314         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1315         {
1316             int iy, iz;
1317
1318             for (iy = 0; iy < overlap; iy++)
1319             {
1320                 for (iz = 0; iz < nz; iz++)
1321                 {
1322                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1323                         pmegrid[(ix*pny+iy)*pnz+iz];
1324                 }
1325             }
1326         }
1327     }
1328
1329     /* Copy periodic overlap in z */
1330 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1331     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1332     {
1333         int iy, iz;
1334
1335         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1336         {
1337             for (iz = 0; iz < overlap; iz++)
1338             {
1339                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1340                     pmegrid[(ix*pny+iy)*pnz+iz];
1341             }
1342         }
1343     }
1344 }
1345
1346
1347 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1348 #define DO_BSPLINE(order)                            \
1349     for (ithx = 0; (ithx < order); ithx++)                    \
1350     {                                                    \
1351         index_x = (i0+ithx)*pny*pnz;                     \
1352         valx    = coefficient*thx[ithx];                          \
1353                                                      \
1354         for (ithy = 0; (ithy < order); ithy++)                \
1355         {                                                \
1356             valxy    = valx*thy[ithy];                   \
1357             index_xy = index_x+(j0+ithy)*pnz;            \
1358                                                      \
1359             for (ithz = 0; (ithz < order); ithz++)            \
1360             {                                            \
1361                 index_xyz        = index_xy+(k0+ithz);   \
1362                 grid[index_xyz] += valxy*thz[ithz];      \
1363             }                                            \
1364         }                                                \
1365     }
1366
1367
1368 static void spread_coefficients_bsplines_thread(pmegrid_t                    *pmegrid,
1369                                                 pme_atomcomm_t               *atc,
1370                                                 splinedata_t                 *spline,
1371                                                 pme_spline_work_t gmx_unused *work)
1372 {
1373
1374     /* spread coefficients from home atoms to local grid */
1375     real          *grid;
1376     pme_overlap_t *ol;
1377     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1378     int       *    idxptr;
1379     int            order, norder, index_x, index_xy, index_xyz;
1380     real           valx, valxy, coefficient;
1381     real          *thx, *thy, *thz;
1382     int            localsize, bndsize;
1383     int            pnx, pny, pnz, ndatatot;
1384     int            offx, offy, offz;
1385
1386 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1387     real           thz_buffer[GMX_SIMD4_WIDTH*3], *thz_aligned;
1388
1389     thz_aligned = gmx_simd4_align_r(thz_buffer);
1390 #endif
1391
1392     pnx = pmegrid->s[XX];
1393     pny = pmegrid->s[YY];
1394     pnz = pmegrid->s[ZZ];
1395
1396     offx = pmegrid->offset[XX];
1397     offy = pmegrid->offset[YY];
1398     offz = pmegrid->offset[ZZ];
1399
1400     ndatatot = pnx*pny*pnz;
1401     grid     = pmegrid->grid;
1402     for (i = 0; i < ndatatot; i++)
1403     {
1404         grid[i] = 0;
1405     }
1406
1407     order = pmegrid->order;
1408
1409     for (nn = 0; nn < spline->n; nn++)
1410     {
1411         n           = spline->ind[nn];
1412         coefficient = atc->coefficient[n];
1413
1414         if (coefficient != 0)
1415         {
1416             idxptr = atc->idx[n];
1417             norder = nn*order;
1418
1419             i0   = idxptr[XX] - offx;
1420             j0   = idxptr[YY] - offy;
1421             k0   = idxptr[ZZ] - offz;
1422
1423             thx = spline->theta[XX] + norder;
1424             thy = spline->theta[YY] + norder;
1425             thz = spline->theta[ZZ] + norder;
1426
1427             switch (order)
1428             {
1429                 case 4:
1430 #ifdef PME_SIMD4_SPREAD_GATHER
1431 #ifdef PME_SIMD4_UNALIGNED
1432 #define PME_SPREAD_SIMD4_ORDER4
1433 #else
1434 #define PME_SPREAD_SIMD4_ALIGNED
1435 #define PME_ORDER 4
1436 #endif
1437 #include "pme_simd4.h"
1438 #else
1439                     DO_BSPLINE(4);
1440 #endif
1441                     break;
1442                 case 5:
1443 #ifdef PME_SIMD4_SPREAD_GATHER
1444 #define PME_SPREAD_SIMD4_ALIGNED
1445 #define PME_ORDER 5
1446 #include "pme_simd4.h"
1447 #else
1448                     DO_BSPLINE(5);
1449 #endif
1450                     break;
1451                 default:
1452                     DO_BSPLINE(order);
1453                     break;
1454             }
1455         }
1456     }
1457 }
1458
1459 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1460 {
1461 #ifdef PME_SIMD4_SPREAD_GATHER
1462     if (pme_order == 5
1463 #ifndef PME_SIMD4_UNALIGNED
1464         || pme_order == 4
1465 #endif
1466         )
1467     {
1468         /* Round nz up to a multiple of 4 to ensure alignment */
1469         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1470     }
1471 #endif
1472 }
1473
1474 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1475 {
1476 #ifdef PME_SIMD4_SPREAD_GATHER
1477 #ifndef PME_SIMD4_UNALIGNED
1478     if (pme_order == 4)
1479     {
1480         /* Add extra elements to ensured aligned operations do not go
1481          * beyond the allocated grid size.
1482          * Note that for pme_order=5, the pme grid z-size alignment
1483          * ensures that we will not go beyond the grid size.
1484          */
1485         *gridsize += 4;
1486     }
1487 #endif
1488 #endif
1489 }
1490
1491 static void pmegrid_init(pmegrid_t *grid,
1492                          int cx, int cy, int cz,
1493                          int x0, int y0, int z0,
1494                          int x1, int y1, int z1,
1495                          gmx_bool set_alignment,
1496                          int pme_order,
1497                          real *ptr)
1498 {
1499     int nz, gridsize;
1500
1501     grid->ci[XX]     = cx;
1502     grid->ci[YY]     = cy;
1503     grid->ci[ZZ]     = cz;
1504     grid->offset[XX] = x0;
1505     grid->offset[YY] = y0;
1506     grid->offset[ZZ] = z0;
1507     grid->n[XX]      = x1 - x0 + pme_order - 1;
1508     grid->n[YY]      = y1 - y0 + pme_order - 1;
1509     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1510     copy_ivec(grid->n, grid->s);
1511
1512     nz = grid->s[ZZ];
1513     set_grid_alignment(&nz, pme_order);
1514     if (set_alignment)
1515     {
1516         grid->s[ZZ] = nz;
1517     }
1518     else if (nz != grid->s[ZZ])
1519     {
1520         gmx_incons("pmegrid_init call with an unaligned z size");
1521     }
1522
1523     grid->order = pme_order;
1524     if (ptr == NULL)
1525     {
1526         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1527         set_gridsize_alignment(&gridsize, pme_order);
1528         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1529     }
1530     else
1531     {
1532         grid->grid = ptr;
1533     }
1534 }
1535
1536 static int div_round_up(int enumerator, int denominator)
1537 {
1538     return (enumerator + denominator - 1)/denominator;
1539 }
1540
1541 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1542                                   ivec nsub)
1543 {
1544     int gsize_opt, gsize;
1545     int nsx, nsy, nsz;
1546     char *env;
1547
1548     gsize_opt = -1;
1549     for (nsx = 1; nsx <= nthread; nsx++)
1550     {
1551         if (nthread % nsx == 0)
1552         {
1553             for (nsy = 1; nsy <= nthread; nsy++)
1554             {
1555                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1556                 {
1557                     nsz = nthread/(nsx*nsy);
1558
1559                     /* Determine the number of grid points per thread */
1560                     gsize =
1561                         (div_round_up(n[XX], nsx) + ovl)*
1562                         (div_round_up(n[YY], nsy) + ovl)*
1563                         (div_round_up(n[ZZ], nsz) + ovl);
1564
1565                     /* Minimize the number of grids points per thread
1566                      * and, secondarily, the number of cuts in minor dimensions.
1567                      */
1568                     if (gsize_opt == -1 ||
1569                         gsize < gsize_opt ||
1570                         (gsize == gsize_opt &&
1571                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1572                     {
1573                         nsub[XX]  = nsx;
1574                         nsub[YY]  = nsy;
1575                         nsub[ZZ]  = nsz;
1576                         gsize_opt = gsize;
1577                     }
1578                 }
1579             }
1580         }
1581     }
1582
1583     env = getenv("GMX_PME_THREAD_DIVISION");
1584     if (env != NULL)
1585     {
1586         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1587     }
1588
1589     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1590     {
1591         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1592     }
1593 }
1594
1595 static void pmegrids_init(pmegrids_t *grids,
1596                           int nx, int ny, int nz, int nz_base,
1597                           int pme_order,
1598                           gmx_bool bUseThreads,
1599                           int nthread,
1600                           int overlap_x,
1601                           int overlap_y)
1602 {
1603     ivec n, n_base, g0, g1;
1604     int t, x, y, z, d, i, tfac;
1605     int max_comm_lines = -1;
1606
1607     n[XX] = nx - (pme_order - 1);
1608     n[YY] = ny - (pme_order - 1);
1609     n[ZZ] = nz - (pme_order - 1);
1610
1611     copy_ivec(n, n_base);
1612     n_base[ZZ] = nz_base;
1613
1614     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1615                  NULL);
1616
1617     grids->nthread = nthread;
1618
1619     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1620
1621     if (bUseThreads)
1622     {
1623         ivec nst;
1624         int gridsize;
1625
1626         for (d = 0; d < DIM; d++)
1627         {
1628             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1629         }
1630         set_grid_alignment(&nst[ZZ], pme_order);
1631
1632         if (debug)
1633         {
1634             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1635                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1636             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1637                     nx, ny, nz,
1638                     nst[XX], nst[YY], nst[ZZ]);
1639         }
1640
1641         snew(grids->grid_th, grids->nthread);
1642         t        = 0;
1643         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1644         set_gridsize_alignment(&gridsize, pme_order);
1645         snew_aligned(grids->grid_all,
1646                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1647                      SIMD4_ALIGNMENT);
1648
1649         for (x = 0; x < grids->nc[XX]; x++)
1650         {
1651             for (y = 0; y < grids->nc[YY]; y++)
1652             {
1653                 for (z = 0; z < grids->nc[ZZ]; z++)
1654                 {
1655                     pmegrid_init(&grids->grid_th[t],
1656                                  x, y, z,
1657                                  (n[XX]*(x  ))/grids->nc[XX],
1658                                  (n[YY]*(y  ))/grids->nc[YY],
1659                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1660                                  (n[XX]*(x+1))/grids->nc[XX],
1661                                  (n[YY]*(y+1))/grids->nc[YY],
1662                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1663                                  TRUE,
1664                                  pme_order,
1665                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1666                     t++;
1667                 }
1668             }
1669         }
1670     }
1671     else
1672     {
1673         grids->grid_th = NULL;
1674     }
1675
1676     snew(grids->g2t, DIM);
1677     tfac = 1;
1678     for (d = DIM-1; d >= 0; d--)
1679     {
1680         snew(grids->g2t[d], n[d]);
1681         t = 0;
1682         for (i = 0; i < n[d]; i++)
1683         {
1684             /* The second check should match the parameters
1685              * of the pmegrid_init call above.
1686              */
1687             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1688             {
1689                 t++;
1690             }
1691             grids->g2t[d][i] = t*tfac;
1692         }
1693
1694         tfac *= grids->nc[d];
1695
1696         switch (d)
1697         {
1698             case XX: max_comm_lines = overlap_x;     break;
1699             case YY: max_comm_lines = overlap_y;     break;
1700             case ZZ: max_comm_lines = pme_order - 1; break;
1701         }
1702         grids->nthread_comm[d] = 0;
1703         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1704                grids->nthread_comm[d] < grids->nc[d])
1705         {
1706             grids->nthread_comm[d]++;
1707         }
1708         if (debug != NULL)
1709         {
1710             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1711                     'x'+d, grids->nthread_comm[d]);
1712         }
1713         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1714          * work, but this is not a problematic restriction.
1715          */
1716         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1717         {
1718             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1719         }
1720     }
1721 }
1722
1723
1724 static void pmegrids_destroy(pmegrids_t *grids)
1725 {
1726     int t;
1727
1728     if (grids->grid.grid != NULL)
1729     {
1730         sfree(grids->grid.grid);
1731
1732         if (grids->nthread > 0)
1733         {
1734             for (t = 0; t < grids->nthread; t++)
1735             {
1736                 sfree(grids->grid_th[t].grid);
1737             }
1738             sfree(grids->grid_th);
1739         }
1740     }
1741 }
1742
1743
1744 static void realloc_work(pme_work_t *work, int nkx)
1745 {
1746     int simd_width;
1747
1748     if (nkx > work->nalloc)
1749     {
1750         work->nalloc = nkx;
1751         srenew(work->mhx, work->nalloc);
1752         srenew(work->mhy, work->nalloc);
1753         srenew(work->mhz, work->nalloc);
1754         srenew(work->m2, work->nalloc);
1755         /* Allocate an aligned pointer for SIMD operations, including extra
1756          * elements at the end for padding.
1757          */
1758 #ifdef PME_SIMD_SOLVE
1759         simd_width = GMX_SIMD_REAL_WIDTH;
1760 #else
1761         /* We can use any alignment, apart from 0, so we use 4 */
1762         simd_width = 4;
1763 #endif
1764         sfree_aligned(work->denom);
1765         sfree_aligned(work->tmp1);
1766         sfree_aligned(work->tmp2);
1767         sfree_aligned(work->eterm);
1768         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1769         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1770         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1771         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1772         srenew(work->m2inv, work->nalloc);
1773     }
1774 }
1775
1776
1777 static void free_work(pme_work_t *work)
1778 {
1779     sfree(work->mhx);
1780     sfree(work->mhy);
1781     sfree(work->mhz);
1782     sfree(work->m2);
1783     sfree_aligned(work->denom);
1784     sfree_aligned(work->tmp1);
1785     sfree_aligned(work->tmp2);
1786     sfree_aligned(work->eterm);
1787     sfree(work->m2inv);
1788 }
1789
1790
1791 #if defined PME_SIMD_SOLVE
1792 /* Calculate exponentials through SIMD */
1793 gmx_inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1794 {
1795     {
1796         const gmx_simd_real_t two = gmx_simd_set1_r(2.0);
1797         gmx_simd_real_t f_simd;
1798         gmx_simd_real_t lu;
1799         gmx_simd_real_t tmp_d1, d_inv, tmp_r, tmp_e;
1800         int kx;
1801         f_simd = gmx_simd_set1_r(f);
1802         /* We only need to calculate from start. But since start is 0 or 1
1803          * and we want to use aligned loads/stores, we always start from 0.
1804          */
1805         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1806         {
1807             tmp_d1   = gmx_simd_load_r(d_aligned+kx);
1808             d_inv    = gmx_simd_inv_r(tmp_d1);
1809             tmp_r    = gmx_simd_load_r(r_aligned+kx);
1810             tmp_r    = gmx_simd_exp_r(tmp_r);
1811             tmp_e    = gmx_simd_mul_r(f_simd, d_inv);
1812             tmp_e    = gmx_simd_mul_r(tmp_e, tmp_r);
1813             gmx_simd_store_r(e_aligned+kx, tmp_e);
1814         }
1815     }
1816 }
1817 #else
1818 gmx_inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1819 {
1820     int kx;
1821     for (kx = start; kx < end; kx++)
1822     {
1823         d[kx] = 1.0/d[kx];
1824     }
1825     for (kx = start; kx < end; kx++)
1826     {
1827         r[kx] = exp(r[kx]);
1828     }
1829     for (kx = start; kx < end; kx++)
1830     {
1831         e[kx] = f*r[kx]*d[kx];
1832     }
1833 }
1834 #endif
1835
1836 #if defined PME_SIMD_SOLVE
1837 /* Calculate exponentials through SIMD */
1838 gmx_inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1839 {
1840     gmx_simd_real_t tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1841     const gmx_simd_real_t sqr_PI = gmx_simd_sqrt_r(gmx_simd_set1_r(M_PI));
1842     int kx;
1843     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1844     {
1845         /* We only need to calculate from start. But since start is 0 or 1
1846          * and we want to use aligned loads/stores, we always start from 0.
1847          */
1848         tmp_d = gmx_simd_load_r(d_aligned+kx);
1849         d_inv = gmx_simd_inv_r(tmp_d);
1850         gmx_simd_store_r(d_aligned+kx, d_inv);
1851         tmp_r = gmx_simd_load_r(r_aligned+kx);
1852         tmp_r = gmx_simd_exp_r(tmp_r);
1853         gmx_simd_store_r(r_aligned+kx, tmp_r);
1854         tmp_mk  = gmx_simd_load_r(factor_aligned+kx);
1855         tmp_fac = gmx_simd_mul_r(sqr_PI, gmx_simd_mul_r(tmp_mk, gmx_simd_erfc_r(tmp_mk)));
1856         gmx_simd_store_r(factor_aligned+kx, tmp_fac);
1857     }
1858 }
1859 #else
1860 gmx_inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1861 {
1862     int kx;
1863     real mk;
1864     for (kx = start; kx < end; kx++)
1865     {
1866         d[kx] = 1.0/d[kx];
1867     }
1868
1869     for (kx = start; kx < end; kx++)
1870     {
1871         r[kx] = exp(r[kx]);
1872     }
1873
1874     for (kx = start; kx < end; kx++)
1875     {
1876         mk       = tmp2[kx];
1877         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1878     }
1879 }
1880 #endif
1881
1882 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1883                          real ewaldcoeff, real vol,
1884                          gmx_bool bEnerVir,
1885                          int nthread, int thread)
1886 {
1887     /* do recip sum over local cells in grid */
1888     /* y major, z middle, x minor or continuous */
1889     t_complex *p0;
1890     int     kx, ky, kz, maxkx, maxky, maxkz;
1891     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1892     real    mx, my, mz;
1893     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1894     real    ets2, struct2, vfactor, ets2vf;
1895     real    d1, d2, energy = 0;
1896     real    by, bz;
1897     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1898     real    rxx, ryx, ryy, rzx, rzy, rzz;
1899     pme_work_t *work;
1900     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1901     real    mhxk, mhyk, mhzk, m2k;
1902     real    corner_fac;
1903     ivec    complex_order;
1904     ivec    local_ndata, local_offset, local_size;
1905     real    elfac;
1906
1907     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1908
1909     nx = pme->nkx;
1910     ny = pme->nky;
1911     nz = pme->nkz;
1912
1913     /* Dimensions should be identical for A/B grid, so we just use A here */
1914     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
1915                                       complex_order,
1916                                       local_ndata,
1917                                       local_offset,
1918                                       local_size);
1919
1920     rxx = pme->recipbox[XX][XX];
1921     ryx = pme->recipbox[YY][XX];
1922     ryy = pme->recipbox[YY][YY];
1923     rzx = pme->recipbox[ZZ][XX];
1924     rzy = pme->recipbox[ZZ][YY];
1925     rzz = pme->recipbox[ZZ][ZZ];
1926
1927     maxkx = (nx+1)/2;
1928     maxky = (ny+1)/2;
1929     maxkz = nz/2+1;
1930
1931     work  = &pme->work[thread];
1932     mhx   = work->mhx;
1933     mhy   = work->mhy;
1934     mhz   = work->mhz;
1935     m2    = work->m2;
1936     denom = work->denom;
1937     tmp1  = work->tmp1;
1938     eterm = work->eterm;
1939     m2inv = work->m2inv;
1940
1941     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1942     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1943
1944     for (iyz = iyz0; iyz < iyz1; iyz++)
1945     {
1946         iy = iyz/local_ndata[ZZ];
1947         iz = iyz - iy*local_ndata[ZZ];
1948
1949         ky = iy + local_offset[YY];
1950
1951         if (ky < maxky)
1952         {
1953             my = ky;
1954         }
1955         else
1956         {
1957             my = (ky - ny);
1958         }
1959
1960         by = M_PI*vol*pme->bsp_mod[YY][ky];
1961
1962         kz = iz + local_offset[ZZ];
1963
1964         mz = kz;
1965
1966         bz = pme->bsp_mod[ZZ][kz];
1967
1968         /* 0.5 correction for corner points */
1969         corner_fac = 1;
1970         if (kz == 0 || kz == (nz+1)/2)
1971         {
1972             corner_fac = 0.5;
1973         }
1974
1975         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1976
1977         /* We should skip the k-space point (0,0,0) */
1978         /* Note that since here x is the minor index, local_offset[XX]=0 */
1979         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1980         {
1981             kxstart = local_offset[XX];
1982         }
1983         else
1984         {
1985             kxstart = local_offset[XX] + 1;
1986             p0++;
1987         }
1988         kxend = local_offset[XX] + local_ndata[XX];
1989
1990         if (bEnerVir)
1991         {
1992             /* More expensive inner loop, especially because of the storage
1993              * of the mh elements in array's.
1994              * Because x is the minor grid index, all mh elements
1995              * depend on kx for triclinic unit cells.
1996              */
1997
1998             /* Two explicit loops to avoid a conditional inside the loop */
1999             for (kx = kxstart; kx < maxkx; kx++)
2000             {
2001                 mx = kx;
2002
2003                 mhxk      = mx * rxx;
2004                 mhyk      = mx * ryx + my * ryy;
2005                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2006                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2007                 mhx[kx]   = mhxk;
2008                 mhy[kx]   = mhyk;
2009                 mhz[kx]   = mhzk;
2010                 m2[kx]    = m2k;
2011                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2012                 tmp1[kx]  = -factor*m2k;
2013             }
2014
2015             for (kx = maxkx; kx < kxend; kx++)
2016             {
2017                 mx = (kx - nx);
2018
2019                 mhxk      = mx * rxx;
2020                 mhyk      = mx * ryx + my * ryy;
2021                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2022                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2023                 mhx[kx]   = mhxk;
2024                 mhy[kx]   = mhyk;
2025                 mhz[kx]   = mhzk;
2026                 m2[kx]    = m2k;
2027                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2028                 tmp1[kx]  = -factor*m2k;
2029             }
2030
2031             for (kx = kxstart; kx < kxend; kx++)
2032             {
2033                 m2inv[kx] = 1.0/m2[kx];
2034             }
2035
2036             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2037
2038             for (kx = kxstart; kx < kxend; kx++, p0++)
2039             {
2040                 d1      = p0->re;
2041                 d2      = p0->im;
2042
2043                 p0->re  = d1*eterm[kx];
2044                 p0->im  = d2*eterm[kx];
2045
2046                 struct2 = 2.0*(d1*d1+d2*d2);
2047
2048                 tmp1[kx] = eterm[kx]*struct2;
2049             }
2050
2051             for (kx = kxstart; kx < kxend; kx++)
2052             {
2053                 ets2     = corner_fac*tmp1[kx];
2054                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2055                 energy  += ets2;
2056
2057                 ets2vf   = ets2*vfactor;
2058                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2059                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2060                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2061                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2062                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2063                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2064             }
2065         }
2066         else
2067         {
2068             /* We don't need to calculate the energy and the virial.
2069              * In this case the triclinic overhead is small.
2070              */
2071
2072             /* Two explicit loops to avoid a conditional inside the loop */
2073
2074             for (kx = kxstart; kx < maxkx; kx++)
2075             {
2076                 mx = kx;
2077
2078                 mhxk      = mx * rxx;
2079                 mhyk      = mx * ryx + my * ryy;
2080                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2081                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2082                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2083                 tmp1[kx]  = -factor*m2k;
2084             }
2085
2086             for (kx = maxkx; kx < kxend; kx++)
2087             {
2088                 mx = (kx - nx);
2089
2090                 mhxk      = mx * rxx;
2091                 mhyk      = mx * ryx + my * ryy;
2092                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2093                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2094                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2095                 tmp1[kx]  = -factor*m2k;
2096             }
2097
2098             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2099
2100             for (kx = kxstart; kx < kxend; kx++, p0++)
2101             {
2102                 d1      = p0->re;
2103                 d2      = p0->im;
2104
2105                 p0->re  = d1*eterm[kx];
2106                 p0->im  = d2*eterm[kx];
2107             }
2108         }
2109     }
2110
2111     if (bEnerVir)
2112     {
2113         /* Update virial with local values.
2114          * The virial is symmetric by definition.
2115          * this virial seems ok for isotropic scaling, but I'm
2116          * experiencing problems on semiisotropic membranes.
2117          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2118          */
2119         work->vir_q[XX][XX] = 0.25*virxx;
2120         work->vir_q[YY][YY] = 0.25*viryy;
2121         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2122         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2123         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2124         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2125
2126         /* This energy should be corrected for a charged system */
2127         work->energy_q = 0.5*energy;
2128     }
2129
2130     /* Return the loop count */
2131     return local_ndata[YY]*local_ndata[XX];
2132 }
2133
2134 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2135                             real ewaldcoeff, real vol,
2136                             gmx_bool bEnerVir, int nthread, int thread)
2137 {
2138     /* do recip sum over local cells in grid */
2139     /* y major, z middle, x minor or continuous */
2140     int     ig, gcount;
2141     int     kx, ky, kz, maxkx, maxky, maxkz;
2142     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2143     real    mx, my, mz;
2144     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2145     real    ets2, ets2vf;
2146     real    eterm, vterm, d1, d2, energy = 0;
2147     real    by, bz;
2148     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2149     real    rxx, ryx, ryy, rzx, rzy, rzz;
2150     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2151     real    mhxk, mhyk, mhzk, m2k;
2152     real    mk;
2153     pme_work_t *work;
2154     real    corner_fac;
2155     ivec    complex_order;
2156     ivec    local_ndata, local_offset, local_size;
2157     nx = pme->nkx;
2158     ny = pme->nky;
2159     nz = pme->nkz;
2160
2161     /* Dimensions should be identical for A/B grid, so we just use A here */
2162     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2163                                       complex_order,
2164                                       local_ndata,
2165                                       local_offset,
2166                                       local_size);
2167     rxx = pme->recipbox[XX][XX];
2168     ryx = pme->recipbox[YY][XX];
2169     ryy = pme->recipbox[YY][YY];
2170     rzx = pme->recipbox[ZZ][XX];
2171     rzy = pme->recipbox[ZZ][YY];
2172     rzz = pme->recipbox[ZZ][ZZ];
2173
2174     maxkx = (nx+1)/2;
2175     maxky = (ny+1)/2;
2176     maxkz = nz/2+1;
2177
2178     work  = &pme->work[thread];
2179     mhx   = work->mhx;
2180     mhy   = work->mhy;
2181     mhz   = work->mhz;
2182     m2    = work->m2;
2183     denom = work->denom;
2184     tmp1  = work->tmp1;
2185     tmp2  = work->tmp2;
2186
2187     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2188     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2189
2190     for (iyz = iyz0; iyz < iyz1; iyz++)
2191     {
2192         iy = iyz/local_ndata[ZZ];
2193         iz = iyz - iy*local_ndata[ZZ];
2194
2195         ky = iy + local_offset[YY];
2196
2197         if (ky < maxky)
2198         {
2199             my = ky;
2200         }
2201         else
2202         {
2203             my = (ky - ny);
2204         }
2205
2206         by = 3.0*vol*pme->bsp_mod[YY][ky]
2207             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2208
2209         kz = iz + local_offset[ZZ];
2210
2211         mz = kz;
2212
2213         bz = pme->bsp_mod[ZZ][kz];
2214
2215         /* 0.5 correction for corner points */
2216         corner_fac = 1;
2217         if (kz == 0 || kz == (nz+1)/2)
2218         {
2219             corner_fac = 0.5;
2220         }
2221
2222         kxstart = local_offset[XX];
2223         kxend   = local_offset[XX] + local_ndata[XX];
2224         if (bEnerVir)
2225         {
2226             /* More expensive inner loop, especially because of the
2227              * storage of the mh elements in array's.  Because x is the
2228              * minor grid index, all mh elements depend on kx for
2229              * triclinic unit cells.
2230              */
2231
2232             /* Two explicit loops to avoid a conditional inside the loop */
2233             for (kx = kxstart; kx < maxkx; kx++)
2234             {
2235                 mx = kx;
2236
2237                 mhxk      = mx * rxx;
2238                 mhyk      = mx * ryx + my * ryy;
2239                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2240                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2241                 mhx[kx]   = mhxk;
2242                 mhy[kx]   = mhyk;
2243                 mhz[kx]   = mhzk;
2244                 m2[kx]    = m2k;
2245                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2246                 tmp1[kx]  = -factor*m2k;
2247                 tmp2[kx]  = sqrt(factor*m2k);
2248             }
2249
2250             for (kx = maxkx; kx < kxend; kx++)
2251             {
2252                 mx = (kx - nx);
2253
2254                 mhxk      = mx * rxx;
2255                 mhyk      = mx * ryx + my * ryy;
2256                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2257                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2258                 mhx[kx]   = mhxk;
2259                 mhy[kx]   = mhyk;
2260                 mhz[kx]   = mhzk;
2261                 m2[kx]    = m2k;
2262                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2263                 tmp1[kx]  = -factor*m2k;
2264                 tmp2[kx]  = sqrt(factor*m2k);
2265             }
2266
2267             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2268
2269             for (kx = kxstart; kx < kxend; kx++)
2270             {
2271                 m2k   = factor*m2[kx];
2272                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2273                           + 2.0*m2k*tmp2[kx]);
2274                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2275                 tmp1[kx] = eterm*denom[kx];
2276                 tmp2[kx] = vterm*denom[kx];
2277             }
2278
2279             if (!bLB)
2280             {
2281                 t_complex *p0;
2282                 real       struct2;
2283
2284                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2285                 for (kx = kxstart; kx < kxend; kx++, p0++)
2286                 {
2287                     d1      = p0->re;
2288                     d2      = p0->im;
2289
2290                     eterm   = tmp1[kx];
2291                     vterm   = tmp2[kx];
2292                     p0->re  = d1*eterm;
2293                     p0->im  = d2*eterm;
2294
2295                     struct2 = 2.0*(d1*d1+d2*d2);
2296
2297                     tmp1[kx] = eterm*struct2;
2298                     tmp2[kx] = vterm*struct2;
2299                 }
2300             }
2301             else
2302             {
2303                 real *struct2 = denom;
2304                 real  str2;
2305
2306                 for (kx = kxstart; kx < kxend; kx++)
2307                 {
2308                     struct2[kx] = 0.0;
2309                 }
2310                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2311                 for (ig = 0; ig <= 3; ++ig)
2312                 {
2313                     t_complex *p0, *p1;
2314                     real       scale;
2315
2316                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2317                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2318                     scale = 2.0*lb_scale_factor_symm[ig];
2319                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2320                     {
2321                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2322                     }
2323
2324                 }
2325                 for (ig = 0; ig <= 6; ++ig)
2326                 {
2327                     t_complex *p0;
2328
2329                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2330                     for (kx = kxstart; kx < kxend; kx++, p0++)
2331                     {
2332                         d1     = p0->re;
2333                         d2     = p0->im;
2334
2335                         eterm  = tmp1[kx];
2336                         p0->re = d1*eterm;
2337                         p0->im = d2*eterm;
2338                     }
2339                 }
2340                 for (kx = kxstart; kx < kxend; kx++)
2341                 {
2342                     eterm    = tmp1[kx];
2343                     vterm    = tmp2[kx];
2344                     str2     = struct2[kx];
2345                     tmp1[kx] = eterm*str2;
2346                     tmp2[kx] = vterm*str2;
2347                 }
2348             }
2349
2350             for (kx = kxstart; kx < kxend; kx++)
2351             {
2352                 ets2     = corner_fac*tmp1[kx];
2353                 vterm    = 2.0*factor*tmp2[kx];
2354                 energy  += ets2;
2355                 ets2vf   = corner_fac*vterm;
2356                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2357                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2358                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2359                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2360                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2361                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2362             }
2363         }
2364         else
2365         {
2366             /* We don't need to calculate the energy and the virial.
2367              *  In this case the triclinic overhead is small.
2368              */
2369
2370             /* Two explicit loops to avoid a conditional inside the loop */
2371
2372             for (kx = kxstart; kx < maxkx; kx++)
2373             {
2374                 mx = kx;
2375
2376                 mhxk      = mx * rxx;
2377                 mhyk      = mx * ryx + my * ryy;
2378                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2379                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2380                 m2[kx]    = m2k;
2381                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2382                 tmp1[kx]  = -factor*m2k;
2383                 tmp2[kx]  = sqrt(factor*m2k);
2384             }
2385
2386             for (kx = maxkx; kx < kxend; kx++)
2387             {
2388                 mx = (kx - nx);
2389
2390                 mhxk      = mx * rxx;
2391                 mhyk      = mx * ryx + my * ryy;
2392                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2393                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2394                 m2[kx]    = m2k;
2395                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2396                 tmp1[kx]  = -factor*m2k;
2397                 tmp2[kx]  = sqrt(factor*m2k);
2398             }
2399
2400             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2401
2402             for (kx = kxstart; kx < kxend; kx++)
2403             {
2404                 m2k    = factor*m2[kx];
2405                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2406                            + 2.0*m2k*tmp2[kx]);
2407                 tmp1[kx] = eterm*denom[kx];
2408             }
2409             gcount = (bLB ? 7 : 1);
2410             for (ig = 0; ig < gcount; ++ig)
2411             {
2412                 t_complex *p0;
2413
2414                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2415                 for (kx = kxstart; kx < kxend; kx++, p0++)
2416                 {
2417                     d1      = p0->re;
2418                     d2      = p0->im;
2419
2420                     eterm   = tmp1[kx];
2421
2422                     p0->re  = d1*eterm;
2423                     p0->im  = d2*eterm;
2424                 }
2425             }
2426         }
2427     }
2428     if (bEnerVir)
2429     {
2430         work->vir_lj[XX][XX] = 0.25*virxx;
2431         work->vir_lj[YY][YY] = 0.25*viryy;
2432         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2433         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2434         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2435         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2436
2437         /* This energy should be corrected for a charged system */
2438         work->energy_lj = 0.5*energy;
2439     }
2440     /* Return the loop count */
2441     return local_ndata[YY]*local_ndata[XX];
2442 }
2443
2444 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2445                                real *mesh_energy, matrix vir)
2446 {
2447     /* This function sums output over threads and should therefore
2448      * only be called after thread synchronization.
2449      */
2450     int thread;
2451
2452     *mesh_energy = pme->work[0].energy_q;
2453     copy_mat(pme->work[0].vir_q, vir);
2454
2455     for (thread = 1; thread < nthread; thread++)
2456     {
2457         *mesh_energy += pme->work[thread].energy_q;
2458         m_add(vir, pme->work[thread].vir_q, vir);
2459     }
2460 }
2461
2462 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2463                                 real *mesh_energy, matrix vir)
2464 {
2465     /* This function sums output over threads and should therefore
2466      * only be called after thread synchronization.
2467      */
2468     int thread;
2469
2470     *mesh_energy = pme->work[0].energy_lj;
2471     copy_mat(pme->work[0].vir_lj, vir);
2472
2473     for (thread = 1; thread < nthread; thread++)
2474     {
2475         *mesh_energy += pme->work[thread].energy_lj;
2476         m_add(vir, pme->work[thread].vir_lj, vir);
2477     }
2478 }
2479
2480
2481 #define DO_FSPLINE(order)                      \
2482     for (ithx = 0; (ithx < order); ithx++)              \
2483     {                                              \
2484         index_x = (i0+ithx)*pny*pnz;               \
2485         tx      = thx[ithx];                       \
2486         dx      = dthx[ithx];                      \
2487                                                \
2488         for (ithy = 0; (ithy < order); ithy++)          \
2489         {                                          \
2490             index_xy = index_x+(j0+ithy)*pnz;      \
2491             ty       = thy[ithy];                  \
2492             dy       = dthy[ithy];                 \
2493             fxy1     = fz1 = 0;                    \
2494                                                \
2495             for (ithz = 0; (ithz < order); ithz++)      \
2496             {                                      \
2497                 gval  = grid[index_xy+(k0+ithz)];  \
2498                 fxy1 += thz[ithz]*gval;            \
2499                 fz1  += dthz[ithz]*gval;           \
2500             }                                      \
2501             fx += dx*ty*fxy1;                      \
2502             fy += tx*dy*fxy1;                      \
2503             fz += tx*ty*fz1;                       \
2504         }                                          \
2505     }
2506
2507
2508 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2509                               gmx_bool bClearF, pme_atomcomm_t *atc,
2510                               splinedata_t *spline,
2511                               real scale)
2512 {
2513     /* sum forces for local particles */
2514     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2515     int     index_x, index_xy;
2516     int     nx, ny, nz, pnx, pny, pnz;
2517     int *   idxptr;
2518     real    tx, ty, dx, dy, coefficient;
2519     real    fx, fy, fz, gval;
2520     real    fxy1, fz1;
2521     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2522     int     norder;
2523     real    rxx, ryx, ryy, rzx, rzy, rzz;
2524     int     order;
2525
2526     pme_spline_work_t *work;
2527
2528 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2529     real           thz_buffer[GMX_SIMD4_WIDTH*3],  *thz_aligned;
2530     real           dthz_buffer[GMX_SIMD4_WIDTH*3], *dthz_aligned;
2531
2532     thz_aligned  = gmx_simd4_align_r(thz_buffer);
2533     dthz_aligned = gmx_simd4_align_r(dthz_buffer);
2534 #endif
2535
2536     work = pme->spline_work;
2537
2538     order = pme->pme_order;
2539     thx   = spline->theta[XX];
2540     thy   = spline->theta[YY];
2541     thz   = spline->theta[ZZ];
2542     dthx  = spline->dtheta[XX];
2543     dthy  = spline->dtheta[YY];
2544     dthz  = spline->dtheta[ZZ];
2545     nx    = pme->nkx;
2546     ny    = pme->nky;
2547     nz    = pme->nkz;
2548     pnx   = pme->pmegrid_nx;
2549     pny   = pme->pmegrid_ny;
2550     pnz   = pme->pmegrid_nz;
2551
2552     rxx   = pme->recipbox[XX][XX];
2553     ryx   = pme->recipbox[YY][XX];
2554     ryy   = pme->recipbox[YY][YY];
2555     rzx   = pme->recipbox[ZZ][XX];
2556     rzy   = pme->recipbox[ZZ][YY];
2557     rzz   = pme->recipbox[ZZ][ZZ];
2558
2559     for (nn = 0; nn < spline->n; nn++)
2560     {
2561         n           = spline->ind[nn];
2562         coefficient = scale*atc->coefficient[n];
2563
2564         if (bClearF)
2565         {
2566             atc->f[n][XX] = 0;
2567             atc->f[n][YY] = 0;
2568             atc->f[n][ZZ] = 0;
2569         }
2570         if (coefficient != 0)
2571         {
2572             fx     = 0;
2573             fy     = 0;
2574             fz     = 0;
2575             idxptr = atc->idx[n];
2576             norder = nn*order;
2577
2578             i0   = idxptr[XX];
2579             j0   = idxptr[YY];
2580             k0   = idxptr[ZZ];
2581
2582             /* Pointer arithmetic alert, next six statements */
2583             thx  = spline->theta[XX] + norder;
2584             thy  = spline->theta[YY] + norder;
2585             thz  = spline->theta[ZZ] + norder;
2586             dthx = spline->dtheta[XX] + norder;
2587             dthy = spline->dtheta[YY] + norder;
2588             dthz = spline->dtheta[ZZ] + norder;
2589
2590             switch (order)
2591             {
2592                 case 4:
2593 #ifdef PME_SIMD4_SPREAD_GATHER
2594 #ifdef PME_SIMD4_UNALIGNED
2595 #define PME_GATHER_F_SIMD4_ORDER4
2596 #else
2597 #define PME_GATHER_F_SIMD4_ALIGNED
2598 #define PME_ORDER 4
2599 #endif
2600 #include "pme_simd4.h"
2601 #else
2602                     DO_FSPLINE(4);
2603 #endif
2604                     break;
2605                 case 5:
2606 #ifdef PME_SIMD4_SPREAD_GATHER
2607 #define PME_GATHER_F_SIMD4_ALIGNED
2608 #define PME_ORDER 5
2609 #include "pme_simd4.h"
2610 #else
2611                     DO_FSPLINE(5);
2612 #endif
2613                     break;
2614                 default:
2615                     DO_FSPLINE(order);
2616                     break;
2617             }
2618
2619             atc->f[n][XX] += -coefficient*( fx*nx*rxx );
2620             atc->f[n][YY] += -coefficient*( fx*nx*ryx + fy*ny*ryy );
2621             atc->f[n][ZZ] += -coefficient*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2622         }
2623     }
2624     /* Since the energy and not forces are interpolated
2625      * the net force might not be exactly zero.
2626      * This can be solved by also interpolating F, but
2627      * that comes at a cost.
2628      * A better hack is to remove the net force every
2629      * step, but that must be done at a higher level
2630      * since this routine doesn't see all atoms if running
2631      * in parallel. Don't know how important it is?  EL 990726
2632      */
2633 }
2634
2635
2636 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2637                                    pme_atomcomm_t *atc)
2638 {
2639     splinedata_t *spline;
2640     int     n, ithx, ithy, ithz, i0, j0, k0;
2641     int     index_x, index_xy;
2642     int *   idxptr;
2643     real    energy, pot, tx, ty, coefficient, gval;
2644     real    *thx, *thy, *thz;
2645     int     norder;
2646     int     order;
2647
2648     spline = &atc->spline[0];
2649
2650     order = pme->pme_order;
2651
2652     energy = 0;
2653     for (n = 0; (n < atc->n); n++)
2654     {
2655         coefficient      = atc->coefficient[n];
2656
2657         if (coefficient != 0)
2658         {
2659             idxptr = atc->idx[n];
2660             norder = n*order;
2661
2662             i0   = idxptr[XX];
2663             j0   = idxptr[YY];
2664             k0   = idxptr[ZZ];
2665
2666             /* Pointer arithmetic alert, next three statements */
2667             thx  = spline->theta[XX] + norder;
2668             thy  = spline->theta[YY] + norder;
2669             thz  = spline->theta[ZZ] + norder;
2670
2671             pot = 0;
2672             for (ithx = 0; (ithx < order); ithx++)
2673             {
2674                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2675                 tx      = thx[ithx];
2676
2677                 for (ithy = 0; (ithy < order); ithy++)
2678                 {
2679                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2680                     ty       = thy[ithy];
2681
2682                     for (ithz = 0; (ithz < order); ithz++)
2683                     {
2684                         gval  = grid[index_xy+(k0+ithz)];
2685                         pot  += tx*ty*thz[ithz]*gval;
2686                     }
2687
2688                 }
2689             }
2690
2691             energy += pot*coefficient;
2692         }
2693     }
2694
2695     return energy;
2696 }
2697
2698 /* Macro to force loop unrolling by fixing order.
2699  * This gives a significant performance gain.
2700  */
2701 #define CALC_SPLINE(order)                     \
2702     {                                              \
2703         int j, k, l;                                 \
2704         real dr, div;                               \
2705         real data[PME_ORDER_MAX];                  \
2706         real ddata[PME_ORDER_MAX];                 \
2707                                                \
2708         for (j = 0; (j < DIM); j++)                     \
2709         {                                          \
2710             dr  = xptr[j];                         \
2711                                                \
2712             /* dr is relative offset from lower cell limit */ \
2713             data[order-1] = 0;                     \
2714             data[1]       = dr;                          \
2715             data[0]       = 1 - dr;                      \
2716                                                \
2717             for (k = 3; (k < order); k++)               \
2718             {                                      \
2719                 div       = 1.0/(k - 1.0);               \
2720                 data[k-1] = div*dr*data[k-2];      \
2721                 for (l = 1; (l < (k-1)); l++)           \
2722                 {                                  \
2723                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2724                                        data[k-l-1]);                \
2725                 }                                  \
2726                 data[0] = div*(1-dr)*data[0];      \
2727             }                                      \
2728             /* differentiate */                    \
2729             ddata[0] = -data[0];                   \
2730             for (k = 1; (k < order); k++)               \
2731             {                                      \
2732                 ddata[k] = data[k-1] - data[k];    \
2733             }                                      \
2734                                                \
2735             div           = 1.0/(order - 1);                 \
2736             data[order-1] = div*dr*data[order-2];  \
2737             for (l = 1; (l < (order-1)); l++)           \
2738             {                                      \
2739                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2740                                        (order-l-dr)*data[order-l-1]); \
2741             }                                      \
2742             data[0] = div*(1 - dr)*data[0];        \
2743                                                \
2744             for (k = 0; k < order; k++)                 \
2745             {                                      \
2746                 theta[j][i*order+k]  = data[k];    \
2747                 dtheta[j][i*order+k] = ddata[k];   \
2748             }                                      \
2749         }                                          \
2750     }
2751
2752 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2753                    rvec fractx[], int nr, int ind[], real coefficient[],
2754                    gmx_bool bDoSplines)
2755 {
2756     /* construct splines for local atoms */
2757     int  i, ii;
2758     real *xptr;
2759
2760     for (i = 0; i < nr; i++)
2761     {
2762         /* With free energy we do not use the coefficient check.
2763          * In most cases this will be more efficient than calling make_bsplines
2764          * twice, since usually more than half the particles have non-zero coefficients.
2765          */
2766         ii = ind[i];
2767         if (bDoSplines || coefficient[ii] != 0.0)
2768         {
2769             xptr = fractx[ii];
2770             switch (order)
2771             {
2772                 case 4:  CALC_SPLINE(4);     break;
2773                 case 5:  CALC_SPLINE(5);     break;
2774                 default: CALC_SPLINE(order); break;
2775             }
2776         }
2777     }
2778 }
2779
2780
2781 void make_dft_mod(real *mod, real *data, int ndata)
2782 {
2783     int i, j;
2784     real sc, ss, arg;
2785
2786     for (i = 0; i < ndata; i++)
2787     {
2788         sc = ss = 0;
2789         for (j = 0; j < ndata; j++)
2790         {
2791             arg = (2.0*M_PI*i*j)/ndata;
2792             sc += data[j]*cos(arg);
2793             ss += data[j]*sin(arg);
2794         }
2795         mod[i] = sc*sc+ss*ss;
2796     }
2797     for (i = 0; i < ndata; i++)
2798     {
2799         if (mod[i] < 1e-7)
2800         {
2801             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2802         }
2803     }
2804 }
2805
2806
2807 static void make_bspline_moduli(splinevec bsp_mod,
2808                                 int nx, int ny, int nz, int order)
2809 {
2810     int nmax = max(nx, max(ny, nz));
2811     real *data, *ddata, *bsp_data;
2812     int i, k, l;
2813     real div;
2814
2815     snew(data, order);
2816     snew(ddata, order);
2817     snew(bsp_data, nmax);
2818
2819     data[order-1] = 0;
2820     data[1]       = 0;
2821     data[0]       = 1;
2822
2823     for (k = 3; k < order; k++)
2824     {
2825         div       = 1.0/(k-1.0);
2826         data[k-1] = 0;
2827         for (l = 1; l < (k-1); l++)
2828         {
2829             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2830         }
2831         data[0] = div*data[0];
2832     }
2833     /* differentiate */
2834     ddata[0] = -data[0];
2835     for (k = 1; k < order; k++)
2836     {
2837         ddata[k] = data[k-1]-data[k];
2838     }
2839     div           = 1.0/(order-1);
2840     data[order-1] = 0;
2841     for (l = 1; l < (order-1); l++)
2842     {
2843         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2844     }
2845     data[0] = div*data[0];
2846
2847     for (i = 0; i < nmax; i++)
2848     {
2849         bsp_data[i] = 0;
2850     }
2851     for (i = 1; i <= order; i++)
2852     {
2853         bsp_data[i] = data[i-1];
2854     }
2855
2856     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2857     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2858     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2859
2860     sfree(data);
2861     sfree(ddata);
2862     sfree(bsp_data);
2863 }
2864
2865
2866 /* Return the P3M optimal influence function */
2867 static double do_p3m_influence(double z, int order)
2868 {
2869     double z2, z4;
2870
2871     z2 = z*z;
2872     z4 = z2*z2;
2873
2874     /* The formula and most constants can be found in:
2875      * Ballenegger et al., JCTC 8, 936 (2012)
2876      */
2877     switch (order)
2878     {
2879         case 2:
2880             return 1.0 - 2.0*z2/3.0;
2881             break;
2882         case 3:
2883             return 1.0 - z2 + 2.0*z4/15.0;
2884             break;
2885         case 4:
2886             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2887             break;
2888         case 5:
2889             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2890             break;
2891         case 6:
2892             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2893             break;
2894         case 7:
2895             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2896         case 8:
2897             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2898             break;
2899     }
2900
2901     return 0.0;
2902 }
2903
2904 /* Calculate the P3M B-spline moduli for one dimension */
2905 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2906 {
2907     double zarg, zai, sinzai, infl;
2908     int    maxk, i;
2909
2910     if (order > 8)
2911     {
2912         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2913     }
2914
2915     zarg = M_PI/n;
2916
2917     maxk = (n + 1)/2;
2918
2919     for (i = -maxk; i < 0; i++)
2920     {
2921         zai          = zarg*i;
2922         sinzai       = sin(zai);
2923         infl         = do_p3m_influence(sinzai, order);
2924         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2925     }
2926     bsp_mod[0] = 1.0;
2927     for (i = 1; i < maxk; i++)
2928     {
2929         zai        = zarg*i;
2930         sinzai     = sin(zai);
2931         infl       = do_p3m_influence(sinzai, order);
2932         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2933     }
2934 }
2935
2936 /* Calculate the P3M B-spline moduli */
2937 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2938                                     int nx, int ny, int nz, int order)
2939 {
2940     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2941     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2942     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2943 }
2944
2945
2946 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2947 {
2948     int nslab, n, i;
2949     int fw, bw;
2950
2951     nslab = atc->nslab;
2952
2953     n = 0;
2954     for (i = 1; i <= nslab/2; i++)
2955     {
2956         fw = (atc->nodeid + i) % nslab;
2957         bw = (atc->nodeid - i + nslab) % nslab;
2958         if (n < nslab - 1)
2959         {
2960             atc->node_dest[n] = fw;
2961             atc->node_src[n]  = bw;
2962             n++;
2963         }
2964         if (n < nslab - 1)
2965         {
2966             atc->node_dest[n] = bw;
2967             atc->node_src[n]  = fw;
2968             n++;
2969         }
2970     }
2971 }
2972
2973 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2974 {
2975     int thread, i;
2976
2977     if (NULL != log)
2978     {
2979         fprintf(log, "Destroying PME data structures.\n");
2980     }
2981
2982     sfree((*pmedata)->nnx);
2983     sfree((*pmedata)->nny);
2984     sfree((*pmedata)->nnz);
2985
2986     for (i = 0; i < (*pmedata)->ngrids; ++i)
2987     {
2988         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
2989         sfree((*pmedata)->fftgrid[i]);
2990         sfree((*pmedata)->cfftgrid[i]);
2991         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
2992     }
2993
2994     sfree((*pmedata)->lb_buf1);
2995     sfree((*pmedata)->lb_buf2);
2996
2997     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2998     {
2999         free_work(&(*pmedata)->work[thread]);
3000     }
3001     sfree((*pmedata)->work);
3002
3003     sfree(*pmedata);
3004     *pmedata = NULL;
3005
3006     return 0;
3007 }
3008
3009 static int mult_up(int n, int f)
3010 {
3011     return ((n + f - 1)/f)*f;
3012 }
3013
3014
3015 static double pme_load_imbalance(gmx_pme_t pme)
3016 {
3017     int    nma, nmi;
3018     double n1, n2, n3;
3019
3020     nma = pme->nnodes_major;
3021     nmi = pme->nnodes_minor;
3022
3023     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3024     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3025     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3026
3027     /* pme_solve is roughly double the cost of an fft */
3028
3029     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3030 }
3031
3032 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3033                           int dimind, gmx_bool bSpread)
3034 {
3035     int nk, k, s, thread;
3036
3037     atc->dimind    = dimind;
3038     atc->nslab     = 1;
3039     atc->nodeid    = 0;
3040     atc->pd_nalloc = 0;
3041 #ifdef GMX_MPI
3042     if (pme->nnodes > 1)
3043     {
3044         atc->mpi_comm = pme->mpi_comm_d[dimind];
3045         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3046         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3047     }
3048     if (debug)
3049     {
3050         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3051     }
3052 #endif
3053
3054     atc->bSpread   = bSpread;
3055     atc->pme_order = pme->pme_order;
3056
3057     if (atc->nslab > 1)
3058     {
3059         snew(atc->node_dest, atc->nslab);
3060         snew(atc->node_src, atc->nslab);
3061         setup_coordinate_communication(atc);
3062
3063         snew(atc->count_thread, pme->nthread);
3064         for (thread = 0; thread < pme->nthread; thread++)
3065         {
3066             snew(atc->count_thread[thread], atc->nslab);
3067         }
3068         atc->count = atc->count_thread[0];
3069         snew(atc->rcount, atc->nslab);
3070         snew(atc->buf_index, atc->nslab);
3071     }
3072
3073     atc->nthread = pme->nthread;
3074     if (atc->nthread > 1)
3075     {
3076         snew(atc->thread_plist, atc->nthread);
3077     }
3078     snew(atc->spline, atc->nthread);
3079     for (thread = 0; thread < atc->nthread; thread++)
3080     {
3081         if (atc->nthread > 1)
3082         {
3083             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3084             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3085         }
3086         snew(atc->spline[thread].thread_one, pme->nthread);
3087         atc->spline[thread].thread_one[thread] = 1;
3088     }
3089 }
3090
3091 static void
3092 init_overlap_comm(pme_overlap_t *  ol,
3093                   int              norder,
3094 #ifdef GMX_MPI
3095                   MPI_Comm         comm,
3096 #endif
3097                   int              nnodes,
3098                   int              nodeid,
3099                   int              ndata,
3100                   int              commplainsize)
3101 {
3102     int lbnd, rbnd, maxlr, b, i;
3103     int exten;
3104     int nn, nk;
3105     pme_grid_comm_t *pgc;
3106     gmx_bool bCont;
3107     int fft_start, fft_end, send_index1, recv_index1;
3108 #ifdef GMX_MPI
3109     MPI_Status stat;
3110
3111     ol->mpi_comm = comm;
3112 #endif
3113
3114     ol->nnodes = nnodes;
3115     ol->nodeid = nodeid;
3116
3117     /* Linear translation of the PME grid won't affect reciprocal space
3118      * calculations, so to optimize we only interpolate "upwards",
3119      * which also means we only have to consider overlap in one direction.
3120      * I.e., particles on this node might also be spread to grid indices
3121      * that belong to higher nodes (modulo nnodes)
3122      */
3123
3124     snew(ol->s2g0, ol->nnodes+1);
3125     snew(ol->s2g1, ol->nnodes);
3126     if (debug)
3127     {
3128         fprintf(debug, "PME slab boundaries:");
3129     }
3130     for (i = 0; i < nnodes; i++)
3131     {
3132         /* s2g0 the local interpolation grid start.
3133          * s2g1 the local interpolation grid end.
3134          * Because grid overlap communication only goes forward,
3135          * the grid the slabs for fft's should be rounded down.
3136          */
3137         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3138         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3139
3140         if (debug)
3141         {
3142             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3143         }
3144     }
3145     ol->s2g0[nnodes] = ndata;
3146     if (debug)
3147     {
3148         fprintf(debug, "\n");
3149     }
3150
3151     /* Determine with how many nodes we need to communicate the grid overlap */
3152     b = 0;
3153     do
3154     {
3155         b++;
3156         bCont = FALSE;
3157         for (i = 0; i < nnodes; i++)
3158         {
3159             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3160                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3161             {
3162                 bCont = TRUE;
3163             }
3164         }
3165     }
3166     while (bCont && b < nnodes);
3167     ol->noverlap_nodes = b - 1;
3168
3169     snew(ol->send_id, ol->noverlap_nodes);
3170     snew(ol->recv_id, ol->noverlap_nodes);
3171     for (b = 0; b < ol->noverlap_nodes; b++)
3172     {
3173         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3174         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3175     }
3176     snew(ol->comm_data, ol->noverlap_nodes);
3177
3178     ol->send_size = 0;
3179     for (b = 0; b < ol->noverlap_nodes; b++)
3180     {
3181         pgc = &ol->comm_data[b];
3182         /* Send */
3183         fft_start        = ol->s2g0[ol->send_id[b]];
3184         fft_end          = ol->s2g0[ol->send_id[b]+1];
3185         if (ol->send_id[b] < nodeid)
3186         {
3187             fft_start += ndata;
3188             fft_end   += ndata;
3189         }
3190         send_index1       = ol->s2g1[nodeid];
3191         send_index1       = min(send_index1, fft_end);
3192         pgc->send_index0  = fft_start;
3193         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3194         ol->send_size    += pgc->send_nindex;
3195
3196         /* We always start receiving to the first index of our slab */
3197         fft_start        = ol->s2g0[ol->nodeid];
3198         fft_end          = ol->s2g0[ol->nodeid+1];
3199         recv_index1      = ol->s2g1[ol->recv_id[b]];
3200         if (ol->recv_id[b] > nodeid)
3201         {
3202             recv_index1 -= ndata;
3203         }
3204         recv_index1      = min(recv_index1, fft_end);
3205         pgc->recv_index0 = fft_start;
3206         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3207     }
3208
3209 #ifdef GMX_MPI
3210     /* Communicate the buffer sizes to receive */
3211     for (b = 0; b < ol->noverlap_nodes; b++)
3212     {
3213         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3214                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3215                      ol->mpi_comm, &stat);
3216     }
3217 #endif
3218
3219     /* For non-divisible grid we need pme_order iso pme_order-1 */
3220     snew(ol->sendbuf, norder*commplainsize);
3221     snew(ol->recvbuf, norder*commplainsize);
3222 }
3223
3224 static void
3225 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3226                               int **global_to_local,
3227                               real **fraction_shift)
3228 {
3229     int i;
3230     int * gtl;
3231     real * fsh;
3232
3233     snew(gtl, 5*n);
3234     snew(fsh, 5*n);
3235     for (i = 0; (i < 5*n); i++)
3236     {
3237         /* Determine the global to local grid index */
3238         gtl[i] = (i - local_start + n) % n;
3239         /* For coordinates that fall within the local grid the fraction
3240          * is correct, we don't need to shift it.
3241          */
3242         fsh[i] = 0;
3243         if (local_range < n)
3244         {
3245             /* Due to rounding issues i could be 1 beyond the lower or
3246              * upper boundary of the local grid. Correct the index for this.
3247              * If we shift the index, we need to shift the fraction by
3248              * the same amount in the other direction to not affect
3249              * the weights.
3250              * Note that due to this shifting the weights at the end of
3251              * the spline might change, but that will only involve values
3252              * between zero and values close to the precision of a real,
3253              * which is anyhow the accuracy of the whole mesh calculation.
3254              */
3255             /* With local_range=0 we should not change i=local_start */
3256             if (i % n != local_start)
3257             {
3258                 if (gtl[i] == n-1)
3259                 {
3260                     gtl[i] = 0;
3261                     fsh[i] = -1;
3262                 }
3263                 else if (gtl[i] == local_range)
3264                 {
3265                     gtl[i] = local_range - 1;
3266                     fsh[i] = 1;
3267                 }
3268             }
3269         }
3270     }
3271
3272     *global_to_local = gtl;
3273     *fraction_shift  = fsh;
3274 }
3275
3276 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3277 {
3278     pme_spline_work_t *work;
3279
3280 #ifdef PME_SIMD4_SPREAD_GATHER
3281     real             tmp[GMX_SIMD4_WIDTH*3], *tmp_aligned;
3282     gmx_simd4_real_t zero_S;
3283     gmx_simd4_real_t real_mask_S0, real_mask_S1;
3284     int              of, i;
3285
3286     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3287
3288     tmp_aligned = gmx_simd4_align_r(tmp);
3289
3290     zero_S = gmx_simd4_setzero_r();
3291
3292     /* Generate bit masks to mask out the unused grid entries,
3293      * as we only operate on order of the 8 grid entries that are
3294      * load into 2 SIMD registers.
3295      */
3296     for (of = 0; of < 2*GMX_SIMD4_WIDTH-(order-1); of++)
3297     {
3298         for (i = 0; i < 2*GMX_SIMD4_WIDTH; i++)
3299         {
3300             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3301         }
3302         real_mask_S0      = gmx_simd4_load_r(tmp_aligned);
3303         real_mask_S1      = gmx_simd4_load_r(tmp_aligned+GMX_SIMD4_WIDTH);
3304         work->mask_S0[of] = gmx_simd4_cmplt_r(real_mask_S0, zero_S);
3305         work->mask_S1[of] = gmx_simd4_cmplt_r(real_mask_S1, zero_S);
3306     }
3307 #else
3308     work = NULL;
3309 #endif
3310
3311     return work;
3312 }
3313
3314 void gmx_pme_check_restrictions(int pme_order,
3315                                 int nkx, int nky, int nkz,
3316                                 int nnodes_major,
3317                                 int nnodes_minor,
3318                                 gmx_bool bUseThreads,
3319                                 gmx_bool bFatal,
3320                                 gmx_bool *bValidSettings)
3321 {
3322     if (pme_order > PME_ORDER_MAX)
3323     {
3324         if (!bFatal)
3325         {
3326             *bValidSettings = FALSE;
3327             return;
3328         }
3329         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3330                   pme_order, PME_ORDER_MAX);
3331     }
3332
3333     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3334         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3335         nkz <= pme_order)
3336     {
3337         if (!bFatal)
3338         {
3339             *bValidSettings = FALSE;
3340             return;
3341         }
3342         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3343                   pme_order);
3344     }
3345
3346     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3347      * We only allow multiple communication pulses in dim 1, not in dim 0.
3348      */
3349     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3350                         nkx != nnodes_major*(pme_order - 1)))
3351     {
3352         if (!bFatal)
3353         {
3354             *bValidSettings = FALSE;
3355             return;
3356         }
3357         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3358                   nkx/(double)nnodes_major, pme_order);
3359     }
3360
3361     if (bValidSettings != NULL)
3362     {
3363         *bValidSettings = TRUE;
3364     }
3365
3366     return;
3367 }
3368
3369 int gmx_pme_init(gmx_pme_t *         pmedata,
3370                  t_commrec *         cr,
3371                  int                 nnodes_major,
3372                  int                 nnodes_minor,
3373                  t_inputrec *        ir,
3374                  int                 homenr,
3375                  gmx_bool            bFreeEnergy_q,
3376                  gmx_bool            bFreeEnergy_lj,
3377                  gmx_bool            bReproducible,
3378                  int                 nthread)
3379 {
3380     gmx_pme_t pme = NULL;
3381
3382     int  use_threads, sum_use_threads, i;
3383     ivec ndata;
3384
3385     if (debug)
3386     {
3387         fprintf(debug, "Creating PME data structures.\n");
3388     }
3389     snew(pme, 1);
3390
3391     pme->sum_qgrid_tmp       = NULL;
3392     pme->sum_qgrid_dd_tmp    = NULL;
3393     pme->buf_nalloc          = 0;
3394
3395     pme->nnodes              = 1;
3396     pme->bPPnode             = TRUE;
3397
3398     pme->nnodes_major        = nnodes_major;
3399     pme->nnodes_minor        = nnodes_minor;
3400
3401 #ifdef GMX_MPI
3402     if (nnodes_major*nnodes_minor > 1)
3403     {
3404         pme->mpi_comm = cr->mpi_comm_mygroup;
3405
3406         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3407         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3408         if (pme->nnodes != nnodes_major*nnodes_minor)
3409         {
3410             gmx_incons("PME node count mismatch");
3411         }
3412     }
3413     else
3414     {
3415         pme->mpi_comm = MPI_COMM_NULL;
3416     }
3417 #endif
3418
3419     if (pme->nnodes == 1)
3420     {
3421 #ifdef GMX_MPI
3422         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3423         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3424 #endif
3425         pme->ndecompdim   = 0;
3426         pme->nodeid_major = 0;
3427         pme->nodeid_minor = 0;
3428 #ifdef GMX_MPI
3429         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3430 #endif
3431     }
3432     else
3433     {
3434         if (nnodes_minor == 1)
3435         {
3436 #ifdef GMX_MPI
3437             pme->mpi_comm_d[0] = pme->mpi_comm;
3438             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3439 #endif
3440             pme->ndecompdim   = 1;
3441             pme->nodeid_major = pme->nodeid;
3442             pme->nodeid_minor = 0;
3443
3444         }
3445         else if (nnodes_major == 1)
3446         {
3447 #ifdef GMX_MPI
3448             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3449             pme->mpi_comm_d[1] = pme->mpi_comm;
3450 #endif
3451             pme->ndecompdim   = 1;
3452             pme->nodeid_major = 0;
3453             pme->nodeid_minor = pme->nodeid;
3454         }
3455         else
3456         {
3457             if (pme->nnodes % nnodes_major != 0)
3458             {
3459                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3460             }
3461             pme->ndecompdim = 2;
3462
3463 #ifdef GMX_MPI
3464             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3465                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3466             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3467                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3468
3469             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3470             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3471             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3472             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3473 #endif
3474         }
3475         pme->bPPnode = (cr->duty & DUTY_PP);
3476     }
3477
3478     pme->nthread = nthread;
3479
3480     /* Check if any of the PME MPI ranks uses threads */
3481     use_threads = (pme->nthread > 1 ? 1 : 0);
3482 #ifdef GMX_MPI
3483     if (pme->nnodes > 1)
3484     {
3485         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3486                       MPI_SUM, pme->mpi_comm);
3487     }
3488     else
3489 #endif
3490     {
3491         sum_use_threads = use_threads;
3492     }
3493     pme->bUseThreads = (sum_use_threads > 0);
3494
3495     if (ir->ePBC == epbcSCREW)
3496     {
3497         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3498     }
3499
3500     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3501     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3502     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3503     pme->nkx         = ir->nkx;
3504     pme->nky         = ir->nky;
3505     pme->nkz         = ir->nkz;
3506     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3507     pme->pme_order   = ir->pme_order;
3508
3509     /* Always constant electrostatics coefficients */
3510     pme->epsilon_r   = ir->epsilon_r;
3511
3512     /* Always constant LJ coefficients */
3513     pme->ljpme_combination_rule = ir->ljpme_combination_rule;
3514
3515     /* If we violate restrictions, generate a fatal error here */
3516     gmx_pme_check_restrictions(pme->pme_order,
3517                                pme->nkx, pme->nky, pme->nkz,
3518                                pme->nnodes_major,
3519                                pme->nnodes_minor,
3520                                pme->bUseThreads,
3521                                TRUE,
3522                                NULL);
3523
3524     if (pme->nnodes > 1)
3525     {
3526         double imbal;
3527
3528 #ifdef GMX_MPI
3529         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3530         MPI_Type_commit(&(pme->rvec_mpi));
3531 #endif
3532
3533         /* Note that the coefficient spreading and force gathering, which usually
3534          * takes about the same amount of time as FFT+solve_pme,
3535          * is always fully load balanced
3536          * (unless the coefficient distribution is inhomogeneous).
3537          */
3538
3539         imbal = pme_load_imbalance(pme);
3540         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3541         {
3542             fprintf(stderr,
3543                     "\n"
3544                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3545                     "      For optimal PME load balancing\n"
3546                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3547                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3548                     "\n",
3549                     (int)((imbal-1)*100 + 0.5),
3550                     pme->nkx, pme->nky, pme->nnodes_major,
3551                     pme->nky, pme->nkz, pme->nnodes_minor);
3552         }
3553     }
3554
3555     /* For non-divisible grid we need pme_order iso pme_order-1 */
3556     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3557      * y is always copied through a buffer: we don't need padding in z,
3558      * but we do need the overlap in x because of the communication order.
3559      */
3560     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3561 #ifdef GMX_MPI
3562                       pme->mpi_comm_d[0],
3563 #endif
3564                       pme->nnodes_major, pme->nodeid_major,
3565                       pme->nkx,
3566                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3567
3568     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3569      * We do this with an offset buffer of equal size, so we need to allocate
3570      * extra for the offset. That's what the (+1)*pme->nkz is for.
3571      */
3572     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3573 #ifdef GMX_MPI
3574                       pme->mpi_comm_d[1],
3575 #endif
3576                       pme->nnodes_minor, pme->nodeid_minor,
3577                       pme->nky,
3578                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3579
3580     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3581      * Note that gmx_pme_check_restrictions checked for this already.
3582      */
3583     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3584     {
3585         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3586     }
3587
3588     snew(pme->bsp_mod[XX], pme->nkx);
3589     snew(pme->bsp_mod[YY], pme->nky);
3590     snew(pme->bsp_mod[ZZ], pme->nkz);
3591
3592     /* The required size of the interpolation grid, including overlap.
3593      * The allocated size (pmegrid_n?) might be slightly larger.
3594      */
3595     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3596         pme->overlap[0].s2g0[pme->nodeid_major];
3597     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3598         pme->overlap[1].s2g0[pme->nodeid_minor];
3599     pme->pmegrid_nz_base = pme->nkz;
3600     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3601     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3602
3603     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3604     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3605     pme->pmegrid_start_iz = 0;
3606
3607     make_gridindex5_to_localindex(pme->nkx,
3608                                   pme->pmegrid_start_ix,
3609                                   pme->pmegrid_nx - (pme->pme_order-1),
3610                                   &pme->nnx, &pme->fshx);
3611     make_gridindex5_to_localindex(pme->nky,
3612                                   pme->pmegrid_start_iy,
3613                                   pme->pmegrid_ny - (pme->pme_order-1),
3614                                   &pme->nny, &pme->fshy);
3615     make_gridindex5_to_localindex(pme->nkz,
3616                                   pme->pmegrid_start_iz,
3617                                   pme->pmegrid_nz_base,
3618                                   &pme->nnz, &pme->fshz);
3619
3620     pme->spline_work = make_pme_spline_work(pme->pme_order);
3621
3622     ndata[0]    = pme->nkx;
3623     ndata[1]    = pme->nky;
3624     ndata[2]    = pme->nkz;
3625     /* It doesn't matter if we allocate too many grids here,
3626      * we only allocate and use the ones we need.
3627      */
3628     if (EVDW_PME(ir->vdwtype))
3629     {
3630         pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3631     }
3632     else
3633     {
3634         pme->ngrids = DO_Q;
3635     }
3636     snew(pme->fftgrid, pme->ngrids);
3637     snew(pme->cfftgrid, pme->ngrids);
3638     snew(pme->pfft_setup, pme->ngrids);
3639
3640     for (i = 0; i < pme->ngrids; ++i)
3641     {
3642         if ((i <  DO_Q && EEL_PME(ir->coulombtype) && (i == 0 ||
3643                                                        bFreeEnergy_q)) ||
3644             (i >= DO_Q && EVDW_PME(ir->vdwtype) && (i == 2 ||
3645                                                     bFreeEnergy_lj ||
3646                                                     ir->ljpme_combination_rule == eljpmeLB)))
3647         {
3648             pmegrids_init(&pme->pmegrid[i],
3649                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3650                           pme->pmegrid_nz_base,
3651                           pme->pme_order,
3652                           pme->bUseThreads,
3653                           pme->nthread,
3654                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3655                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3656             /* This routine will allocate the grid data to fit the FFTs */
3657             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3658                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3659                                     pme->mpi_comm_d,
3660                                     bReproducible, pme->nthread);
3661
3662         }
3663     }
3664
3665     if (!pme->bP3M)
3666     {
3667         /* Use plain SPME B-spline interpolation */
3668         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3669     }
3670     else
3671     {
3672         /* Use the P3M grid-optimized influence function */
3673         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3674     }
3675
3676     /* Use atc[0] for spreading */
3677     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3678     if (pme->ndecompdim >= 2)
3679     {
3680         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3681     }
3682
3683     if (pme->nnodes == 1)
3684     {
3685         pme->atc[0].n = homenr;
3686         pme_realloc_atomcomm_things(&pme->atc[0]);
3687     }
3688
3689     pme->lb_buf1       = NULL;
3690     pme->lb_buf2       = NULL;
3691     pme->lb_buf_nalloc = 0;
3692
3693     {
3694         int thread;
3695
3696         /* Use fft5d, order after FFT is y major, z, x minor */
3697
3698         snew(pme->work, pme->nthread);
3699         for (thread = 0; thread < pme->nthread; thread++)
3700         {
3701             realloc_work(&pme->work[thread], pme->nkx);
3702         }
3703     }
3704
3705     *pmedata = pme;
3706
3707     return 0;
3708 }
3709
3710 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3711 {
3712     int d, t;
3713
3714     for (d = 0; d < DIM; d++)
3715     {
3716         if (new->grid.n[d] > old->grid.n[d])
3717         {
3718             return;
3719         }
3720     }
3721
3722     sfree_aligned(new->grid.grid);
3723     new->grid.grid = old->grid.grid;
3724
3725     if (new->grid_th != NULL && new->nthread == old->nthread)
3726     {
3727         sfree_aligned(new->grid_all);
3728         for (t = 0; t < new->nthread; t++)
3729         {
3730             new->grid_th[t].grid = old->grid_th[t].grid;
3731         }
3732     }
3733 }
3734
3735 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3736                    t_commrec *         cr,
3737                    gmx_pme_t           pme_src,
3738                    const t_inputrec *  ir,
3739                    ivec                grid_size)
3740 {
3741     t_inputrec irc;
3742     int homenr;
3743     int ret;
3744
3745     irc     = *ir;
3746     irc.nkx = grid_size[XX];
3747     irc.nky = grid_size[YY];
3748     irc.nkz = grid_size[ZZ];
3749
3750     if (pme_src->nnodes == 1)
3751     {
3752         homenr = pme_src->atc[0].n;
3753     }
3754     else
3755     {
3756         homenr = -1;
3757     }
3758
3759     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3760                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3761
3762     if (ret == 0)
3763     {
3764         /* We can easily reuse the allocated pme grids in pme_src */
3765         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3766         /* We would like to reuse the fft grids, but that's harder */
3767     }
3768
3769     return ret;
3770 }
3771
3772
3773 static void copy_local_grid(gmx_pme_t pme, pmegrids_t *pmegrids,
3774                             int grid_index, int thread, real *fftgrid)
3775 {
3776     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3777     int  fft_my, fft_mz;
3778     int  nsx, nsy, nsz;
3779     ivec nf;
3780     int  offx, offy, offz, x, y, z, i0, i0t;
3781     int  d;
3782     pmegrid_t *pmegrid;
3783     real *grid_th;
3784
3785     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3786                                    local_fft_ndata,
3787                                    local_fft_offset,
3788                                    local_fft_size);
3789     fft_my = local_fft_size[YY];
3790     fft_mz = local_fft_size[ZZ];
3791
3792     pmegrid = &pmegrids->grid_th[thread];
3793
3794     nsx = pmegrid->s[XX];
3795     nsy = pmegrid->s[YY];
3796     nsz = pmegrid->s[ZZ];
3797
3798     for (d = 0; d < DIM; d++)
3799     {
3800         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3801                     local_fft_ndata[d] - pmegrid->offset[d]);
3802     }
3803
3804     offx = pmegrid->offset[XX];
3805     offy = pmegrid->offset[YY];
3806     offz = pmegrid->offset[ZZ];
3807
3808     /* Directly copy the non-overlapping parts of the local grids.
3809      * This also initializes the full grid.
3810      */
3811     grid_th = pmegrid->grid;
3812     for (x = 0; x < nf[XX]; x++)
3813     {
3814         for (y = 0; y < nf[YY]; y++)
3815         {
3816             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3817             i0t = (x*nsy + y)*nsz;
3818             for (z = 0; z < nf[ZZ]; z++)
3819             {
3820                 fftgrid[i0+z] = grid_th[i0t+z];
3821             }
3822         }
3823     }
3824 }
3825
3826 static void
3827 reduce_threadgrid_overlap(gmx_pme_t pme,
3828                           const pmegrids_t *pmegrids, int thread,
3829                           real *fftgrid, real *commbuf_x, real *commbuf_y,
3830                           int grid_index)
3831 {
3832     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3833     int  fft_nx, fft_ny, fft_nz;
3834     int  fft_my, fft_mz;
3835     int  buf_my = -1;
3836     int  nsx, nsy, nsz;
3837     ivec ne;
3838     int  offx, offy, offz, x, y, z, i0, i0t;
3839     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3840     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3841     gmx_bool bCommX, bCommY;
3842     int  d;
3843     int  thread_f;
3844     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3845     const real *grid_th;
3846     real *commbuf = NULL;
3847
3848     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
3849                                    local_fft_ndata,
3850                                    local_fft_offset,
3851                                    local_fft_size);
3852     fft_nx = local_fft_ndata[XX];
3853     fft_ny = local_fft_ndata[YY];
3854     fft_nz = local_fft_ndata[ZZ];
3855
3856     fft_my = local_fft_size[YY];
3857     fft_mz = local_fft_size[ZZ];
3858
3859     /* This routine is called when all thread have finished spreading.
3860      * Here each thread sums grid contributions calculated by other threads
3861      * to the thread local grid volume.
3862      * To minimize the number of grid copying operations,
3863      * this routines sums immediately from the pmegrid to the fftgrid.
3864      */
3865
3866     /* Determine which part of the full node grid we should operate on,
3867      * this is our thread local part of the full grid.
3868      */
3869     pmegrid = &pmegrids->grid_th[thread];
3870
3871     for (d = 0; d < DIM; d++)
3872     {
3873         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3874                     local_fft_ndata[d]);
3875     }
3876
3877     offx = pmegrid->offset[XX];
3878     offy = pmegrid->offset[YY];
3879     offz = pmegrid->offset[ZZ];
3880
3881
3882     bClearBufX  = TRUE;
3883     bClearBufY  = TRUE;
3884     bClearBufXY = TRUE;
3885
3886     /* Now loop over all the thread data blocks that contribute
3887      * to the grid region we (our thread) are operating on.
3888      */
3889     /* Note that ffy_nx/y is equal to the number of grid points
3890      * between the first point of our node grid and the one of the next node.
3891      */
3892     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3893     {
3894         fx     = pmegrid->ci[XX] + sx;
3895         ox     = 0;
3896         bCommX = FALSE;
3897         if (fx < 0)
3898         {
3899             fx    += pmegrids->nc[XX];
3900             ox    -= fft_nx;
3901             bCommX = (pme->nnodes_major > 1);
3902         }
3903         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3904         ox       += pmegrid_g->offset[XX];
3905         if (!bCommX)
3906         {
3907             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3908         }
3909         else
3910         {
3911             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3912         }
3913
3914         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3915         {
3916             fy     = pmegrid->ci[YY] + sy;
3917             oy     = 0;
3918             bCommY = FALSE;
3919             if (fy < 0)
3920             {
3921                 fy    += pmegrids->nc[YY];
3922                 oy    -= fft_ny;
3923                 bCommY = (pme->nnodes_minor > 1);
3924             }
3925             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3926             oy       += pmegrid_g->offset[YY];
3927             if (!bCommY)
3928             {
3929                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3930             }
3931             else
3932             {
3933                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3934             }
3935
3936             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3937             {
3938                 fz = pmegrid->ci[ZZ] + sz;
3939                 oz = 0;
3940                 if (fz < 0)
3941                 {
3942                     fz += pmegrids->nc[ZZ];
3943                     oz -= fft_nz;
3944                 }
3945                 pmegrid_g = &pmegrids->grid_th[fz];
3946                 oz       += pmegrid_g->offset[ZZ];
3947                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3948
3949                 if (sx == 0 && sy == 0 && sz == 0)
3950                 {
3951                     /* We have already added our local contribution
3952                      * before calling this routine, so skip it here.
3953                      */
3954                     continue;
3955                 }
3956
3957                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3958
3959                 pmegrid_f = &pmegrids->grid_th[thread_f];
3960
3961                 grid_th = pmegrid_f->grid;
3962
3963                 nsx = pmegrid_f->s[XX];
3964                 nsy = pmegrid_f->s[YY];
3965                 nsz = pmegrid_f->s[ZZ];
3966
3967 #ifdef DEBUG_PME_REDUCE
3968                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3969                        pme->nodeid, thread, thread_f,
3970                        pme->pmegrid_start_ix,
3971                        pme->pmegrid_start_iy,
3972                        pme->pmegrid_start_iz,
3973                        sx, sy, sz,
3974                        offx-ox, tx1-ox, offx, tx1,
3975                        offy-oy, ty1-oy, offy, ty1,
3976                        offz-oz, tz1-oz, offz, tz1);
3977 #endif
3978
3979                 if (!(bCommX || bCommY))
3980                 {
3981                     /* Copy from the thread local grid to the node grid */
3982                     for (x = offx; x < tx1; x++)
3983                     {
3984                         for (y = offy; y < ty1; y++)
3985                         {
3986                             i0  = (x*fft_my + y)*fft_mz;
3987                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3988                             for (z = offz; z < tz1; z++)
3989                             {
3990                                 fftgrid[i0+z] += grid_th[i0t+z];
3991                             }
3992                         }
3993                     }
3994                 }
3995                 else
3996                 {
3997                     /* The order of this conditional decides
3998                      * where the corner volume gets stored with x+y decomp.
3999                      */
4000                     if (bCommY)
4001                     {
4002                         commbuf = commbuf_y;
4003                         buf_my  = ty1 - offy;
4004                         if (bCommX)
4005                         {
4006                             /* We index commbuf modulo the local grid size */
4007                             commbuf += buf_my*fft_nx*fft_nz;
4008
4009                             bClearBuf   = bClearBufXY;
4010                             bClearBufXY = FALSE;
4011                         }
4012                         else
4013                         {
4014                             bClearBuf  = bClearBufY;
4015                             bClearBufY = FALSE;
4016                         }
4017                     }
4018                     else
4019                     {
4020                         commbuf    = commbuf_x;
4021                         buf_my     = fft_ny;
4022                         bClearBuf  = bClearBufX;
4023                         bClearBufX = FALSE;
4024                     }
4025
4026                     /* Copy to the communication buffer */
4027                     for (x = offx; x < tx1; x++)
4028                     {
4029                         for (y = offy; y < ty1; y++)
4030                         {
4031                             i0  = (x*buf_my + y)*fft_nz;
4032                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4033
4034                             if (bClearBuf)
4035                             {
4036                                 /* First access of commbuf, initialize it */
4037                                 for (z = offz; z < tz1; z++)
4038                                 {
4039                                     commbuf[i0+z]  = grid_th[i0t+z];
4040                                 }
4041                             }
4042                             else
4043                             {
4044                                 for (z = offz; z < tz1; z++)
4045                                 {
4046                                     commbuf[i0+z] += grid_th[i0t+z];
4047                                 }
4048                             }
4049                         }
4050                     }
4051                 }
4052             }
4053         }
4054     }
4055 }
4056
4057
4058 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid, int grid_index)
4059 {
4060     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4061     pme_overlap_t *overlap;
4062     int  send_index0, send_nindex;
4063     int  recv_nindex;
4064 #ifdef GMX_MPI
4065     MPI_Status stat;
4066 #endif
4067     int  send_size_y, recv_size_y;
4068     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4069     real *sendptr, *recvptr;
4070     int  x, y, z, indg, indb;
4071
4072     /* Note that this routine is only used for forward communication.
4073      * Since the force gathering, unlike the coefficient spreading,
4074      * can be trivially parallelized over the particles,
4075      * the backwards process is much simpler and can use the "old"
4076      * communication setup.
4077      */
4078
4079     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
4080                                    local_fft_ndata,
4081                                    local_fft_offset,
4082                                    local_fft_size);
4083
4084     if (pme->nnodes_minor > 1)
4085     {
4086         /* Major dimension */
4087         overlap = &pme->overlap[1];
4088
4089         if (pme->nnodes_major > 1)
4090         {
4091             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4092         }
4093         else
4094         {
4095             size_yx = 0;
4096         }
4097         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4098
4099         send_size_y = overlap->send_size;
4100
4101         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4102         {
4103             send_id       = overlap->send_id[ipulse];
4104             recv_id       = overlap->recv_id[ipulse];
4105             send_index0   =
4106                 overlap->comm_data[ipulse].send_index0 -
4107                 overlap->comm_data[0].send_index0;
4108             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4109             /* We don't use recv_index0, as we always receive starting at 0 */
4110             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4111             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4112
4113             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4114             recvptr = overlap->recvbuf;
4115
4116 #ifdef GMX_MPI
4117             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4118                          send_id, ipulse,
4119                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4120                          recv_id, ipulse,
4121                          overlap->mpi_comm, &stat);
4122 #endif
4123
4124             for (x = 0; x < local_fft_ndata[XX]; x++)
4125             {
4126                 for (y = 0; y < recv_nindex; y++)
4127                 {
4128                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4129                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4130                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4131                     {
4132                         fftgrid[indg+z] += recvptr[indb+z];
4133                     }
4134                 }
4135             }
4136
4137             if (pme->nnodes_major > 1)
4138             {
4139                 /* Copy from the received buffer to the send buffer for dim 0 */
4140                 sendptr = pme->overlap[0].sendbuf;
4141                 for (x = 0; x < size_yx; x++)
4142                 {
4143                     for (y = 0; y < recv_nindex; y++)
4144                     {
4145                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4146                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4147                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4148                         {
4149                             sendptr[indg+z] += recvptr[indb+z];
4150                         }
4151                     }
4152                 }
4153             }
4154         }
4155     }
4156
4157     /* We only support a single pulse here.
4158      * This is not a severe limitation, as this code is only used
4159      * with OpenMP and with OpenMP the (PME) domains can be larger.
4160      */
4161     if (pme->nnodes_major > 1)
4162     {
4163         /* Major dimension */
4164         overlap = &pme->overlap[0];
4165
4166         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4167         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4168
4169         ipulse = 0;
4170
4171         send_id       = overlap->send_id[ipulse];
4172         recv_id       = overlap->recv_id[ipulse];
4173         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4174         /* We don't use recv_index0, as we always receive starting at 0 */
4175         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4176
4177         sendptr = overlap->sendbuf;
4178         recvptr = overlap->recvbuf;
4179
4180         if (debug != NULL)
4181         {
4182             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
4183                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4184         }
4185
4186 #ifdef GMX_MPI
4187         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4188                      send_id, ipulse,
4189                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4190                      recv_id, ipulse,
4191                      overlap->mpi_comm, &stat);
4192 #endif
4193
4194         for (x = 0; x < recv_nindex; x++)
4195         {
4196             for (y = 0; y < local_fft_ndata[YY]; y++)
4197             {
4198                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4199                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4200                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4201                 {
4202                     fftgrid[indg+z] += recvptr[indb+z];
4203                 }
4204             }
4205         }
4206     }
4207 }
4208
4209
4210 static void spread_on_grid(gmx_pme_t pme,
4211                            pme_atomcomm_t *atc, pmegrids_t *grids,
4212                            gmx_bool bCalcSplines, gmx_bool bSpread,
4213                            real *fftgrid, gmx_bool bDoSplines, int grid_index)
4214 {
4215     int nthread, thread;
4216 #ifdef PME_TIME_THREADS
4217     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4218     static double cs1     = 0, cs2 = 0, cs3 = 0;
4219     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4220     static int cnt        = 0;
4221 #endif
4222
4223     nthread = pme->nthread;
4224     assert(nthread > 0);
4225
4226 #ifdef PME_TIME_THREADS
4227     c1 = omp_cyc_start();
4228 #endif
4229     if (bCalcSplines)
4230     {
4231 #pragma omp parallel for num_threads(nthread) schedule(static)
4232         for (thread = 0; thread < nthread; thread++)
4233         {
4234             int start, end;
4235
4236             start = atc->n* thread   /nthread;
4237             end   = atc->n*(thread+1)/nthread;
4238
4239             /* Compute fftgrid index for all atoms,
4240              * with help of some extra variables.
4241              */
4242             calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
4243         }
4244     }
4245 #ifdef PME_TIME_THREADS
4246     c1   = omp_cyc_end(c1);
4247     cs1 += (double)c1;
4248 #endif
4249
4250 #ifdef PME_TIME_THREADS
4251     c2 = omp_cyc_start();
4252 #endif
4253 #pragma omp parallel for num_threads(nthread) schedule(static)
4254     for (thread = 0; thread < nthread; thread++)
4255     {
4256         splinedata_t *spline;
4257         pmegrid_t *grid = NULL;
4258
4259         /* make local bsplines  */
4260         if (grids == NULL || !pme->bUseThreads)
4261         {
4262             spline = &atc->spline[0];
4263
4264             spline->n = atc->n;
4265
4266             if (bSpread)
4267             {
4268                 grid = &grids->grid;
4269             }
4270         }
4271         else
4272         {
4273             spline = &atc->spline[thread];
4274
4275             if (grids->nthread == 1)
4276             {
4277                 /* One thread, we operate on all coefficients */
4278                 spline->n = atc->n;
4279             }
4280             else
4281             {
4282                 /* Get the indices our thread should operate on */
4283                 make_thread_local_ind(atc, thread, spline);
4284             }
4285
4286             grid = &grids->grid_th[thread];
4287         }
4288
4289         if (bCalcSplines)
4290         {
4291             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4292                           atc->fractx, spline->n, spline->ind, atc->coefficient, bDoSplines);
4293         }
4294
4295         if (bSpread)
4296         {
4297             /* put local atoms on grid. */
4298 #ifdef PME_TIME_SPREAD
4299             ct1a = omp_cyc_start();
4300 #endif
4301             spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
4302
4303             if (pme->bUseThreads)
4304             {
4305                 copy_local_grid(pme, grids, grid_index, thread, fftgrid);
4306             }
4307 #ifdef PME_TIME_SPREAD
4308             ct1a          = omp_cyc_end(ct1a);
4309             cs1a[thread] += (double)ct1a;
4310 #endif
4311         }
4312     }
4313 #ifdef PME_TIME_THREADS
4314     c2   = omp_cyc_end(c2);
4315     cs2 += (double)c2;
4316 #endif
4317
4318     if (bSpread && pme->bUseThreads)
4319     {
4320 #ifdef PME_TIME_THREADS
4321         c3 = omp_cyc_start();
4322 #endif
4323 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4324         for (thread = 0; thread < grids->nthread; thread++)
4325         {
4326             reduce_threadgrid_overlap(pme, grids, thread,
4327                                       fftgrid,
4328                                       pme->overlap[0].sendbuf,
4329                                       pme->overlap[1].sendbuf,
4330                                       grid_index);
4331         }
4332 #ifdef PME_TIME_THREADS
4333         c3   = omp_cyc_end(c3);
4334         cs3 += (double)c3;
4335 #endif
4336
4337         if (pme->nnodes > 1)
4338         {
4339             /* Communicate the overlapping part of the fftgrid.
4340              * For this communication call we need to check pme->bUseThreads
4341              * to have all ranks communicate here, regardless of pme->nthread.
4342              */
4343             sum_fftgrid_dd(pme, fftgrid, grid_index);
4344         }
4345     }
4346
4347 #ifdef PME_TIME_THREADS
4348     cnt++;
4349     if (cnt % 20 == 0)
4350     {
4351         printf("idx %.2f spread %.2f red %.2f",
4352                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4353 #ifdef PME_TIME_SPREAD
4354         for (thread = 0; thread < nthread; thread++)
4355         {
4356             printf(" %.2f", cs1a[thread]*1e-9);
4357         }
4358 #endif
4359         printf("\n");
4360     }
4361 #endif
4362 }
4363
4364
4365 static void dump_grid(FILE *fp,
4366                       int sx, int sy, int sz, int nx, int ny, int nz,
4367                       int my, int mz, const real *g)
4368 {
4369     int x, y, z;
4370
4371     for (x = 0; x < nx; x++)
4372     {
4373         for (y = 0; y < ny; y++)
4374         {
4375             for (z = 0; z < nz; z++)
4376             {
4377                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4378                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4379             }
4380         }
4381     }
4382 }
4383
4384 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4385 {
4386     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4387
4388     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4389                                    local_fft_ndata,
4390                                    local_fft_offset,
4391                                    local_fft_size);
4392
4393     dump_grid(stderr,
4394               pme->pmegrid_start_ix,
4395               pme->pmegrid_start_iy,
4396               pme->pmegrid_start_iz,
4397               pme->pmegrid_nx-pme->pme_order+1,
4398               pme->pmegrid_ny-pme->pme_order+1,
4399               pme->pmegrid_nz-pme->pme_order+1,
4400               local_fft_size[YY],
4401               local_fft_size[ZZ],
4402               fftgrid);
4403 }
4404
4405
4406 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4407 {
4408     pme_atomcomm_t *atc;
4409     pmegrids_t *grid;
4410
4411     if (pme->nnodes > 1)
4412     {
4413         gmx_incons("gmx_pme_calc_energy called in parallel");
4414     }
4415     if (pme->bFEP_q > 1)
4416     {
4417         gmx_incons("gmx_pme_calc_energy with free energy");
4418     }
4419
4420     atc            = &pme->atc_energy;
4421     atc->nthread   = 1;
4422     if (atc->spline == NULL)
4423     {
4424         snew(atc->spline, atc->nthread);
4425     }
4426     atc->nslab     = 1;
4427     atc->bSpread   = TRUE;
4428     atc->pme_order = pme->pme_order;
4429     atc->n         = n;
4430     pme_realloc_atomcomm_things(atc);
4431     atc->x           = x;
4432     atc->coefficient = q;
4433
4434     /* We only use the A-charges grid */
4435     grid = &pme->pmegrid[PME_GRID_QA];
4436
4437     /* Only calculate the spline coefficients, don't actually spread */
4438     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE, PME_GRID_QA);
4439
4440     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4441 }
4442
4443
4444 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4445                                    gmx_walltime_accounting_t walltime_accounting,
4446                                    t_nrnb *nrnb, t_inputrec *ir,
4447                                    gmx_int64_t step)
4448 {
4449     /* Reset all the counters related to performance over the run */
4450     wallcycle_stop(wcycle, ewcRUN);
4451     wallcycle_reset_all(wcycle);
4452     init_nrnb(nrnb);
4453     if (ir->nsteps >= 0)
4454     {
4455         /* ir->nsteps is not used here, but we update it for consistency */
4456         ir->nsteps -= step - ir->init_step;
4457     }
4458     ir->init_step = step;
4459     wallcycle_start(wcycle, ewcRUN);
4460     walltime_accounting_start(walltime_accounting);
4461 }
4462
4463
4464 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4465                                ivec grid_size,
4466                                t_commrec *cr, t_inputrec *ir,
4467                                gmx_pme_t *pme_ret)
4468 {
4469     int ind;
4470     gmx_pme_t pme = NULL;
4471
4472     ind = 0;
4473     while (ind < *npmedata)
4474     {
4475         pme = (*pmedata)[ind];
4476         if (pme->nkx == grid_size[XX] &&
4477             pme->nky == grid_size[YY] &&
4478             pme->nkz == grid_size[ZZ])
4479         {
4480             *pme_ret = pme;
4481
4482             return;
4483         }
4484
4485         ind++;
4486     }
4487
4488     (*npmedata)++;
4489     srenew(*pmedata, *npmedata);
4490
4491     /* Generate a new PME data structure, copying part of the old pointers */
4492     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4493
4494     *pme_ret = (*pmedata)[ind];
4495 }
4496
4497 int gmx_pmeonly(gmx_pme_t pme,
4498                 t_commrec *cr,    t_nrnb *mynrnb,
4499                 gmx_wallcycle_t wcycle,
4500                 gmx_walltime_accounting_t walltime_accounting,
4501                 real ewaldcoeff_q, real ewaldcoeff_lj,
4502                 t_inputrec *ir)
4503 {
4504     int npmedata;
4505     gmx_pme_t *pmedata;
4506     gmx_pme_pp_t pme_pp;
4507     int  ret;
4508     int  natoms;
4509     matrix box;
4510     rvec *x_pp      = NULL, *f_pp = NULL;
4511     real *chargeA   = NULL, *chargeB = NULL;
4512     real *c6A       = NULL, *c6B = NULL;
4513     real *sigmaA    = NULL, *sigmaB = NULL;
4514     real lambda_q   = 0;
4515     real lambda_lj  = 0;
4516     int  maxshift_x = 0, maxshift_y = 0;
4517     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4518     matrix vir_q, vir_lj;
4519     float cycles;
4520     int  count;
4521     gmx_bool bEnerVir;
4522     int pme_flags;
4523     gmx_int64_t step, step_rel;
4524     ivec grid_switch;
4525
4526     /* This data will only use with PME tuning, i.e. switching PME grids */
4527     npmedata = 1;
4528     snew(pmedata, npmedata);
4529     pmedata[0] = pme;
4530
4531     pme_pp = gmx_pme_pp_init(cr);
4532
4533     init_nrnb(mynrnb);
4534
4535     count = 0;
4536     do /****** this is a quasi-loop over time steps! */
4537     {
4538         /* The reason for having a loop here is PME grid tuning/switching */
4539         do
4540         {
4541             /* Domain decomposition */
4542             ret = gmx_pme_recv_coeffs_coords(pme_pp,
4543                                              &natoms,
4544                                              &chargeA, &chargeB,
4545                                              &c6A, &c6B,
4546                                              &sigmaA, &sigmaB,
4547                                              box, &x_pp, &f_pp,
4548                                              &maxshift_x, &maxshift_y,
4549                                              &pme->bFEP_q, &pme->bFEP_lj,
4550                                              &lambda_q, &lambda_lj,
4551                                              &bEnerVir,
4552                                              &pme_flags,
4553                                              &step,
4554                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4555
4556             if (ret == pmerecvqxSWITCHGRID)
4557             {
4558                 /* Switch the PME grid to grid_switch */
4559                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4560             }
4561
4562             if (ret == pmerecvqxRESETCOUNTERS)
4563             {
4564                 /* Reset the cycle and flop counters */
4565                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4566             }
4567         }
4568         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4569
4570         if (ret == pmerecvqxFINISH)
4571         {
4572             /* We should stop: break out of the loop */
4573             break;
4574         }
4575
4576         step_rel = step - ir->init_step;
4577
4578         if (count == 0)
4579         {
4580             wallcycle_start(wcycle, ewcRUN);
4581             walltime_accounting_start(walltime_accounting);
4582         }
4583
4584         wallcycle_start(wcycle, ewcPMEMESH);
4585
4586         dvdlambda_q  = 0;
4587         dvdlambda_lj = 0;
4588         clear_mat(vir_q);
4589         clear_mat(vir_lj);
4590
4591         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4592                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4593                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4594                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4595                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4596                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4597
4598         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4599
4600         gmx_pme_send_force_vir_ener(pme_pp,
4601                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4602                                     dvdlambda_q, dvdlambda_lj, cycles);
4603
4604         count++;
4605     } /***** end of quasi-loop, we stop with the break above */
4606     while (TRUE);
4607
4608     walltime_accounting_end(walltime_accounting);
4609
4610     return 0;
4611 }
4612
4613 static void
4614 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4615 {
4616     int  i;
4617
4618     for (i = 0; i < pme->atc[0].n; ++i)
4619     {
4620         real sigma4;
4621
4622         sigma4                     = local_sigma[i];
4623         sigma4                     = sigma4*sigma4;
4624         sigma4                     = sigma4*sigma4;
4625         pme->atc[0].coefficient[i] = local_c6[i] / sigma4;
4626     }
4627 }
4628
4629 static void
4630 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4631 {
4632     int  i;
4633
4634     for (i = 0; i < pme->atc[0].n; ++i)
4635     {
4636         pme->atc[0].coefficient[i] *= local_sigma[i];
4637     }
4638 }
4639
4640 static void
4641 do_redist_pos_coeffs(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4642                      gmx_bool bFirst, rvec x[], real *data)
4643 {
4644     int      d;
4645     pme_atomcomm_t *atc;
4646     atc = &pme->atc[0];
4647
4648     for (d = pme->ndecompdim - 1; d >= 0; d--)
4649     {
4650         int             n_d;
4651         rvec           *x_d;
4652         real           *param_d;
4653
4654         if (d == pme->ndecompdim - 1)
4655         {
4656             n_d     = homenr;
4657             x_d     = x + start;
4658             param_d = data;
4659         }
4660         else
4661         {
4662             n_d     = pme->atc[d + 1].n;
4663             x_d     = atc->x;
4664             param_d = atc->coefficient;
4665         }
4666         atc      = &pme->atc[d];
4667         atc->npd = n_d;
4668         if (atc->npd > atc->pd_nalloc)
4669         {
4670             atc->pd_nalloc = over_alloc_dd(atc->npd);
4671             srenew(atc->pd, atc->pd_nalloc);
4672         }
4673         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4674         where();
4675         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4676         if (DOMAINDECOMP(cr))
4677         {
4678             dd_pmeredist_pos_coeffs(pme, n_d, bFirst, x_d, param_d, atc);
4679         }
4680     }
4681 }
4682
4683 int gmx_pme_do(gmx_pme_t pme,
4684                int start,       int homenr,
4685                rvec x[],        rvec f[],
4686                real *chargeA,   real *chargeB,
4687                real *c6A,       real *c6B,
4688                real *sigmaA,    real *sigmaB,
4689                matrix box, t_commrec *cr,
4690                int  maxshift_x, int maxshift_y,
4691                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4692                matrix vir_q,      real ewaldcoeff_q,
4693                matrix vir_lj,   real ewaldcoeff_lj,
4694                real *energy_q,  real *energy_lj,
4695                real lambda_q, real lambda_lj,
4696                real *dvdlambda_q, real *dvdlambda_lj,
4697                int flags)
4698 {
4699     int     d, i, j, k, ntot, npme, grid_index, max_grid_index;
4700     int     nx, ny, nz;
4701     int     n_d, local_ny;
4702     pme_atomcomm_t *atc = NULL;
4703     pmegrids_t *pmegrid = NULL;
4704     real    *grid       = NULL;
4705     real    *ptr;
4706     rvec    *x_d, *f_d;
4707     real    *coefficient = NULL;
4708     real    energy_AB[4];
4709     matrix  vir_AB[4];
4710     real    scale, lambda;
4711     gmx_bool bClearF;
4712     gmx_parallel_3dfft_t pfft_setup;
4713     real *  fftgrid;
4714     t_complex * cfftgrid;
4715     int     thread;
4716     gmx_bool bFirst, bDoSplines;
4717     int fep_state;
4718     int fep_states_lj           = pme->bFEP_lj ? 2 : 1;
4719     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4720     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4721
4722     assert(pme->nnodes > 0);
4723     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4724
4725     if (pme->nnodes > 1)
4726     {
4727         atc      = &pme->atc[0];
4728         atc->npd = homenr;
4729         if (atc->npd > atc->pd_nalloc)
4730         {
4731             atc->pd_nalloc = over_alloc_dd(atc->npd);
4732             srenew(atc->pd, atc->pd_nalloc);
4733         }
4734         for (d = pme->ndecompdim-1; d >= 0; d--)
4735         {
4736             atc           = &pme->atc[d];
4737             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4738         }
4739     }
4740     else
4741     {
4742         atc = &pme->atc[0];
4743         /* This could be necessary for TPI */
4744         pme->atc[0].n = homenr;
4745         if (DOMAINDECOMP(cr))
4746         {
4747             pme_realloc_atomcomm_things(atc);
4748         }
4749         atc->x = x;
4750         atc->f = f;
4751     }
4752
4753     m_inv_ur0(box, pme->recipbox);
4754     bFirst = TRUE;
4755
4756     /* For simplicity, we construct the splines for all particles if
4757      * more than one PME calculations is needed. Some optimization
4758      * could be done by keeping track of which atoms have splines
4759      * constructed, and construct new splines on each pass for atoms
4760      * that don't yet have them.
4761      */
4762
4763     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4764
4765     /* We need a maximum of four separate PME calculations:
4766      * grid_index=0: Coulomb PME with charges from state A
4767      * grid_index=1: Coulomb PME with charges from state B
4768      * grid_index=2: LJ PME with C6 from state A
4769      * grid_index=3: LJ PME with C6 from state B
4770      * For Lorentz-Berthelot combination rules, a separate loop is used to
4771      * calculate all the terms
4772      */
4773
4774     /* If we are doing LJ-PME with LB, we only do Q here */
4775     max_grid_index = (pme->ljpme_combination_rule == eljpmeLB) ? DO_Q : DO_Q_AND_LJ;
4776
4777     for (grid_index = 0; grid_index < max_grid_index; ++grid_index)
4778     {
4779         /* Check if we should do calculations at this grid_index
4780          * If grid_index is odd we should be doing FEP
4781          * If grid_index < 2 we should be doing electrostatic PME
4782          * If grid_index >= 2 we should be doing LJ-PME
4783          */
4784         if ((grid_index <  DO_Q && (!(flags & GMX_PME_DO_COULOMB) ||
4785                                     (grid_index == 1 && !pme->bFEP_q))) ||
4786             (grid_index >= DO_Q && (!(flags & GMX_PME_DO_LJ) ||
4787                                     (grid_index == 3 && !pme->bFEP_lj))))
4788         {
4789             continue;
4790         }
4791         /* Unpack structure */
4792         pmegrid    = &pme->pmegrid[grid_index];
4793         fftgrid    = pme->fftgrid[grid_index];
4794         cfftgrid   = pme->cfftgrid[grid_index];
4795         pfft_setup = pme->pfft_setup[grid_index];
4796         switch (grid_index)
4797         {
4798             case 0: coefficient = chargeA + start; break;
4799             case 1: coefficient = chargeB + start; break;
4800             case 2: coefficient = c6A + start; break;
4801             case 3: coefficient = c6B + start; break;
4802         }
4803
4804         grid = pmegrid->grid.grid;
4805
4806         if (debug)
4807         {
4808             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4809                     cr->nnodes, cr->nodeid);
4810             fprintf(debug, "Grid = %p\n", (void*)grid);
4811             if (grid == NULL)
4812             {
4813                 gmx_fatal(FARGS, "No grid!");
4814             }
4815         }
4816         where();
4817
4818         if (pme->nnodes == 1)
4819         {
4820             atc->coefficient = coefficient;
4821         }
4822         else
4823         {
4824             wallcycle_start(wcycle, ewcPME_REDISTXF);
4825             do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, coefficient);
4826             where();
4827
4828             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4829         }
4830
4831         if (debug)
4832         {
4833             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4834                     cr->nodeid, atc->n);
4835         }
4836
4837         if (flags & GMX_PME_SPREAD)
4838         {
4839             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4840
4841             /* Spread the coefficients on a grid */
4842             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
4843
4844             if (bFirst)
4845             {
4846                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4847             }
4848             inc_nrnb(nrnb, eNR_SPREADBSP,
4849                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4850
4851             if (!pme->bUseThreads)
4852             {
4853                 wrap_periodic_pmegrid(pme, grid);
4854
4855                 /* sum contributions to local grid from other nodes */
4856 #ifdef GMX_MPI
4857                 if (pme->nnodes > 1)
4858                 {
4859                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
4860                     where();
4861                 }
4862 #endif
4863
4864                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
4865             }
4866
4867             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4868
4869             /*
4870                dump_local_fftgrid(pme,fftgrid);
4871                exit(0);
4872              */
4873         }
4874
4875         /* Here we start a large thread parallel region */
4876 #pragma omp parallel num_threads(pme->nthread) private(thread)
4877         {
4878             thread = gmx_omp_get_thread_num();
4879             if (flags & GMX_PME_SOLVE)
4880             {
4881                 int loop_count;
4882
4883                 /* do 3d-fft */
4884                 if (thread == 0)
4885                 {
4886                     wallcycle_start(wcycle, ewcPME_FFT);
4887                 }
4888                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4889                                            thread, wcycle);
4890                 if (thread == 0)
4891                 {
4892                     wallcycle_stop(wcycle, ewcPME_FFT);
4893                 }
4894                 where();
4895
4896                 /* solve in k-space for our local cells */
4897                 if (thread == 0)
4898                 {
4899                     wallcycle_start(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4900                 }
4901                 if (grid_index < DO_Q)
4902                 {
4903                     loop_count =
4904                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
4905                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4906                                       bCalcEnerVir,
4907                                       pme->nthread, thread);
4908                 }
4909                 else
4910                 {
4911                     loop_count =
4912                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
4913                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4914                                          bCalcEnerVir,
4915                                          pme->nthread, thread);
4916                 }
4917
4918                 if (thread == 0)
4919                 {
4920                     wallcycle_stop(wcycle, (grid_index < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4921                     where();
4922                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4923                 }
4924             }
4925
4926             if (bCalcF)
4927             {
4928                 /* do 3d-invfft */
4929                 if (thread == 0)
4930                 {
4931                     where();
4932                     wallcycle_start(wcycle, ewcPME_FFT);
4933                 }
4934                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4935                                            thread, wcycle);
4936                 if (thread == 0)
4937                 {
4938                     wallcycle_stop(wcycle, ewcPME_FFT);
4939
4940                     where();
4941
4942                     if (pme->nodeid == 0)
4943                     {
4944                         ntot  = pme->nkx*pme->nky*pme->nkz;
4945                         npme  = ntot*log((real)ntot)/log(2.0);
4946                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4947                     }
4948
4949                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4950                 }
4951
4952                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
4953             }
4954         }
4955         /* End of thread parallel section.
4956          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4957          */
4958
4959         if (bCalcF)
4960         {
4961             /* distribute local grid to all nodes */
4962 #ifdef GMX_MPI
4963             if (pme->nnodes > 1)
4964             {
4965                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
4966             }
4967 #endif
4968             where();
4969
4970             unwrap_periodic_pmegrid(pme, grid);
4971
4972             /* interpolate forces for our local atoms */
4973
4974             where();
4975
4976             /* If we are running without parallelization,
4977              * atc->f is the actual force array, not a buffer,
4978              * therefore we should not clear it.
4979              */
4980             lambda  = grid_index < DO_Q ? lambda_q : lambda_lj;
4981             bClearF = (bFirst && PAR(cr));
4982 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4983             for (thread = 0; thread < pme->nthread; thread++)
4984             {
4985                 gather_f_bsplines(pme, grid, bClearF, atc,
4986                                   &atc->spline[thread],
4987                                   pme->bFEP ? (grid_index % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
4988             }
4989
4990             where();
4991
4992             inc_nrnb(nrnb, eNR_GATHERFBSP,
4993                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4994             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4995         }
4996
4997         if (bCalcEnerVir)
4998         {
4999             /* This should only be called on the master thread
5000              * and after the threads have synchronized.
5001              */
5002             if (grid_index < 2)
5003             {
5004                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5005             }
5006             else
5007             {
5008                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[grid_index], vir_AB[grid_index]);
5009             }
5010         }
5011         bFirst = FALSE;
5012     } /* of grid_index-loop */
5013
5014     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
5015      * seven terms. */
5016
5017     if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB)
5018     {
5019         /* Loop over A- and B-state if we are doing FEP */
5020         for (fep_state = 0; fep_state < fep_states_lj; ++fep_state)
5021         {
5022             real *local_c6 = NULL, *local_sigma = NULL, *RedistC6 = NULL, *RedistSigma = NULL;
5023             if (pme->nnodes == 1)
5024             {
5025                 if (pme->lb_buf1 == NULL)
5026                 {
5027                     pme->lb_buf_nalloc = pme->atc[0].n;
5028                     snew(pme->lb_buf1, pme->lb_buf_nalloc);
5029                 }
5030                 pme->atc[0].coefficient = pme->lb_buf1;
5031                 switch (fep_state)
5032                 {
5033                     case 0:
5034                         local_c6      = c6A;
5035                         local_sigma   = sigmaA;
5036                         break;
5037                     case 1:
5038                         local_c6      = c6B;
5039                         local_sigma   = sigmaB;
5040                         break;
5041                     default:
5042                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5043                 }
5044             }
5045             else
5046             {
5047                 atc = &pme->atc[0];
5048                 switch (fep_state)
5049                 {
5050                     case 0:
5051                         RedistC6      = c6A;
5052                         RedistSigma   = sigmaA;
5053                         break;
5054                     case 1:
5055                         RedistC6      = c6B;
5056                         RedistSigma   = sigmaB;
5057                         break;
5058                     default:
5059                         gmx_incons("Trying to access wrong FEP-state in LJ-PME routine");
5060                 }
5061                 wallcycle_start(wcycle, ewcPME_REDISTXF);
5062
5063                 do_redist_pos_coeffs(pme, cr, start, homenr, bFirst, x, RedistC6);
5064                 if (pme->lb_buf_nalloc < atc->n)
5065                 {
5066                     pme->lb_buf_nalloc = atc->nalloc;
5067                     srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5068                     srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5069                 }
5070                 local_c6 = pme->lb_buf1;
5071                 for (i = 0; i < atc->n; ++i)
5072                 {
5073                     local_c6[i] = atc->coefficient[i];
5074                 }
5075                 where();
5076
5077                 do_redist_pos_coeffs(pme, cr, start, homenr, FALSE, x, RedistSigma);
5078                 local_sigma = pme->lb_buf2;
5079                 for (i = 0; i < atc->n; ++i)
5080                 {
5081                     local_sigma[i] = atc->coefficient[i];
5082                 }
5083                 where();
5084
5085                 wallcycle_stop(wcycle, ewcPME_REDISTXF);
5086             }
5087             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5088
5089             /*Seven terms in LJ-PME with LB, grid_index < 2 reserved for electrostatics*/
5090             for (grid_index = 2; grid_index < 9; ++grid_index)
5091             {
5092                 /* Unpack structure */
5093                 pmegrid    = &pme->pmegrid[grid_index];
5094                 fftgrid    = pme->fftgrid[grid_index];
5095                 cfftgrid   = pme->cfftgrid[grid_index];
5096                 pfft_setup = pme->pfft_setup[grid_index];
5097                 calc_next_lb_coeffs(pme, local_sigma);
5098                 grid = pmegrid->grid.grid;
5099                 where();
5100
5101                 if (flags & GMX_PME_SPREAD)
5102                 {
5103                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5104                     /* Spread the c6 on a grid */
5105                     spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines, grid_index);
5106
5107                     if (bFirst)
5108                     {
5109                         inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5110                     }
5111
5112                     inc_nrnb(nrnb, eNR_SPREADBSP,
5113                              pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5114                     if (pme->nthread == 1)
5115                     {
5116                         wrap_periodic_pmegrid(pme, grid);
5117                         /* sum contributions to local grid from other nodes */
5118 #ifdef GMX_MPI
5119                         if (pme->nnodes > 1)
5120                         {
5121                             gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_FORWARD);
5122                             where();
5123                         }
5124 #endif
5125                         copy_pmegrid_to_fftgrid(pme, grid, fftgrid, grid_index);
5126                     }
5127                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5128                 }
5129                 /*Here we start a large thread parallel region*/
5130 #pragma omp parallel num_threads(pme->nthread) private(thread)
5131                 {
5132                     thread = gmx_omp_get_thread_num();
5133                     if (flags & GMX_PME_SOLVE)
5134                     {
5135                         /* do 3d-fft */
5136                         if (thread == 0)
5137                         {
5138                             wallcycle_start(wcycle, ewcPME_FFT);
5139                         }
5140
5141                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5142                                                    thread, wcycle);
5143                         if (thread == 0)
5144                         {
5145                             wallcycle_stop(wcycle, ewcPME_FFT);
5146                         }
5147                         where();
5148                     }
5149                 }
5150                 bFirst = FALSE;
5151             }
5152             if (flags & GMX_PME_SOLVE)
5153             {
5154                 /* solve in k-space for our local cells */
5155 #pragma omp parallel num_threads(pme->nthread) private(thread)
5156                 {
5157                     int loop_count;
5158                     thread = gmx_omp_get_thread_num();
5159                     if (thread == 0)
5160                     {
5161                         wallcycle_start(wcycle, ewcLJPME);
5162                     }
5163
5164                     loop_count =
5165                         solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5166                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5167                                          bCalcEnerVir,
5168                                          pme->nthread, thread);
5169                     if (thread == 0)
5170                     {
5171                         wallcycle_stop(wcycle, ewcLJPME);
5172                         where();
5173                         inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5174                     }
5175                 }
5176             }
5177
5178             if (bCalcEnerVir)
5179             {
5180                 /* This should only be called on the master thread and
5181                  * after the threads have synchronized.
5182                  */
5183                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2+fep_state], vir_AB[2+fep_state]);
5184             }
5185
5186             if (bCalcF)
5187             {
5188                 bFirst = !(flags & GMX_PME_DO_COULOMB);
5189                 calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5190                 for (grid_index = 8; grid_index >= 2; --grid_index)
5191                 {
5192                     /* Unpack structure */
5193                     pmegrid    = &pme->pmegrid[grid_index];
5194                     fftgrid    = pme->fftgrid[grid_index];
5195                     cfftgrid   = pme->cfftgrid[grid_index];
5196                     pfft_setup = pme->pfft_setup[grid_index];
5197                     grid       = pmegrid->grid.grid;
5198                     calc_next_lb_coeffs(pme, local_sigma);
5199                     where();
5200 #pragma omp parallel num_threads(pme->nthread) private(thread)
5201                     {
5202                         thread = gmx_omp_get_thread_num();
5203                         /* do 3d-invfft */
5204                         if (thread == 0)
5205                         {
5206                             where();
5207                             wallcycle_start(wcycle, ewcPME_FFT);
5208                         }
5209
5210                         gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5211                                                    thread, wcycle);
5212                         if (thread == 0)
5213                         {
5214                             wallcycle_stop(wcycle, ewcPME_FFT);
5215
5216                             where();
5217
5218                             if (pme->nodeid == 0)
5219                             {
5220                                 ntot  = pme->nkx*pme->nky*pme->nkz;
5221                                 npme  = ntot*log((real)ntot)/log(2.0);
5222                                 inc_nrnb(nrnb, eNR_FFT, 2*npme);
5223                             }
5224                             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5225                         }
5226
5227                         copy_fftgrid_to_pmegrid(pme, fftgrid, grid, grid_index, pme->nthread, thread);
5228
5229                     } /*#pragma omp parallel*/
5230
5231                     /* distribute local grid to all nodes */
5232 #ifdef GMX_MPI
5233                     if (pme->nnodes > 1)
5234                     {
5235                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_GRID_BACKWARD);
5236                     }
5237 #endif
5238                     where();
5239
5240                     unwrap_periodic_pmegrid(pme, grid);
5241
5242                     /* interpolate forces for our local atoms */
5243                     where();
5244                     bClearF = (bFirst && PAR(cr));
5245                     scale   = pme->bFEP ? (fep_state < 1 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5246                     scale  *= lb_scale_factor[grid_index-2];
5247 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5248                     for (thread = 0; thread < pme->nthread; thread++)
5249                     {
5250                         gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5251                                           &pme->atc[0].spline[thread],
5252                                           scale);
5253                     }
5254                     where();
5255
5256                     inc_nrnb(nrnb, eNR_GATHERFBSP,
5257                              pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5258                     wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5259
5260                     bFirst = FALSE;
5261                 } /* for (grid_index = 8; grid_index >= 2; --grid_index) */
5262             }     /* if (bCalcF) */
5263         }         /* for (fep_state = 0; fep_state < fep_states_lj; ++fep_state) */
5264     }             /* if ((flags & GMX_PME_DO_LJ) && pme->ljpme_combination_rule == eljpmeLB) */
5265
5266     if (bCalcF && pme->nnodes > 1)
5267     {
5268         wallcycle_start(wcycle, ewcPME_REDISTXF);
5269         for (d = 0; d < pme->ndecompdim; d++)
5270         {
5271             atc = &pme->atc[d];
5272             if (d == pme->ndecompdim - 1)
5273             {
5274                 n_d = homenr;
5275                 f_d = f + start;
5276             }
5277             else
5278             {
5279                 n_d = pme->atc[d+1].n;
5280                 f_d = pme->atc[d+1].f;
5281             }
5282             if (DOMAINDECOMP(cr))
5283             {
5284                 dd_pmeredist_f(pme, atc, n_d, f_d,
5285                                d == pme->ndecompdim-1 && pme->bPPnode);
5286             }
5287         }
5288
5289         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5290     }
5291     where();
5292
5293     if (bCalcEnerVir)
5294     {
5295         if (flags & GMX_PME_DO_COULOMB)
5296         {
5297             if (!pme->bFEP_q)
5298             {
5299                 *energy_q = energy_AB[0];
5300                 m_add(vir_q, vir_AB[0], vir_q);
5301             }
5302             else
5303             {
5304                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5305                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5306                 for (i = 0; i < DIM; i++)
5307                 {
5308                     for (j = 0; j < DIM; j++)
5309                     {
5310                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5311                             lambda_q*vir_AB[1][i][j];
5312                     }
5313                 }
5314             }
5315             if (debug)
5316             {
5317                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5318             }
5319         }
5320         else
5321         {
5322             *energy_q = 0;
5323         }
5324
5325         if (flags & GMX_PME_DO_LJ)
5326         {
5327             if (!pme->bFEP_lj)
5328             {
5329                 *energy_lj = energy_AB[2];
5330                 m_add(vir_lj, vir_AB[2], vir_lj);
5331             }
5332             else
5333             {
5334                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5335                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5336                 for (i = 0; i < DIM; i++)
5337                 {
5338                     for (j = 0; j < DIM; j++)
5339                     {
5340                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5341                     }
5342                 }
5343             }
5344             if (debug)
5345             {
5346                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5347             }
5348         }
5349         else
5350         {
5351             *energy_lj = 0;
5352         }
5353     }
5354     return 0;
5355 }