Remove particle decomposition
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #ifdef HAVE_CONFIG_H
61 #include <config.h>
62 #endif
63
64 #include <stdio.h>
65 #include <string.h>
66 #include <math.h>
67 #include <assert.h>
68 #include "typedefs.h"
69 #include "txtdump.h"
70 #include "vec.h"
71 #include "smalloc.h"
72 #include "coulomb.h"
73 #include "gmx_fatal.h"
74 #include "pme.h"
75 #include "network.h"
76 #include "physics.h"
77 #include "nrnb.h"
78 #include "macros.h"
79
80 #include "gromacs/fft/parallel_3dfft.h"
81 #include "gromacs/fileio/futil.h"
82 #include "gromacs/fileio/pdbio.h"
83 #include "gromacs/math/gmxcomplex.h"
84 #include "gromacs/timing/cyclecounter.h"
85 #include "gromacs/timing/wallcycle.h"
86 #include "gromacs/utility/gmxmpi.h"
87 #include "gromacs/utility/gmxomp.h"
88
89 /* Include the SIMD macro file and then check for support */
90 #include "gromacs/simd/macros.h"
91 #if defined GMX_HAVE_SIMD_MACROS
92 /* Turn on arbitrary width SIMD intrinsics for PME solve */
93 #define PME_SIMD_SOLVE
94 #endif
95
96 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
97 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
98 #define DO_Q           2 /* Electrostatic grids have index q<2 */
99 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
100 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
101
102 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
103 const real lb_scale_factor[] = {
104     1.0/64, 6.0/64, 15.0/64, 20.0/64,
105     15.0/64, 6.0/64, 1.0/64
106 };
107
108 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
109 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
110
111 /* Include the 4-wide SIMD macro file */
112 #include "gromacs/simd/four_wide_macros.h"
113 /* Check if we have 4-wide SIMD macro support */
114 #ifdef GMX_HAVE_SIMD4_MACROS
115 /* Do PME spread and gather with 4-wide SIMD.
116  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
117  */
118 #define PME_SIMD4_SPREAD_GATHER
119
120 #ifdef GMX_SIMD4_HAVE_UNALIGNED
121 /* With PME-order=4 on x86, unaligned load+store is slightly faster
122  * than doubling all SIMD operations when using aligned load+store.
123  */
124 #define PME_SIMD4_UNALIGNED
125 #endif
126 #endif
127
128 #define DFT_TOL 1e-7
129 /* #define PRT_FORCE */
130 /* conditions for on the fly time-measurement */
131 /* #define TAKETIME (step > 1 && timesteps < 10) */
132 #define TAKETIME FALSE
133
134 /* #define PME_TIME_THREADS */
135
136 #ifdef GMX_DOUBLE
137 #define mpi_type MPI_DOUBLE
138 #else
139 #define mpi_type MPI_FLOAT
140 #endif
141
142 #ifdef PME_SIMD4_SPREAD_GATHER
143 #define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
144 #else
145 /* We can use any alignment, apart from 0, so we use 4 reals */
146 #define SIMD4_ALIGNMENT  (4*sizeof(real))
147 #endif
148
149 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
150  * to preserve alignment.
151  */
152 #define GMX_CACHE_SEP 64
153
154 /* We only define a maximum to be able to use local arrays without allocation.
155  * An order larger than 12 should never be needed, even for test cases.
156  * If needed it can be changed here.
157  */
158 #define PME_ORDER_MAX 12
159
160 /* Internal datastructures */
161 typedef struct {
162     int send_index0;
163     int send_nindex;
164     int recv_index0;
165     int recv_nindex;
166     int recv_size;   /* Receive buffer width, used with OpenMP */
167 } pme_grid_comm_t;
168
169 typedef struct {
170 #ifdef GMX_MPI
171     MPI_Comm         mpi_comm;
172 #endif
173     int              nnodes, nodeid;
174     int             *s2g0;
175     int             *s2g1;
176     int              noverlap_nodes;
177     int             *send_id, *recv_id;
178     int              send_size; /* Send buffer width, used with OpenMP */
179     pme_grid_comm_t *comm_data;
180     real            *sendbuf;
181     real            *recvbuf;
182 } pme_overlap_t;
183
184 typedef struct {
185     int *n;      /* Cumulative counts of the number of particles per thread */
186     int  nalloc; /* Allocation size of i */
187     int *i;      /* Particle indices ordered on thread index (n) */
188 } thread_plist_t;
189
190 typedef struct {
191     int      *thread_one;
192     int       n;
193     int      *ind;
194     splinevec theta;
195     real     *ptr_theta_z;
196     splinevec dtheta;
197     real     *ptr_dtheta_z;
198 } splinedata_t;
199
200 typedef struct {
201     int      dimind;        /* The index of the dimension, 0=x, 1=y */
202     int      nslab;
203     int      nodeid;
204 #ifdef GMX_MPI
205     MPI_Comm mpi_comm;
206 #endif
207
208     int     *node_dest;     /* The nodes to send x and q to with DD */
209     int     *node_src;      /* The nodes to receive x and q from with DD */
210     int     *buf_index;     /* Index for commnode into the buffers */
211
212     int      maxshift;
213
214     int      npd;
215     int      pd_nalloc;
216     int     *pd;
217     int     *count;         /* The number of atoms to send to each node */
218     int    **count_thread;
219     int     *rcount;        /* The number of atoms to receive */
220
221     int      n;
222     int      nalloc;
223     rvec    *x;
224     real    *q;
225     rvec    *f;
226     gmx_bool bSpread;       /* These coordinates are used for spreading */
227     int      pme_order;
228     ivec    *idx;
229     rvec    *fractx;            /* Fractional coordinate relative to
230                                  * the lower cell boundary
231                                  */
232     int             nthread;
233     int            *thread_idx; /* Which thread should spread which charge */
234     thread_plist_t *thread_plist;
235     splinedata_t   *spline;
236 } pme_atomcomm_t;
237
238 #define FLBS  3
239 #define FLBSZ 4
240
241 typedef struct {
242     ivec  ci;     /* The spatial location of this grid         */
243     ivec  n;      /* The used size of *grid, including order-1 */
244     ivec  offset; /* The grid offset from the full node grid   */
245     int   order;  /* PME spreading order                       */
246     ivec  s;      /* The allocated size of *grid, s >= n       */
247     real *grid;   /* The grid local thread, size n             */
248 } pmegrid_t;
249
250 typedef struct {
251     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
252     int        nthread;      /* The number of threads operating on this grid     */
253     ivec       nc;           /* The local spatial decomposition over the threads */
254     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
255     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
256     int      **g2t;          /* The grid to thread index                         */
257     ivec       nthread_comm; /* The number of threads to communicate with        */
258 } pmegrids_t;
259
260 typedef struct {
261 #ifdef PME_SIMD4_SPREAD_GATHER
262     /* Masks for 4-wide SIMD aligned spreading and gathering */
263     gmx_simd4_bool_t mask_S0[6], mask_S1[6];
264 #else
265     int              dummy; /* C89 requires that struct has at least one member */
266 #endif
267 } pme_spline_work_t;
268
269 typedef struct {
270     /* work data for solve_pme */
271     int      nalloc;
272     real *   mhx;
273     real *   mhy;
274     real *   mhz;
275     real *   m2;
276     real *   denom;
277     real *   tmp1_alloc;
278     real *   tmp1;
279     real *   tmp2;
280     real *   eterm;
281     real *   m2inv;
282
283     real     energy_q;
284     matrix   vir_q;
285     real     energy_lj;
286     matrix   vir_lj;
287 } pme_work_t;
288
289 typedef struct gmx_pme {
290     int           ndecompdim; /* The number of decomposition dimensions */
291     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
292     int           nodeid_major;
293     int           nodeid_minor;
294     int           nnodes;    /* The number of nodes doing PME */
295     int           nnodes_major;
296     int           nnodes_minor;
297
298     MPI_Comm      mpi_comm;
299     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
300 #ifdef GMX_MPI
301     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
302 #endif
303
304     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
305     int        nthread;       /* The number of threads doing PME on our rank */
306
307     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
308     gmx_bool   bFEP;          /* Compute Free energy contribution */
309     gmx_bool   bFEP_q;
310     gmx_bool   bFEP_lj;
311     int        nkx, nky, nkz; /* Grid dimensions */
312     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
313     int        pme_order;
314     real       epsilon_r;
315
316     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
317
318     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
319                                          * includes overlap Grid indices are ordered as
320                                          * follows:
321                                          * 0: Coloumb PME, state A
322                                          * 1: Coloumb PME, state B
323                                          * 2-8: LJ-PME
324                                          * This can probably be done in a better way
325                                          * but this simple hack works for now
326                                          */
327     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
328     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
329     /* pmegrid_nz might be larger than strictly necessary to ensure
330      * memory alignment, pmegrid_nz_base gives the real base size.
331      */
332     int     pmegrid_nz_base;
333     /* The local PME grid starting indices */
334     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
335
336     /* Work data for spreading and gathering */
337     pme_spline_work_t     *spline_work;
338
339     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
340     /* inside the interpolation grid, but separate for 2D PME decomp. */
341     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
342
343     t_complex            **cfftgrid;  /* Grids for complex FFT data */
344
345     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
346
347     gmx_parallel_3dfft_t  *pfft_setup;
348
349     int                   *nnx, *nny, *nnz;
350     real                  *fshx, *fshy, *fshz;
351
352     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
353     matrix                 recipbox;
354     splinevec              bsp_mod;
355     /* Buffers to store data for local atoms for L-B combination rule
356      * calculations in LJ-PME. lb_buf1 stores either the coefficients
357      * for spreading/gathering (in serial), or the C6 parameters for
358      * local atoms (in parallel).  lb_buf2 is only used in parallel,
359      * and stores the sigma values for local atoms. */
360     real                 *lb_buf1, *lb_buf2;
361     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
362
363     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
364
365     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
366
367     rvec                 *bufv;          /* Communication buffer */
368     real                 *bufr;          /* Communication buffer */
369     int                   buf_nalloc;    /* The communication buffer size */
370
371     /* thread local work data for solve_pme */
372     pme_work_t *work;
373
374     /* Work data for sum_qgrid */
375     real *   sum_qgrid_tmp;
376     real *   sum_qgrid_dd_tmp;
377 } t_gmx_pme;
378
379 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
380                                    int start, int end, int thread)
381 {
382     int             i;
383     int            *idxptr, tix, tiy, tiz;
384     real           *xptr, *fptr, tx, ty, tz;
385     real            rxx, ryx, ryy, rzx, rzy, rzz;
386     int             nx, ny, nz;
387     int             start_ix, start_iy, start_iz;
388     int            *g2tx, *g2ty, *g2tz;
389     gmx_bool        bThreads;
390     int            *thread_idx = NULL;
391     thread_plist_t *tpl        = NULL;
392     int            *tpl_n      = NULL;
393     int             thread_i;
394
395     nx  = pme->nkx;
396     ny  = pme->nky;
397     nz  = pme->nkz;
398
399     start_ix = pme->pmegrid_start_ix;
400     start_iy = pme->pmegrid_start_iy;
401     start_iz = pme->pmegrid_start_iz;
402
403     rxx = pme->recipbox[XX][XX];
404     ryx = pme->recipbox[YY][XX];
405     ryy = pme->recipbox[YY][YY];
406     rzx = pme->recipbox[ZZ][XX];
407     rzy = pme->recipbox[ZZ][YY];
408     rzz = pme->recipbox[ZZ][ZZ];
409
410     g2tx = pme->pmegrid[PME_GRID_QA].g2t[XX];
411     g2ty = pme->pmegrid[PME_GRID_QA].g2t[YY];
412     g2tz = pme->pmegrid[PME_GRID_QA].g2t[ZZ];
413
414     bThreads = (atc->nthread > 1);
415     if (bThreads)
416     {
417         thread_idx = atc->thread_idx;
418
419         tpl   = &atc->thread_plist[thread];
420         tpl_n = tpl->n;
421         for (i = 0; i < atc->nthread; i++)
422         {
423             tpl_n[i] = 0;
424         }
425     }
426
427     for (i = start; i < end; i++)
428     {
429         xptr   = atc->x[i];
430         idxptr = atc->idx[i];
431         fptr   = atc->fractx[i];
432
433         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
434         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
435         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
436         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
437
438         tix = (int)(tx);
439         tiy = (int)(ty);
440         tiz = (int)(tz);
441
442         /* Because decomposition only occurs in x and y,
443          * we never have a fraction correction in z.
444          */
445         fptr[XX] = tx - tix + pme->fshx[tix];
446         fptr[YY] = ty - tiy + pme->fshy[tiy];
447         fptr[ZZ] = tz - tiz;
448
449         idxptr[XX] = pme->nnx[tix];
450         idxptr[YY] = pme->nny[tiy];
451         idxptr[ZZ] = pme->nnz[tiz];
452
453 #ifdef DEBUG
454         range_check(idxptr[XX], 0, pme->pmegrid_nx);
455         range_check(idxptr[YY], 0, pme->pmegrid_ny);
456         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
457 #endif
458
459         if (bThreads)
460         {
461             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
462             thread_idx[i] = thread_i;
463             tpl_n[thread_i]++;
464         }
465     }
466
467     if (bThreads)
468     {
469         /* Make a list of particle indices sorted on thread */
470
471         /* Get the cumulative count */
472         for (i = 1; i < atc->nthread; i++)
473         {
474             tpl_n[i] += tpl_n[i-1];
475         }
476         /* The current implementation distributes particles equally
477          * over the threads, so we could actually allocate for that
478          * in pme_realloc_atomcomm_things.
479          */
480         if (tpl_n[atc->nthread-1] > tpl->nalloc)
481         {
482             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
483             srenew(tpl->i, tpl->nalloc);
484         }
485         /* Set tpl_n to the cumulative start */
486         for (i = atc->nthread-1; i >= 1; i--)
487         {
488             tpl_n[i] = tpl_n[i-1];
489         }
490         tpl_n[0] = 0;
491
492         /* Fill our thread local array with indices sorted on thread */
493         for (i = start; i < end; i++)
494         {
495             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
496         }
497         /* Now tpl_n contains the cummulative count again */
498     }
499 }
500
501 static void make_thread_local_ind(pme_atomcomm_t *atc,
502                                   int thread, splinedata_t *spline)
503 {
504     int             n, t, i, start, end;
505     thread_plist_t *tpl;
506
507     /* Combine the indices made by each thread into one index */
508
509     n     = 0;
510     start = 0;
511     for (t = 0; t < atc->nthread; t++)
512     {
513         tpl = &atc->thread_plist[t];
514         /* Copy our part (start - end) from the list of thread t */
515         if (thread > 0)
516         {
517             start = tpl->n[thread-1];
518         }
519         end = tpl->n[thread];
520         for (i = start; i < end; i++)
521         {
522             spline->ind[n++] = tpl->i[i];
523         }
524     }
525
526     spline->n = n;
527 }
528
529
530 static void pme_calc_pidx(int start, int end,
531                           matrix recipbox, rvec x[],
532                           pme_atomcomm_t *atc, int *count)
533 {
534     int   nslab, i;
535     int   si;
536     real *xptr, s;
537     real  rxx, ryx, rzx, ryy, rzy;
538     int  *pd;
539
540     /* Calculate PME task index (pidx) for each grid index.
541      * Here we always assign equally sized slabs to each node
542      * for load balancing reasons (the PME grid spacing is not used).
543      */
544
545     nslab = atc->nslab;
546     pd    = atc->pd;
547
548     /* Reset the count */
549     for (i = 0; i < nslab; i++)
550     {
551         count[i] = 0;
552     }
553
554     if (atc->dimind == 0)
555     {
556         rxx = recipbox[XX][XX];
557         ryx = recipbox[YY][XX];
558         rzx = recipbox[ZZ][XX];
559         /* Calculate the node index in x-dimension */
560         for (i = start; i < end; i++)
561         {
562             xptr   = x[i];
563             /* Fractional coordinates along box vectors */
564             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
565             si    = (int)(s + 2*nslab) % nslab;
566             pd[i] = si;
567             count[si]++;
568         }
569     }
570     else
571     {
572         ryy = recipbox[YY][YY];
573         rzy = recipbox[ZZ][YY];
574         /* Calculate the node index in y-dimension */
575         for (i = start; i < end; i++)
576         {
577             xptr   = x[i];
578             /* Fractional coordinates along box vectors */
579             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
580             si    = (int)(s + 2*nslab) % nslab;
581             pd[i] = si;
582             count[si]++;
583         }
584     }
585 }
586
587 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
588                                   pme_atomcomm_t *atc)
589 {
590     int nthread, thread, slab;
591
592     nthread = atc->nthread;
593
594 #pragma omp parallel for num_threads(nthread) schedule(static)
595     for (thread = 0; thread < nthread; thread++)
596     {
597         pme_calc_pidx(natoms* thread   /nthread,
598                       natoms*(thread+1)/nthread,
599                       recipbox, x, atc, atc->count_thread[thread]);
600     }
601     /* Non-parallel reduction, since nslab is small */
602
603     for (thread = 1; thread < nthread; thread++)
604     {
605         for (slab = 0; slab < atc->nslab; slab++)
606         {
607             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
608         }
609     }
610 }
611
612 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
613 {
614     const int padding = 4;
615     int       i;
616
617     srenew(th[XX], nalloc);
618     srenew(th[YY], nalloc);
619     /* In z we add padding, this is only required for the aligned SIMD code */
620     sfree_aligned(*ptr_z);
621     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
622     th[ZZ] = *ptr_z + padding;
623
624     for (i = 0; i < padding; i++)
625     {
626         (*ptr_z)[               i] = 0;
627         (*ptr_z)[padding+nalloc+i] = 0;
628     }
629 }
630
631 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
632 {
633     int i, d;
634
635     srenew(spline->ind, atc->nalloc);
636     /* Initialize the index to identity so it works without threads */
637     for (i = 0; i < atc->nalloc; i++)
638     {
639         spline->ind[i] = i;
640     }
641
642     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
643                       atc->pme_order*atc->nalloc);
644     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
645                       atc->pme_order*atc->nalloc);
646 }
647
648 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
649 {
650     int nalloc_old, i, j, nalloc_tpl;
651
652     /* We have to avoid a NULL pointer for atc->x to avoid
653      * possible fatal errors in MPI routines.
654      */
655     if (atc->n > atc->nalloc || atc->nalloc == 0)
656     {
657         nalloc_old  = atc->nalloc;
658         atc->nalloc = over_alloc_dd(max(atc->n, 1));
659
660         if (atc->nslab > 1)
661         {
662             srenew(atc->x, atc->nalloc);
663             srenew(atc->q, atc->nalloc);
664             srenew(atc->f, atc->nalloc);
665             for (i = nalloc_old; i < atc->nalloc; i++)
666             {
667                 clear_rvec(atc->f[i]);
668             }
669         }
670         if (atc->bSpread)
671         {
672             srenew(atc->fractx, atc->nalloc);
673             srenew(atc->idx, atc->nalloc);
674
675             if (atc->nthread > 1)
676             {
677                 srenew(atc->thread_idx, atc->nalloc);
678             }
679
680             for (i = 0; i < atc->nthread; i++)
681             {
682                 pme_realloc_splinedata(&atc->spline[i], atc);
683             }
684         }
685     }
686 }
687
688 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
689                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
690                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
691                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
692 {
693 #ifdef GMX_MPI
694     int        dest, src;
695     MPI_Status stat;
696
697     if (bBackward == FALSE)
698     {
699         dest = atc->node_dest[shift];
700         src  = atc->node_src[shift];
701     }
702     else
703     {
704         dest = atc->node_src[shift];
705         src  = atc->node_dest[shift];
706     }
707
708     if (nbyte_s > 0 && nbyte_r > 0)
709     {
710         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
711                      dest, shift,
712                      buf_r, nbyte_r, MPI_BYTE,
713                      src, shift,
714                      atc->mpi_comm, &stat);
715     }
716     else if (nbyte_s > 0)
717     {
718         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
719                  dest, shift,
720                  atc->mpi_comm);
721     }
722     else if (nbyte_r > 0)
723     {
724         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
725                  src, shift,
726                  atc->mpi_comm, &stat);
727     }
728 #endif
729 }
730
731 static void dd_pmeredist_x_q(gmx_pme_t pme,
732                              int n, gmx_bool bX, rvec *x, real *charge,
733                              pme_atomcomm_t *atc)
734 {
735     int *commnode, *buf_index;
736     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
737
738     commnode  = atc->node_dest;
739     buf_index = atc->buf_index;
740
741     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
742
743     nsend = 0;
744     for (i = 0; i < nnodes_comm; i++)
745     {
746         buf_index[commnode[i]] = nsend;
747         nsend                 += atc->count[commnode[i]];
748     }
749     if (bX)
750     {
751         if (atc->count[atc->nodeid] + nsend != n)
752         {
753             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
754                       "This usually means that your system is not well equilibrated.",
755                       n - (atc->count[atc->nodeid] + nsend),
756                       pme->nodeid, 'x'+atc->dimind);
757         }
758
759         if (nsend > pme->buf_nalloc)
760         {
761             pme->buf_nalloc = over_alloc_dd(nsend);
762             srenew(pme->bufv, pme->buf_nalloc);
763             srenew(pme->bufr, pme->buf_nalloc);
764         }
765
766         atc->n = atc->count[atc->nodeid];
767         for (i = 0; i < nnodes_comm; i++)
768         {
769             scount = atc->count[commnode[i]];
770             /* Communicate the count */
771             if (debug)
772             {
773                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
774                         atc->dimind, atc->nodeid, commnode[i], scount);
775             }
776             pme_dd_sendrecv(atc, FALSE, i,
777                             &scount, sizeof(int),
778                             &atc->rcount[i], sizeof(int));
779             atc->n += atc->rcount[i];
780         }
781
782         pme_realloc_atomcomm_things(atc);
783     }
784
785     local_pos = 0;
786     for (i = 0; i < n; i++)
787     {
788         node = atc->pd[i];
789         if (node == atc->nodeid)
790         {
791             /* Copy direct to the receive buffer */
792             if (bX)
793             {
794                 copy_rvec(x[i], atc->x[local_pos]);
795             }
796             atc->q[local_pos] = charge[i];
797             local_pos++;
798         }
799         else
800         {
801             /* Copy to the send buffer */
802             if (bX)
803             {
804                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
805             }
806             pme->bufr[buf_index[node]] = charge[i];
807             buf_index[node]++;
808         }
809     }
810
811     buf_pos = 0;
812     for (i = 0; i < nnodes_comm; i++)
813     {
814         scount = atc->count[commnode[i]];
815         rcount = atc->rcount[i];
816         if (scount > 0 || rcount > 0)
817         {
818             if (bX)
819             {
820                 /* Communicate the coordinates */
821                 pme_dd_sendrecv(atc, FALSE, i,
822                                 pme->bufv[buf_pos], scount*sizeof(rvec),
823                                 atc->x[local_pos], rcount*sizeof(rvec));
824             }
825             /* Communicate the charges */
826             pme_dd_sendrecv(atc, FALSE, i,
827                             pme->bufr+buf_pos, scount*sizeof(real),
828                             atc->q+local_pos, rcount*sizeof(real));
829             buf_pos   += scount;
830             local_pos += atc->rcount[i];
831         }
832     }
833 }
834
835 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
836                            int n, rvec *f,
837                            gmx_bool bAddF)
838 {
839     int *commnode, *buf_index;
840     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
841
842     commnode  = atc->node_dest;
843     buf_index = atc->buf_index;
844
845     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
846
847     local_pos = atc->count[atc->nodeid];
848     buf_pos   = 0;
849     for (i = 0; i < nnodes_comm; i++)
850     {
851         scount = atc->rcount[i];
852         rcount = atc->count[commnode[i]];
853         if (scount > 0 || rcount > 0)
854         {
855             /* Communicate the forces */
856             pme_dd_sendrecv(atc, TRUE, i,
857                             atc->f[local_pos], scount*sizeof(rvec),
858                             pme->bufv[buf_pos], rcount*sizeof(rvec));
859             local_pos += scount;
860         }
861         buf_index[commnode[i]] = buf_pos;
862         buf_pos               += rcount;
863     }
864
865     local_pos = 0;
866     if (bAddF)
867     {
868         for (i = 0; i < n; i++)
869         {
870             node = atc->pd[i];
871             if (node == atc->nodeid)
872             {
873                 /* Add from the local force array */
874                 rvec_inc(f[i], atc->f[local_pos]);
875                 local_pos++;
876             }
877             else
878             {
879                 /* Add from the receive buffer */
880                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
881                 buf_index[node]++;
882             }
883         }
884     }
885     else
886     {
887         for (i = 0; i < n; i++)
888         {
889             node = atc->pd[i];
890             if (node == atc->nodeid)
891             {
892                 /* Copy from the local force array */
893                 copy_rvec(atc->f[local_pos], f[i]);
894                 local_pos++;
895             }
896             else
897             {
898                 /* Copy from the receive buffer */
899                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
900                 buf_index[node]++;
901             }
902         }
903     }
904 }
905
906 #ifdef GMX_MPI
907 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
908 {
909     pme_overlap_t *overlap;
910     int            send_index0, send_nindex;
911     int            recv_index0, recv_nindex;
912     MPI_Status     stat;
913     int            i, j, k, ix, iy, iz, icnt;
914     int            ipulse, send_id, recv_id, datasize;
915     real          *p;
916     real          *sendptr, *recvptr;
917
918     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
919     overlap = &pme->overlap[1];
920
921     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
922     {
923         /* Since we have already (un)wrapped the overlap in the z-dimension,
924          * we only have to communicate 0 to nkz (not pmegrid_nz).
925          */
926         if (direction == GMX_SUM_QGRID_FORWARD)
927         {
928             send_id       = overlap->send_id[ipulse];
929             recv_id       = overlap->recv_id[ipulse];
930             send_index0   = overlap->comm_data[ipulse].send_index0;
931             send_nindex   = overlap->comm_data[ipulse].send_nindex;
932             recv_index0   = overlap->comm_data[ipulse].recv_index0;
933             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
934         }
935         else
936         {
937             send_id       = overlap->recv_id[ipulse];
938             recv_id       = overlap->send_id[ipulse];
939             send_index0   = overlap->comm_data[ipulse].recv_index0;
940             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
941             recv_index0   = overlap->comm_data[ipulse].send_index0;
942             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
943         }
944
945         /* Copy data to contiguous send buffer */
946         if (debug)
947         {
948             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
949                     pme->nodeid, overlap->nodeid, send_id,
950                     pme->pmegrid_start_iy,
951                     send_index0-pme->pmegrid_start_iy,
952                     send_index0-pme->pmegrid_start_iy+send_nindex);
953         }
954         icnt = 0;
955         for (i = 0; i < pme->pmegrid_nx; i++)
956         {
957             ix = i;
958             for (j = 0; j < send_nindex; j++)
959             {
960                 iy = j + send_index0 - pme->pmegrid_start_iy;
961                 for (k = 0; k < pme->nkz; k++)
962                 {
963                     iz = k;
964                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
965                 }
966             }
967         }
968
969         datasize      = pme->pmegrid_nx * pme->nkz;
970
971         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
972                      send_id, ipulse,
973                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
974                      recv_id, ipulse,
975                      overlap->mpi_comm, &stat);
976
977         /* Get data from contiguous recv buffer */
978         if (debug)
979         {
980             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
981                     pme->nodeid, overlap->nodeid, recv_id,
982                     pme->pmegrid_start_iy,
983                     recv_index0-pme->pmegrid_start_iy,
984                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
985         }
986         icnt = 0;
987         for (i = 0; i < pme->pmegrid_nx; i++)
988         {
989             ix = i;
990             for (j = 0; j < recv_nindex; j++)
991             {
992                 iy = j + recv_index0 - pme->pmegrid_start_iy;
993                 for (k = 0; k < pme->nkz; k++)
994                 {
995                     iz = k;
996                     if (direction == GMX_SUM_QGRID_FORWARD)
997                     {
998                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
999                     }
1000                     else
1001                     {
1002                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1003                     }
1004                 }
1005             }
1006         }
1007     }
1008
1009     /* Major dimension is easier, no copying required,
1010      * but we might have to sum to separate array.
1011      * Since we don't copy, we have to communicate up to pmegrid_nz,
1012      * not nkz as for the minor direction.
1013      */
1014     overlap = &pme->overlap[0];
1015
1016     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1017     {
1018         if (direction == GMX_SUM_QGRID_FORWARD)
1019         {
1020             send_id       = overlap->send_id[ipulse];
1021             recv_id       = overlap->recv_id[ipulse];
1022             send_index0   = overlap->comm_data[ipulse].send_index0;
1023             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1024             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1025             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1026             recvptr       = overlap->recvbuf;
1027         }
1028         else
1029         {
1030             send_id       = overlap->recv_id[ipulse];
1031             recv_id       = overlap->send_id[ipulse];
1032             send_index0   = overlap->comm_data[ipulse].recv_index0;
1033             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1034             recv_index0   = overlap->comm_data[ipulse].send_index0;
1035             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1036             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1037         }
1038
1039         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1040         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1041
1042         if (debug)
1043         {
1044             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1045                     pme->nodeid, overlap->nodeid, send_id,
1046                     pme->pmegrid_start_ix,
1047                     send_index0-pme->pmegrid_start_ix,
1048                     send_index0-pme->pmegrid_start_ix+send_nindex);
1049             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1050                     pme->nodeid, overlap->nodeid, recv_id,
1051                     pme->pmegrid_start_ix,
1052                     recv_index0-pme->pmegrid_start_ix,
1053                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1054         }
1055
1056         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1057                      send_id, ipulse,
1058                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1059                      recv_id, ipulse,
1060                      overlap->mpi_comm, &stat);
1061
1062         /* ADD data from contiguous recv buffer */
1063         if (direction == GMX_SUM_QGRID_FORWARD)
1064         {
1065             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1066             for (i = 0; i < recv_nindex*datasize; i++)
1067             {
1068                 p[i] += overlap->recvbuf[i];
1069             }
1070         }
1071     }
1072 }
1073 #endif
1074
1075
1076 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1077 {
1078     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1079     ivec    local_pme_size;
1080     int     i, ix, iy, iz;
1081     int     pmeidx, fftidx;
1082
1083     /* Dimensions should be identical for A/B grid, so we just use A here */
1084     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1085                                    local_fft_ndata,
1086                                    local_fft_offset,
1087                                    local_fft_size);
1088
1089     local_pme_size[0] = pme->pmegrid_nx;
1090     local_pme_size[1] = pme->pmegrid_ny;
1091     local_pme_size[2] = pme->pmegrid_nz;
1092
1093     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1094        the offset is identical, and the PME grid always has more data (due to overlap)
1095      */
1096     {
1097 #ifdef DEBUG_PME
1098         FILE *fp, *fp2;
1099         char  fn[STRLEN], format[STRLEN];
1100         real  val;
1101         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1102         fp = gmx_ffopen(fn, "w");
1103         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1104         fp2 = gmx_ffopen(fn, "w");
1105         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1106 #endif
1107
1108         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1109         {
1110             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1111             {
1112                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1113                 {
1114                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1115                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1116                     fftgrid[fftidx] = pmegrid[pmeidx];
1117 #ifdef DEBUG_PME
1118                     val = 100*pmegrid[pmeidx];
1119                     if (pmegrid[pmeidx] != 0)
1120                     {
1121                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1122                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1123                     }
1124                     if (pmegrid[pmeidx] != 0)
1125                     {
1126                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1127                                 "qgrid",
1128                                 pme->pmegrid_start_ix + ix,
1129                                 pme->pmegrid_start_iy + iy,
1130                                 pme->pmegrid_start_iz + iz,
1131                                 pmegrid[pmeidx]);
1132                     }
1133 #endif
1134                 }
1135             }
1136         }
1137 #ifdef DEBUG_PME
1138         gmx_ffclose(fp);
1139         gmx_ffclose(fp2);
1140 #endif
1141     }
1142     return 0;
1143 }
1144
1145
1146 static gmx_cycles_t omp_cyc_start()
1147 {
1148     return gmx_cycles_read();
1149 }
1150
1151 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1152 {
1153     return gmx_cycles_read() - c;
1154 }
1155
1156
1157 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1158                                    int nthread, int thread)
1159 {
1160     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1161     ivec          local_pme_size;
1162     int           ixy0, ixy1, ixy, ix, iy, iz;
1163     int           pmeidx, fftidx;
1164 #ifdef PME_TIME_THREADS
1165     gmx_cycles_t  c1;
1166     static double cs1 = 0;
1167     static int    cnt = 0;
1168 #endif
1169
1170 #ifdef PME_TIME_THREADS
1171     c1 = omp_cyc_start();
1172 #endif
1173     /* Dimensions should be identical for A/B grid, so we just use A here */
1174     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1175                                    local_fft_ndata,
1176                                    local_fft_offset,
1177                                    local_fft_size);
1178
1179     local_pme_size[0] = pme->pmegrid_nx;
1180     local_pme_size[1] = pme->pmegrid_ny;
1181     local_pme_size[2] = pme->pmegrid_nz;
1182
1183     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1184        the offset is identical, and the PME grid always has more data (due to overlap)
1185      */
1186     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1187     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1188
1189     for (ixy = ixy0; ixy < ixy1; ixy++)
1190     {
1191         ix = ixy/local_fft_ndata[YY];
1192         iy = ixy - ix*local_fft_ndata[YY];
1193
1194         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1195         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1196         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1197         {
1198             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1199         }
1200     }
1201
1202 #ifdef PME_TIME_THREADS
1203     c1   = omp_cyc_end(c1);
1204     cs1 += (double)c1;
1205     cnt++;
1206     if (cnt % 20 == 0)
1207     {
1208         printf("copy %.2f\n", cs1*1e-9);
1209     }
1210 #endif
1211
1212     return 0;
1213 }
1214
1215
1216 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1217 {
1218     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1219
1220     nx = pme->nkx;
1221     ny = pme->nky;
1222     nz = pme->nkz;
1223
1224     pnx = pme->pmegrid_nx;
1225     pny = pme->pmegrid_ny;
1226     pnz = pme->pmegrid_nz;
1227
1228     overlap = pme->pme_order - 1;
1229
1230     /* Add periodic overlap in z */
1231     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1232     {
1233         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1234         {
1235             for (iz = 0; iz < overlap; iz++)
1236             {
1237                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1238                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1239             }
1240         }
1241     }
1242
1243     if (pme->nnodes_minor == 1)
1244     {
1245         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1246         {
1247             for (iy = 0; iy < overlap; iy++)
1248             {
1249                 for (iz = 0; iz < nz; iz++)
1250                 {
1251                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1252                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1253                 }
1254             }
1255         }
1256     }
1257
1258     if (pme->nnodes_major == 1)
1259     {
1260         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1261
1262         for (ix = 0; ix < overlap; ix++)
1263         {
1264             for (iy = 0; iy < ny_x; iy++)
1265             {
1266                 for (iz = 0; iz < nz; iz++)
1267                 {
1268                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1269                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1270                 }
1271             }
1272         }
1273     }
1274 }
1275
1276
1277 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1278 {
1279     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1280
1281     nx = pme->nkx;
1282     ny = pme->nky;
1283     nz = pme->nkz;
1284
1285     pnx = pme->pmegrid_nx;
1286     pny = pme->pmegrid_ny;
1287     pnz = pme->pmegrid_nz;
1288
1289     overlap = pme->pme_order - 1;
1290
1291     if (pme->nnodes_major == 1)
1292     {
1293         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1294
1295         for (ix = 0; ix < overlap; ix++)
1296         {
1297             int iy, iz;
1298
1299             for (iy = 0; iy < ny_x; iy++)
1300             {
1301                 for (iz = 0; iz < nz; iz++)
1302                 {
1303                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1304                         pmegrid[(ix*pny+iy)*pnz+iz];
1305                 }
1306             }
1307         }
1308     }
1309
1310     if (pme->nnodes_minor == 1)
1311     {
1312 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1313         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1314         {
1315             int iy, iz;
1316
1317             for (iy = 0; iy < overlap; iy++)
1318             {
1319                 for (iz = 0; iz < nz; iz++)
1320                 {
1321                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1322                         pmegrid[(ix*pny+iy)*pnz+iz];
1323                 }
1324             }
1325         }
1326     }
1327
1328     /* Copy periodic overlap in z */
1329 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1330     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1331     {
1332         int iy, iz;
1333
1334         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1335         {
1336             for (iz = 0; iz < overlap; iz++)
1337             {
1338                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1339                     pmegrid[(ix*pny+iy)*pnz+iz];
1340             }
1341         }
1342     }
1343 }
1344
1345
1346 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1347 #define DO_BSPLINE(order)                            \
1348     for (ithx = 0; (ithx < order); ithx++)                    \
1349     {                                                    \
1350         index_x = (i0+ithx)*pny*pnz;                     \
1351         valx    = qn*thx[ithx];                          \
1352                                                      \
1353         for (ithy = 0; (ithy < order); ithy++)                \
1354         {                                                \
1355             valxy    = valx*thy[ithy];                   \
1356             index_xy = index_x+(j0+ithy)*pnz;            \
1357                                                      \
1358             for (ithz = 0; (ithz < order); ithz++)            \
1359             {                                            \
1360                 index_xyz        = index_xy+(k0+ithz);   \
1361                 grid[index_xyz] += valxy*thz[ithz];      \
1362             }                                            \
1363         }                                                \
1364     }
1365
1366
1367 static void spread_q_bsplines_thread(pmegrid_t                    *pmegrid,
1368                                      pme_atomcomm_t               *atc,
1369                                      splinedata_t                 *spline,
1370                                      pme_spline_work_t gmx_unused *work)
1371 {
1372
1373     /* spread charges from home atoms to local grid */
1374     real          *grid;
1375     pme_overlap_t *ol;
1376     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1377     int       *    idxptr;
1378     int            order, norder, index_x, index_xy, index_xyz;
1379     real           valx, valxy, qn;
1380     real          *thx, *thy, *thz;
1381     int            localsize, bndsize;
1382     int            pnx, pny, pnz, ndatatot;
1383     int            offx, offy, offz;
1384
1385 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1386     real           thz_buffer[12], *thz_aligned;
1387
1388     thz_aligned = gmx_simd4_align_real(thz_buffer);
1389 #endif
1390
1391     pnx = pmegrid->s[XX];
1392     pny = pmegrid->s[YY];
1393     pnz = pmegrid->s[ZZ];
1394
1395     offx = pmegrid->offset[XX];
1396     offy = pmegrid->offset[YY];
1397     offz = pmegrid->offset[ZZ];
1398
1399     ndatatot = pnx*pny*pnz;
1400     grid     = pmegrid->grid;
1401     for (i = 0; i < ndatatot; i++)
1402     {
1403         grid[i] = 0;
1404     }
1405
1406     order = pmegrid->order;
1407
1408     for (nn = 0; nn < spline->n; nn++)
1409     {
1410         n  = spline->ind[nn];
1411         qn = atc->q[n];
1412
1413         if (qn != 0)
1414         {
1415             idxptr = atc->idx[n];
1416             norder = nn*order;
1417
1418             i0   = idxptr[XX] - offx;
1419             j0   = idxptr[YY] - offy;
1420             k0   = idxptr[ZZ] - offz;
1421
1422             thx = spline->theta[XX] + norder;
1423             thy = spline->theta[YY] + norder;
1424             thz = spline->theta[ZZ] + norder;
1425
1426             switch (order)
1427             {
1428                 case 4:
1429 #ifdef PME_SIMD4_SPREAD_GATHER
1430 #ifdef PME_SIMD4_UNALIGNED
1431 #define PME_SPREAD_SIMD4_ORDER4
1432 #else
1433 #define PME_SPREAD_SIMD4_ALIGNED
1434 #define PME_ORDER 4
1435 #endif
1436 #include "pme_simd4.h"
1437 #else
1438                     DO_BSPLINE(4);
1439 #endif
1440                     break;
1441                 case 5:
1442 #ifdef PME_SIMD4_SPREAD_GATHER
1443 #define PME_SPREAD_SIMD4_ALIGNED
1444 #define PME_ORDER 5
1445 #include "pme_simd4.h"
1446 #else
1447                     DO_BSPLINE(5);
1448 #endif
1449                     break;
1450                 default:
1451                     DO_BSPLINE(order);
1452                     break;
1453             }
1454         }
1455     }
1456 }
1457
1458 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1459 {
1460 #ifdef PME_SIMD4_SPREAD_GATHER
1461     if (pme_order == 5
1462 #ifndef PME_SIMD4_UNALIGNED
1463         || pme_order == 4
1464 #endif
1465         )
1466     {
1467         /* Round nz up to a multiple of 4 to ensure alignment */
1468         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1469     }
1470 #endif
1471 }
1472
1473 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1474 {
1475 #ifdef PME_SIMD4_SPREAD_GATHER
1476 #ifndef PME_SIMD4_UNALIGNED
1477     if (pme_order == 4)
1478     {
1479         /* Add extra elements to ensured aligned operations do not go
1480          * beyond the allocated grid size.
1481          * Note that for pme_order=5, the pme grid z-size alignment
1482          * ensures that we will not go beyond the grid size.
1483          */
1484         *gridsize += 4;
1485     }
1486 #endif
1487 #endif
1488 }
1489
1490 static void pmegrid_init(pmegrid_t *grid,
1491                          int cx, int cy, int cz,
1492                          int x0, int y0, int z0,
1493                          int x1, int y1, int z1,
1494                          gmx_bool set_alignment,
1495                          int pme_order,
1496                          real *ptr)
1497 {
1498     int nz, gridsize;
1499
1500     grid->ci[XX]     = cx;
1501     grid->ci[YY]     = cy;
1502     grid->ci[ZZ]     = cz;
1503     grid->offset[XX] = x0;
1504     grid->offset[YY] = y0;
1505     grid->offset[ZZ] = z0;
1506     grid->n[XX]      = x1 - x0 + pme_order - 1;
1507     grid->n[YY]      = y1 - y0 + pme_order - 1;
1508     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1509     copy_ivec(grid->n, grid->s);
1510
1511     nz = grid->s[ZZ];
1512     set_grid_alignment(&nz, pme_order);
1513     if (set_alignment)
1514     {
1515         grid->s[ZZ] = nz;
1516     }
1517     else if (nz != grid->s[ZZ])
1518     {
1519         gmx_incons("pmegrid_init call with an unaligned z size");
1520     }
1521
1522     grid->order = pme_order;
1523     if (ptr == NULL)
1524     {
1525         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1526         set_gridsize_alignment(&gridsize, pme_order);
1527         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1528     }
1529     else
1530     {
1531         grid->grid = ptr;
1532     }
1533 }
1534
1535 static int div_round_up(int enumerator, int denominator)
1536 {
1537     return (enumerator + denominator - 1)/denominator;
1538 }
1539
1540 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1541                                   ivec nsub)
1542 {
1543     int gsize_opt, gsize;
1544     int nsx, nsy, nsz;
1545     char *env;
1546
1547     gsize_opt = -1;
1548     for (nsx = 1; nsx <= nthread; nsx++)
1549     {
1550         if (nthread % nsx == 0)
1551         {
1552             for (nsy = 1; nsy <= nthread; nsy++)
1553             {
1554                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1555                 {
1556                     nsz = nthread/(nsx*nsy);
1557
1558                     /* Determine the number of grid points per thread */
1559                     gsize =
1560                         (div_round_up(n[XX], nsx) + ovl)*
1561                         (div_round_up(n[YY], nsy) + ovl)*
1562                         (div_round_up(n[ZZ], nsz) + ovl);
1563
1564                     /* Minimize the number of grids points per thread
1565                      * and, secondarily, the number of cuts in minor dimensions.
1566                      */
1567                     if (gsize_opt == -1 ||
1568                         gsize < gsize_opt ||
1569                         (gsize == gsize_opt &&
1570                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1571                     {
1572                         nsub[XX]  = nsx;
1573                         nsub[YY]  = nsy;
1574                         nsub[ZZ]  = nsz;
1575                         gsize_opt = gsize;
1576                     }
1577                 }
1578             }
1579         }
1580     }
1581
1582     env = getenv("GMX_PME_THREAD_DIVISION");
1583     if (env != NULL)
1584     {
1585         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1586     }
1587
1588     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1589     {
1590         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1591     }
1592 }
1593
1594 static void pmegrids_init(pmegrids_t *grids,
1595                           int nx, int ny, int nz, int nz_base,
1596                           int pme_order,
1597                           gmx_bool bUseThreads,
1598                           int nthread,
1599                           int overlap_x,
1600                           int overlap_y)
1601 {
1602     ivec n, n_base, g0, g1;
1603     int t, x, y, z, d, i, tfac;
1604     int max_comm_lines = -1;
1605
1606     n[XX] = nx - (pme_order - 1);
1607     n[YY] = ny - (pme_order - 1);
1608     n[ZZ] = nz - (pme_order - 1);
1609
1610     copy_ivec(n, n_base);
1611     n_base[ZZ] = nz_base;
1612
1613     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1614                  NULL);
1615
1616     grids->nthread = nthread;
1617
1618     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1619
1620     if (bUseThreads)
1621     {
1622         ivec nst;
1623         int gridsize;
1624
1625         for (d = 0; d < DIM; d++)
1626         {
1627             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1628         }
1629         set_grid_alignment(&nst[ZZ], pme_order);
1630
1631         if (debug)
1632         {
1633             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1634                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1635             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1636                     nx, ny, nz,
1637                     nst[XX], nst[YY], nst[ZZ]);
1638         }
1639
1640         snew(grids->grid_th, grids->nthread);
1641         t        = 0;
1642         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1643         set_gridsize_alignment(&gridsize, pme_order);
1644         snew_aligned(grids->grid_all,
1645                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1646                      SIMD4_ALIGNMENT);
1647
1648         for (x = 0; x < grids->nc[XX]; x++)
1649         {
1650             for (y = 0; y < grids->nc[YY]; y++)
1651             {
1652                 for (z = 0; z < grids->nc[ZZ]; z++)
1653                 {
1654                     pmegrid_init(&grids->grid_th[t],
1655                                  x, y, z,
1656                                  (n[XX]*(x  ))/grids->nc[XX],
1657                                  (n[YY]*(y  ))/grids->nc[YY],
1658                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1659                                  (n[XX]*(x+1))/grids->nc[XX],
1660                                  (n[YY]*(y+1))/grids->nc[YY],
1661                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1662                                  TRUE,
1663                                  pme_order,
1664                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1665                     t++;
1666                 }
1667             }
1668         }
1669     }
1670     else
1671     {
1672         grids->grid_th = NULL;
1673     }
1674
1675     snew(grids->g2t, DIM);
1676     tfac = 1;
1677     for (d = DIM-1; d >= 0; d--)
1678     {
1679         snew(grids->g2t[d], n[d]);
1680         t = 0;
1681         for (i = 0; i < n[d]; i++)
1682         {
1683             /* The second check should match the parameters
1684              * of the pmegrid_init call above.
1685              */
1686             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1687             {
1688                 t++;
1689             }
1690             grids->g2t[d][i] = t*tfac;
1691         }
1692
1693         tfac *= grids->nc[d];
1694
1695         switch (d)
1696         {
1697             case XX: max_comm_lines = overlap_x;     break;
1698             case YY: max_comm_lines = overlap_y;     break;
1699             case ZZ: max_comm_lines = pme_order - 1; break;
1700         }
1701         grids->nthread_comm[d] = 0;
1702         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1703                grids->nthread_comm[d] < grids->nc[d])
1704         {
1705             grids->nthread_comm[d]++;
1706         }
1707         if (debug != NULL)
1708         {
1709             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1710                     'x'+d, grids->nthread_comm[d]);
1711         }
1712         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1713          * work, but this is not a problematic restriction.
1714          */
1715         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1716         {
1717             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1718         }
1719     }
1720 }
1721
1722
1723 static void pmegrids_destroy(pmegrids_t *grids)
1724 {
1725     int t;
1726
1727     if (grids->grid.grid != NULL)
1728     {
1729         sfree(grids->grid.grid);
1730
1731         if (grids->nthread > 0)
1732         {
1733             for (t = 0; t < grids->nthread; t++)
1734             {
1735                 sfree(grids->grid_th[t].grid);
1736             }
1737             sfree(grids->grid_th);
1738         }
1739     }
1740 }
1741
1742
1743 static void realloc_work(pme_work_t *work, int nkx)
1744 {
1745     int simd_width;
1746
1747     if (nkx > work->nalloc)
1748     {
1749         work->nalloc = nkx;
1750         srenew(work->mhx, work->nalloc);
1751         srenew(work->mhy, work->nalloc);
1752         srenew(work->mhz, work->nalloc);
1753         srenew(work->m2, work->nalloc);
1754         /* Allocate an aligned pointer for SIMD operations, including extra
1755          * elements at the end for padding.
1756          */
1757 #ifdef PME_SIMD_SOLVE
1758         simd_width = GMX_SIMD_REAL_WIDTH;
1759 #else
1760         /* We can use any alignment, apart from 0, so we use 4 */
1761         simd_width = 4;
1762 #endif
1763         sfree_aligned(work->denom);
1764         sfree_aligned(work->tmp1);
1765         sfree_aligned(work->tmp2);
1766         sfree_aligned(work->eterm);
1767         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1768         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1769         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1770         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1771         srenew(work->m2inv, work->nalloc);
1772     }
1773 }
1774
1775
1776 static void free_work(pme_work_t *work)
1777 {
1778     sfree(work->mhx);
1779     sfree(work->mhy);
1780     sfree(work->mhz);
1781     sfree(work->m2);
1782     sfree_aligned(work->denom);
1783     sfree_aligned(work->tmp1);
1784     sfree_aligned(work->tmp2);
1785     sfree_aligned(work->eterm);
1786     sfree(work->m2inv);
1787 }
1788
1789
1790 #if defined PME_SIMD_SOLVE && defined GMX_SIMD_HAVE_EXP
1791 /* Calculate exponentials through SIMD */
1792 inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1793 {
1794     {
1795         const gmx_simd_real_t two = gmx_simd_set1_r(2.0);
1796         gmx_simd_real_t f_simd;
1797         gmx_simd_real_t lu;
1798         gmx_simd_real_t tmp_d1, d_inv, tmp_r, tmp_e;
1799         int kx;
1800         f_simd = gmx_simd_set1_r(f);
1801         /* We only need to calculate from start. But since start is 0 or 1
1802          * and we want to use aligned loads/stores, we always start from 0.
1803          */
1804         for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1805         {
1806             tmp_d1   = gmx_simd_load_r(d_aligned+kx);
1807             d_inv    = gmx_simd_inv_r(tmp_d1);
1808             tmp_r    = gmx_simd_load_r(r_aligned+kx);
1809             tmp_r    = gmx_simd_exp_r(tmp_r);
1810             tmp_e    = gmx_simd_mul_r(f_simd, d_inv);
1811             tmp_e    = gmx_simd_mul_r(tmp_e, tmp_r);
1812             gmx_simd_store_r(e_aligned+kx, tmp_e);
1813         }
1814     }
1815 }
1816 #else
1817 inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1818 {
1819     int kx;
1820     for (kx = start; kx < end; kx++)
1821     {
1822         d[kx] = 1.0/d[kx];
1823     }
1824     for (kx = start; kx < end; kx++)
1825     {
1826         r[kx] = exp(r[kx]);
1827     }
1828     for (kx = start; kx < end; kx++)
1829     {
1830         e[kx] = f*r[kx]*d[kx];
1831     }
1832 }
1833 #endif
1834
1835 #if defined PME_SIMD_SOLVE && defined GMX_SIMD_HAVE_ERFC
1836 /* Calculate exponentials through SIMD */
1837 inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1838 {
1839     gmx_simd_real_t tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1840     const gmx_simd_real_t sqr_PI = gmx_simd_sqrt_r(gmx_simd_set1_r(M_PI));
1841     int kx;
1842     for (kx = 0; kx < end; kx += GMX_SIMD_REAL_WIDTH)
1843     {
1844         /* We only need to calculate from start. But since start is 0 or 1
1845          * and we want to use aligned loads/stores, we always start from 0.
1846          */
1847         tmp_d = gmx_simd_load_r(d_aligned+kx);
1848         d_inv = gmx_simd_inv_r(tmp_d);
1849         gmx_simd_store_r(d_aligned+kx, d_inv);
1850         tmp_r = gmx_simd_load_r(r_aligned+kx);
1851         tmp_r = gmx_simd_exp_r(tmp_r);
1852         gmx_simd_store_r(r_aligned+kx, tmp_r);
1853         tmp_mk  = gmx_simd_load_r(factor_aligned+kx);
1854         tmp_fac = gmx_simd_mul_r(sqr_PI, gmx_simd_mul_r(tmp_mk, gmx_simd_erfc_r(tmp_mk)));
1855         gmx_simd_store_r(factor_aligned+kx, tmp_fac);
1856     }
1857 }
1858 #else
1859 inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1860 {
1861     int kx;
1862     real mk;
1863     for (kx = start; kx < end; kx++)
1864     {
1865         d[kx] = 1.0/d[kx];
1866     }
1867
1868     for (kx = start; kx < end; kx++)
1869     {
1870         r[kx] = exp(r[kx]);
1871     }
1872
1873     for (kx = start; kx < end; kx++)
1874     {
1875         mk       = tmp2[kx];
1876         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1877     }
1878 }
1879 #endif
1880
1881 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1882                          real ewaldcoeff, real vol,
1883                          gmx_bool bEnerVir,
1884                          int nthread, int thread)
1885 {
1886     /* do recip sum over local cells in grid */
1887     /* y major, z middle, x minor or continuous */
1888     t_complex *p0;
1889     int     kx, ky, kz, maxkx, maxky, maxkz;
1890     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1891     real    mx, my, mz;
1892     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1893     real    ets2, struct2, vfactor, ets2vf;
1894     real    d1, d2, energy = 0;
1895     real    by, bz;
1896     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1897     real    rxx, ryx, ryy, rzx, rzy, rzz;
1898     pme_work_t *work;
1899     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1900     real    mhxk, mhyk, mhzk, m2k;
1901     real    corner_fac;
1902     ivec    complex_order;
1903     ivec    local_ndata, local_offset, local_size;
1904     real    elfac;
1905
1906     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1907
1908     nx = pme->nkx;
1909     ny = pme->nky;
1910     nz = pme->nkz;
1911
1912     /* Dimensions should be identical for A/B grid, so we just use A here */
1913     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
1914                                       complex_order,
1915                                       local_ndata,
1916                                       local_offset,
1917                                       local_size);
1918
1919     rxx = pme->recipbox[XX][XX];
1920     ryx = pme->recipbox[YY][XX];
1921     ryy = pme->recipbox[YY][YY];
1922     rzx = pme->recipbox[ZZ][XX];
1923     rzy = pme->recipbox[ZZ][YY];
1924     rzz = pme->recipbox[ZZ][ZZ];
1925
1926     maxkx = (nx+1)/2;
1927     maxky = (ny+1)/2;
1928     maxkz = nz/2+1;
1929
1930     work  = &pme->work[thread];
1931     mhx   = work->mhx;
1932     mhy   = work->mhy;
1933     mhz   = work->mhz;
1934     m2    = work->m2;
1935     denom = work->denom;
1936     tmp1  = work->tmp1;
1937     eterm = work->eterm;
1938     m2inv = work->m2inv;
1939
1940     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1941     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1942
1943     for (iyz = iyz0; iyz < iyz1; iyz++)
1944     {
1945         iy = iyz/local_ndata[ZZ];
1946         iz = iyz - iy*local_ndata[ZZ];
1947
1948         ky = iy + local_offset[YY];
1949
1950         if (ky < maxky)
1951         {
1952             my = ky;
1953         }
1954         else
1955         {
1956             my = (ky - ny);
1957         }
1958
1959         by = M_PI*vol*pme->bsp_mod[YY][ky];
1960
1961         kz = iz + local_offset[ZZ];
1962
1963         mz = kz;
1964
1965         bz = pme->bsp_mod[ZZ][kz];
1966
1967         /* 0.5 correction for corner points */
1968         corner_fac = 1;
1969         if (kz == 0 || kz == (nz+1)/2)
1970         {
1971             corner_fac = 0.5;
1972         }
1973
1974         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1975
1976         /* We should skip the k-space point (0,0,0) */
1977         /* Note that since here x is the minor index, local_offset[XX]=0 */
1978         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1979         {
1980             kxstart = local_offset[XX];
1981         }
1982         else
1983         {
1984             kxstart = local_offset[XX] + 1;
1985             p0++;
1986         }
1987         kxend = local_offset[XX] + local_ndata[XX];
1988
1989         if (bEnerVir)
1990         {
1991             /* More expensive inner loop, especially because of the storage
1992              * of the mh elements in array's.
1993              * Because x is the minor grid index, all mh elements
1994              * depend on kx for triclinic unit cells.
1995              */
1996
1997             /* Two explicit loops to avoid a conditional inside the loop */
1998             for (kx = kxstart; kx < maxkx; kx++)
1999             {
2000                 mx = kx;
2001
2002                 mhxk      = mx * rxx;
2003                 mhyk      = mx * ryx + my * ryy;
2004                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2005                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2006                 mhx[kx]   = mhxk;
2007                 mhy[kx]   = mhyk;
2008                 mhz[kx]   = mhzk;
2009                 m2[kx]    = m2k;
2010                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2011                 tmp1[kx]  = -factor*m2k;
2012             }
2013
2014             for (kx = maxkx; kx < kxend; kx++)
2015             {
2016                 mx = (kx - nx);
2017
2018                 mhxk      = mx * rxx;
2019                 mhyk      = mx * ryx + my * ryy;
2020                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2021                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2022                 mhx[kx]   = mhxk;
2023                 mhy[kx]   = mhyk;
2024                 mhz[kx]   = mhzk;
2025                 m2[kx]    = m2k;
2026                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2027                 tmp1[kx]  = -factor*m2k;
2028             }
2029
2030             for (kx = kxstart; kx < kxend; kx++)
2031             {
2032                 m2inv[kx] = 1.0/m2[kx];
2033             }
2034
2035             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2036
2037             for (kx = kxstart; kx < kxend; kx++, p0++)
2038             {
2039                 d1      = p0->re;
2040                 d2      = p0->im;
2041
2042                 p0->re  = d1*eterm[kx];
2043                 p0->im  = d2*eterm[kx];
2044
2045                 struct2 = 2.0*(d1*d1+d2*d2);
2046
2047                 tmp1[kx] = eterm[kx]*struct2;
2048             }
2049
2050             for (kx = kxstart; kx < kxend; kx++)
2051             {
2052                 ets2     = corner_fac*tmp1[kx];
2053                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2054                 energy  += ets2;
2055
2056                 ets2vf   = ets2*vfactor;
2057                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2058                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2059                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2060                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2061                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2062                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2063             }
2064         }
2065         else
2066         {
2067             /* We don't need to calculate the energy and the virial.
2068              * In this case the triclinic overhead is small.
2069              */
2070
2071             /* Two explicit loops to avoid a conditional inside the loop */
2072
2073             for (kx = kxstart; kx < maxkx; kx++)
2074             {
2075                 mx = kx;
2076
2077                 mhxk      = mx * rxx;
2078                 mhyk      = mx * ryx + my * ryy;
2079                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2080                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2081                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2082                 tmp1[kx]  = -factor*m2k;
2083             }
2084
2085             for (kx = maxkx; kx < kxend; kx++)
2086             {
2087                 mx = (kx - nx);
2088
2089                 mhxk      = mx * rxx;
2090                 mhyk      = mx * ryx + my * ryy;
2091                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2092                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2093                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2094                 tmp1[kx]  = -factor*m2k;
2095             }
2096
2097             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2098
2099             for (kx = kxstart; kx < kxend; kx++, p0++)
2100             {
2101                 d1      = p0->re;
2102                 d2      = p0->im;
2103
2104                 p0->re  = d1*eterm[kx];
2105                 p0->im  = d2*eterm[kx];
2106             }
2107         }
2108     }
2109
2110     if (bEnerVir)
2111     {
2112         /* Update virial with local values.
2113          * The virial is symmetric by definition.
2114          * this virial seems ok for isotropic scaling, but I'm
2115          * experiencing problems on semiisotropic membranes.
2116          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2117          */
2118         work->vir_q[XX][XX] = 0.25*virxx;
2119         work->vir_q[YY][YY] = 0.25*viryy;
2120         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2121         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2122         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2123         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2124
2125         /* This energy should be corrected for a charged system */
2126         work->energy_q = 0.5*energy;
2127     }
2128
2129     /* Return the loop count */
2130     return local_ndata[YY]*local_ndata[XX];
2131 }
2132
2133 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2134                             real ewaldcoeff, real vol,
2135                             gmx_bool bEnerVir, int nthread, int thread)
2136 {
2137     /* do recip sum over local cells in grid */
2138     /* y major, z middle, x minor or continuous */
2139     int     ig, gcount;
2140     int     kx, ky, kz, maxkx, maxky, maxkz;
2141     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2142     real    mx, my, mz;
2143     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2144     real    ets2, ets2vf;
2145     real    eterm, vterm, d1, d2, energy = 0;
2146     real    by, bz;
2147     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2148     real    rxx, ryx, ryy, rzx, rzy, rzz;
2149     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2150     real    mhxk, mhyk, mhzk, m2k;
2151     real    mk;
2152     pme_work_t *work;
2153     real    corner_fac;
2154     ivec    complex_order;
2155     ivec    local_ndata, local_offset, local_size;
2156     nx = pme->nkx;
2157     ny = pme->nky;
2158     nz = pme->nkz;
2159
2160     /* Dimensions should be identical for A/B grid, so we just use A here */
2161     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2162                                       complex_order,
2163                                       local_ndata,
2164                                       local_offset,
2165                                       local_size);
2166     rxx = pme->recipbox[XX][XX];
2167     ryx = pme->recipbox[YY][XX];
2168     ryy = pme->recipbox[YY][YY];
2169     rzx = pme->recipbox[ZZ][XX];
2170     rzy = pme->recipbox[ZZ][YY];
2171     rzz = pme->recipbox[ZZ][ZZ];
2172
2173     maxkx = (nx+1)/2;
2174     maxky = (ny+1)/2;
2175     maxkz = nz/2+1;
2176
2177     work  = &pme->work[thread];
2178     mhx   = work->mhx;
2179     mhy   = work->mhy;
2180     mhz   = work->mhz;
2181     m2    = work->m2;
2182     denom = work->denom;
2183     tmp1  = work->tmp1;
2184     tmp2  = work->tmp2;
2185
2186     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2187     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2188
2189     for (iyz = iyz0; iyz < iyz1; iyz++)
2190     {
2191         iy = iyz/local_ndata[ZZ];
2192         iz = iyz - iy*local_ndata[ZZ];
2193
2194         ky = iy + local_offset[YY];
2195
2196         if (ky < maxky)
2197         {
2198             my = ky;
2199         }
2200         else
2201         {
2202             my = (ky - ny);
2203         }
2204
2205         by = 3.0*vol*pme->bsp_mod[YY][ky]
2206             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2207
2208         kz = iz + local_offset[ZZ];
2209
2210         mz = kz;
2211
2212         bz = pme->bsp_mod[ZZ][kz];
2213
2214         /* 0.5 correction for corner points */
2215         corner_fac = 1;
2216         if (kz == 0 || kz == (nz+1)/2)
2217         {
2218             corner_fac = 0.5;
2219         }
2220
2221         kxstart = local_offset[XX];
2222         kxend   = local_offset[XX] + local_ndata[XX];
2223         if (bEnerVir)
2224         {
2225             /* More expensive inner loop, especially because of the
2226              * storage of the mh elements in array's.  Because x is the
2227              * minor grid index, all mh elements depend on kx for
2228              * triclinic unit cells.
2229              */
2230
2231             /* Two explicit loops to avoid a conditional inside the loop */
2232             for (kx = kxstart; kx < maxkx; kx++)
2233             {
2234                 mx = kx;
2235
2236                 mhxk      = mx * rxx;
2237                 mhyk      = mx * ryx + my * ryy;
2238                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2239                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2240                 mhx[kx]   = mhxk;
2241                 mhy[kx]   = mhyk;
2242                 mhz[kx]   = mhzk;
2243                 m2[kx]    = m2k;
2244                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2245                 tmp1[kx]  = -factor*m2k;
2246                 tmp2[kx]  = sqrt(factor*m2k);
2247             }
2248
2249             for (kx = maxkx; kx < kxend; kx++)
2250             {
2251                 mx = (kx - nx);
2252
2253                 mhxk      = mx * rxx;
2254                 mhyk      = mx * ryx + my * ryy;
2255                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2256                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2257                 mhx[kx]   = mhxk;
2258                 mhy[kx]   = mhyk;
2259                 mhz[kx]   = mhzk;
2260                 m2[kx]    = m2k;
2261                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2262                 tmp1[kx]  = -factor*m2k;
2263                 tmp2[kx]  = sqrt(factor*m2k);
2264             }
2265
2266             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2267
2268             for (kx = kxstart; kx < kxend; kx++)
2269             {
2270                 m2k   = factor*m2[kx];
2271                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2272                           + 2.0*m2k*tmp2[kx]);
2273                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2274                 tmp1[kx] = eterm*denom[kx];
2275                 tmp2[kx] = vterm*denom[kx];
2276             }
2277
2278             if (!bLB)
2279             {
2280                 t_complex *p0;
2281                 real       struct2;
2282
2283                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2284                 for (kx = kxstart; kx < kxend; kx++, p0++)
2285                 {
2286                     d1      = p0->re;
2287                     d2      = p0->im;
2288
2289                     eterm   = tmp1[kx];
2290                     vterm   = tmp2[kx];
2291                     p0->re  = d1*eterm;
2292                     p0->im  = d2*eterm;
2293
2294                     struct2 = 2.0*(d1*d1+d2*d2);
2295
2296                     tmp1[kx] = eterm*struct2;
2297                     tmp2[kx] = vterm*struct2;
2298                 }
2299             }
2300             else
2301             {
2302                 real *struct2 = denom;
2303                 real  str2;
2304
2305                 for (kx = kxstart; kx < kxend; kx++)
2306                 {
2307                     struct2[kx] = 0.0;
2308                 }
2309                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2310                 for (ig = 0; ig <= 3; ++ig)
2311                 {
2312                     t_complex *p0, *p1;
2313                     real       scale;
2314
2315                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2316                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2317                     scale = 2.0*lb_scale_factor_symm[ig];
2318                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2319                     {
2320                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2321                     }
2322
2323                 }
2324                 for (ig = 0; ig <= 6; ++ig)
2325                 {
2326                     t_complex *p0;
2327
2328                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2329                     for (kx = kxstart; kx < kxend; kx++, p0++)
2330                     {
2331                         d1     = p0->re;
2332                         d2     = p0->im;
2333
2334                         eterm  = tmp1[kx];
2335                         p0->re = d1*eterm;
2336                         p0->im = d2*eterm;
2337                     }
2338                 }
2339                 for (kx = kxstart; kx < kxend; kx++)
2340                 {
2341                     eterm    = tmp1[kx];
2342                     vterm    = tmp2[kx];
2343                     str2     = struct2[kx];
2344                     tmp1[kx] = eterm*str2;
2345                     tmp2[kx] = vterm*str2;
2346                 }
2347             }
2348
2349             for (kx = kxstart; kx < kxend; kx++)
2350             {
2351                 ets2     = corner_fac*tmp1[kx];
2352                 vterm    = 2.0*factor*tmp2[kx];
2353                 energy  += ets2;
2354                 ets2vf   = corner_fac*vterm;
2355                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2356                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2357                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2358                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2359                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2360                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2361             }
2362         }
2363         else
2364         {
2365             /* We don't need to calculate the energy and the virial.
2366              *  In this case the triclinic overhead is small.
2367              */
2368
2369             /* Two explicit loops to avoid a conditional inside the loop */
2370
2371             for (kx = kxstart; kx < maxkx; kx++)
2372             {
2373                 mx = kx;
2374
2375                 mhxk      = mx * rxx;
2376                 mhyk      = mx * ryx + my * ryy;
2377                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2378                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2379                 m2[kx]    = m2k;
2380                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2381                 tmp1[kx]  = -factor*m2k;
2382                 tmp2[kx]  = sqrt(factor*m2k);
2383             }
2384
2385             for (kx = maxkx; kx < kxend; kx++)
2386             {
2387                 mx = (kx - nx);
2388
2389                 mhxk      = mx * rxx;
2390                 mhyk      = mx * ryx + my * ryy;
2391                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2392                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2393                 m2[kx]    = m2k;
2394                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2395                 tmp1[kx]  = -factor*m2k;
2396                 tmp2[kx]  = sqrt(factor*m2k);
2397             }
2398
2399             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2400
2401             for (kx = kxstart; kx < kxend; kx++)
2402             {
2403                 m2k    = factor*m2[kx];
2404                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2405                            + 2.0*m2k*tmp2[kx]);
2406                 tmp1[kx] = eterm*denom[kx];
2407             }
2408             gcount = (bLB ? 7 : 1);
2409             for (ig = 0; ig < gcount; ++ig)
2410             {
2411                 t_complex *p0;
2412
2413                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2414                 for (kx = kxstart; kx < kxend; kx++, p0++)
2415                 {
2416                     d1      = p0->re;
2417                     d2      = p0->im;
2418
2419                     eterm   = tmp1[kx];
2420
2421                     p0->re  = d1*eterm;
2422                     p0->im  = d2*eterm;
2423                 }
2424             }
2425         }
2426     }
2427     if (bEnerVir)
2428     {
2429         work->vir_lj[XX][XX] = 0.25*virxx;
2430         work->vir_lj[YY][YY] = 0.25*viryy;
2431         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2432         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2433         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2434         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2435         /* This energy should be corrected for a charged system */
2436         work->energy_lj = 0.5*energy;
2437     }
2438     /* Return the loop count */
2439     return local_ndata[YY]*local_ndata[XX];
2440 }
2441
2442 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2443                                real *mesh_energy, matrix vir)
2444 {
2445     /* This function sums output over threads and should therefore
2446      * only be called after thread synchronization.
2447      */
2448     int thread;
2449
2450     *mesh_energy = pme->work[0].energy_q;
2451     copy_mat(pme->work[0].vir_q, vir);
2452
2453     for (thread = 1; thread < nthread; thread++)
2454     {
2455         *mesh_energy += pme->work[thread].energy_q;
2456         m_add(vir, pme->work[thread].vir_q, vir);
2457     }
2458 }
2459
2460 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2461                                 real *mesh_energy, matrix vir)
2462 {
2463     /* This function sums output over threads and should therefore
2464      * only be called after thread synchronization.
2465      */
2466     int thread;
2467
2468     *mesh_energy = pme->work[0].energy_lj;
2469     copy_mat(pme->work[0].vir_lj, vir);
2470
2471     for (thread = 1; thread < nthread; thread++)
2472     {
2473         *mesh_energy += pme->work[thread].energy_lj;
2474         m_add(vir, pme->work[thread].vir_lj, vir);
2475     }
2476 }
2477
2478
2479 #define DO_FSPLINE(order)                      \
2480     for (ithx = 0; (ithx < order); ithx++)              \
2481     {                                              \
2482         index_x = (i0+ithx)*pny*pnz;               \
2483         tx      = thx[ithx];                       \
2484         dx      = dthx[ithx];                      \
2485                                                \
2486         for (ithy = 0; (ithy < order); ithy++)          \
2487         {                                          \
2488             index_xy = index_x+(j0+ithy)*pnz;      \
2489             ty       = thy[ithy];                  \
2490             dy       = dthy[ithy];                 \
2491             fxy1     = fz1 = 0;                    \
2492                                                \
2493             for (ithz = 0; (ithz < order); ithz++)      \
2494             {                                      \
2495                 gval  = grid[index_xy+(k0+ithz)];  \
2496                 fxy1 += thz[ithz]*gval;            \
2497                 fz1  += dthz[ithz]*gval;           \
2498             }                                      \
2499             fx += dx*ty*fxy1;                      \
2500             fy += tx*dy*fxy1;                      \
2501             fz += tx*ty*fz1;                       \
2502         }                                          \
2503     }
2504
2505
2506 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2507                               gmx_bool bClearF, pme_atomcomm_t *atc,
2508                               splinedata_t *spline,
2509                               real scale)
2510 {
2511     /* sum forces for local particles */
2512     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2513     int     index_x, index_xy;
2514     int     nx, ny, nz, pnx, pny, pnz;
2515     int *   idxptr;
2516     real    tx, ty, dx, dy, qn;
2517     real    fx, fy, fz, gval;
2518     real    fxy1, fz1;
2519     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2520     int     norder;
2521     real    rxx, ryx, ryy, rzx, rzy, rzz;
2522     int     order;
2523
2524     pme_spline_work_t *work;
2525
2526 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2527     real           thz_buffer[12],  *thz_aligned;
2528     real           dthz_buffer[12], *dthz_aligned;
2529
2530     thz_aligned  = gmx_simd4_align_real(thz_buffer);
2531     dthz_aligned = gmx_simd4_align_real(dthz_buffer);
2532 #endif
2533
2534     work = pme->spline_work;
2535
2536     order = pme->pme_order;
2537     thx   = spline->theta[XX];
2538     thy   = spline->theta[YY];
2539     thz   = spline->theta[ZZ];
2540     dthx  = spline->dtheta[XX];
2541     dthy  = spline->dtheta[YY];
2542     dthz  = spline->dtheta[ZZ];
2543     nx    = pme->nkx;
2544     ny    = pme->nky;
2545     nz    = pme->nkz;
2546     pnx   = pme->pmegrid_nx;
2547     pny   = pme->pmegrid_ny;
2548     pnz   = pme->pmegrid_nz;
2549
2550     rxx   = pme->recipbox[XX][XX];
2551     ryx   = pme->recipbox[YY][XX];
2552     ryy   = pme->recipbox[YY][YY];
2553     rzx   = pme->recipbox[ZZ][XX];
2554     rzy   = pme->recipbox[ZZ][YY];
2555     rzz   = pme->recipbox[ZZ][ZZ];
2556
2557     for (nn = 0; nn < spline->n; nn++)
2558     {
2559         n  = spline->ind[nn];
2560         qn = scale*atc->q[n];
2561
2562         if (bClearF)
2563         {
2564             atc->f[n][XX] = 0;
2565             atc->f[n][YY] = 0;
2566             atc->f[n][ZZ] = 0;
2567         }
2568         if (qn != 0)
2569         {
2570             fx     = 0;
2571             fy     = 0;
2572             fz     = 0;
2573             idxptr = atc->idx[n];
2574             norder = nn*order;
2575
2576             i0   = idxptr[XX];
2577             j0   = idxptr[YY];
2578             k0   = idxptr[ZZ];
2579
2580             /* Pointer arithmetic alert, next six statements */
2581             thx  = spline->theta[XX] + norder;
2582             thy  = spline->theta[YY] + norder;
2583             thz  = spline->theta[ZZ] + norder;
2584             dthx = spline->dtheta[XX] + norder;
2585             dthy = spline->dtheta[YY] + norder;
2586             dthz = spline->dtheta[ZZ] + norder;
2587
2588             switch (order)
2589             {
2590                 case 4:
2591 #ifdef PME_SIMD4_SPREAD_GATHER
2592 #ifdef PME_SIMD4_UNALIGNED
2593 #define PME_GATHER_F_SIMD4_ORDER4
2594 #else
2595 #define PME_GATHER_F_SIMD4_ALIGNED
2596 #define PME_ORDER 4
2597 #endif
2598 #include "pme_simd4.h"
2599 #else
2600                     DO_FSPLINE(4);
2601 #endif
2602                     break;
2603                 case 5:
2604 #ifdef PME_SIMD4_SPREAD_GATHER
2605 #define PME_GATHER_F_SIMD4_ALIGNED
2606 #define PME_ORDER 5
2607 #include "pme_simd4.h"
2608 #else
2609                     DO_FSPLINE(5);
2610 #endif
2611                     break;
2612                 default:
2613                     DO_FSPLINE(order);
2614                     break;
2615             }
2616
2617             atc->f[n][XX] += -qn*( fx*nx*rxx );
2618             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2619             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2620         }
2621     }
2622     /* Since the energy and not forces are interpolated
2623      * the net force might not be exactly zero.
2624      * This can be solved by also interpolating F, but
2625      * that comes at a cost.
2626      * A better hack is to remove the net force every
2627      * step, but that must be done at a higher level
2628      * since this routine doesn't see all atoms if running
2629      * in parallel. Don't know how important it is?  EL 990726
2630      */
2631 }
2632
2633
2634 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2635                                    pme_atomcomm_t *atc)
2636 {
2637     splinedata_t *spline;
2638     int     n, ithx, ithy, ithz, i0, j0, k0;
2639     int     index_x, index_xy;
2640     int *   idxptr;
2641     real    energy, pot, tx, ty, qn, gval;
2642     real    *thx, *thy, *thz;
2643     int     norder;
2644     int     order;
2645
2646     spline = &atc->spline[0];
2647
2648     order = pme->pme_order;
2649
2650     energy = 0;
2651     for (n = 0; (n < atc->n); n++)
2652     {
2653         qn      = atc->q[n];
2654
2655         if (qn != 0)
2656         {
2657             idxptr = atc->idx[n];
2658             norder = n*order;
2659
2660             i0   = idxptr[XX];
2661             j0   = idxptr[YY];
2662             k0   = idxptr[ZZ];
2663
2664             /* Pointer arithmetic alert, next three statements */
2665             thx  = spline->theta[XX] + norder;
2666             thy  = spline->theta[YY] + norder;
2667             thz  = spline->theta[ZZ] + norder;
2668
2669             pot = 0;
2670             for (ithx = 0; (ithx < order); ithx++)
2671             {
2672                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2673                 tx      = thx[ithx];
2674
2675                 for (ithy = 0; (ithy < order); ithy++)
2676                 {
2677                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2678                     ty       = thy[ithy];
2679
2680                     for (ithz = 0; (ithz < order); ithz++)
2681                     {
2682                         gval  = grid[index_xy+(k0+ithz)];
2683                         pot  += tx*ty*thz[ithz]*gval;
2684                     }
2685
2686                 }
2687             }
2688
2689             energy += pot*qn;
2690         }
2691     }
2692
2693     return energy;
2694 }
2695
2696 /* Macro to force loop unrolling by fixing order.
2697  * This gives a significant performance gain.
2698  */
2699 #define CALC_SPLINE(order)                     \
2700     {                                              \
2701         int j, k, l;                                 \
2702         real dr, div;                               \
2703         real data[PME_ORDER_MAX];                  \
2704         real ddata[PME_ORDER_MAX];                 \
2705                                                \
2706         for (j = 0; (j < DIM); j++)                     \
2707         {                                          \
2708             dr  = xptr[j];                         \
2709                                                \
2710             /* dr is relative offset from lower cell limit */ \
2711             data[order-1] = 0;                     \
2712             data[1]       = dr;                          \
2713             data[0]       = 1 - dr;                      \
2714                                                \
2715             for (k = 3; (k < order); k++)               \
2716             {                                      \
2717                 div       = 1.0/(k - 1.0);               \
2718                 data[k-1] = div*dr*data[k-2];      \
2719                 for (l = 1; (l < (k-1)); l++)           \
2720                 {                                  \
2721                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2722                                        data[k-l-1]);                \
2723                 }                                  \
2724                 data[0] = div*(1-dr)*data[0];      \
2725             }                                      \
2726             /* differentiate */                    \
2727             ddata[0] = -data[0];                   \
2728             for (k = 1; (k < order); k++)               \
2729             {                                      \
2730                 ddata[k] = data[k-1] - data[k];    \
2731             }                                      \
2732                                                \
2733             div           = 1.0/(order - 1);                 \
2734             data[order-1] = div*dr*data[order-2];  \
2735             for (l = 1; (l < (order-1)); l++)           \
2736             {                                      \
2737                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2738                                        (order-l-dr)*data[order-l-1]); \
2739             }                                      \
2740             data[0] = div*(1 - dr)*data[0];        \
2741                                                \
2742             for (k = 0; k < order; k++)                 \
2743             {                                      \
2744                 theta[j][i*order+k]  = data[k];    \
2745                 dtheta[j][i*order+k] = ddata[k];   \
2746             }                                      \
2747         }                                          \
2748     }
2749
2750 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2751                    rvec fractx[], int nr, int ind[], real charge[],
2752                    gmx_bool bDoSplines)
2753 {
2754     /* construct splines for local atoms */
2755     int  i, ii;
2756     real *xptr;
2757
2758     for (i = 0; i < nr; i++)
2759     {
2760         /* With free energy we do not use the charge check.
2761          * In most cases this will be more efficient than calling make_bsplines
2762          * twice, since usually more than half the particles have charges.
2763          */
2764         ii = ind[i];
2765         if (bDoSplines || charge[ii] != 0.0)
2766         {
2767             xptr = fractx[ii];
2768             switch (order)
2769             {
2770                 case 4:  CALC_SPLINE(4);     break;
2771                 case 5:  CALC_SPLINE(5);     break;
2772                 default: CALC_SPLINE(order); break;
2773             }
2774         }
2775     }
2776 }
2777
2778
2779 void make_dft_mod(real *mod, real *data, int ndata)
2780 {
2781     int i, j;
2782     real sc, ss, arg;
2783
2784     for (i = 0; i < ndata; i++)
2785     {
2786         sc = ss = 0;
2787         for (j = 0; j < ndata; j++)
2788         {
2789             arg = (2.0*M_PI*i*j)/ndata;
2790             sc += data[j]*cos(arg);
2791             ss += data[j]*sin(arg);
2792         }
2793         mod[i] = sc*sc+ss*ss;
2794     }
2795     for (i = 0; i < ndata; i++)
2796     {
2797         if (mod[i] < 1e-7)
2798         {
2799             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2800         }
2801     }
2802 }
2803
2804
2805 static void make_bspline_moduli(splinevec bsp_mod,
2806                                 int nx, int ny, int nz, int order)
2807 {
2808     int nmax = max(nx, max(ny, nz));
2809     real *data, *ddata, *bsp_data;
2810     int i, k, l;
2811     real div;
2812
2813     snew(data, order);
2814     snew(ddata, order);
2815     snew(bsp_data, nmax);
2816
2817     data[order-1] = 0;
2818     data[1]       = 0;
2819     data[0]       = 1;
2820
2821     for (k = 3; k < order; k++)
2822     {
2823         div       = 1.0/(k-1.0);
2824         data[k-1] = 0;
2825         for (l = 1; l < (k-1); l++)
2826         {
2827             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2828         }
2829         data[0] = div*data[0];
2830     }
2831     /* differentiate */
2832     ddata[0] = -data[0];
2833     for (k = 1; k < order; k++)
2834     {
2835         ddata[k] = data[k-1]-data[k];
2836     }
2837     div           = 1.0/(order-1);
2838     data[order-1] = 0;
2839     for (l = 1; l < (order-1); l++)
2840     {
2841         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2842     }
2843     data[0] = div*data[0];
2844
2845     for (i = 0; i < nmax; i++)
2846     {
2847         bsp_data[i] = 0;
2848     }
2849     for (i = 1; i <= order; i++)
2850     {
2851         bsp_data[i] = data[i-1];
2852     }
2853
2854     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2855     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2856     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2857
2858     sfree(data);
2859     sfree(ddata);
2860     sfree(bsp_data);
2861 }
2862
2863
2864 /* Return the P3M optimal influence function */
2865 static double do_p3m_influence(double z, int order)
2866 {
2867     double z2, z4;
2868
2869     z2 = z*z;
2870     z4 = z2*z2;
2871
2872     /* The formula and most constants can be found in:
2873      * Ballenegger et al., JCTC 8, 936 (2012)
2874      */
2875     switch (order)
2876     {
2877         case 2:
2878             return 1.0 - 2.0*z2/3.0;
2879             break;
2880         case 3:
2881             return 1.0 - z2 + 2.0*z4/15.0;
2882             break;
2883         case 4:
2884             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2885             break;
2886         case 5:
2887             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2888             break;
2889         case 6:
2890             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2891             break;
2892         case 7:
2893             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2894         case 8:
2895             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2896             break;
2897     }
2898
2899     return 0.0;
2900 }
2901
2902 /* Calculate the P3M B-spline moduli for one dimension */
2903 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2904 {
2905     double zarg, zai, sinzai, infl;
2906     int    maxk, i;
2907
2908     if (order > 8)
2909     {
2910         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2911     }
2912
2913     zarg = M_PI/n;
2914
2915     maxk = (n + 1)/2;
2916
2917     for (i = -maxk; i < 0; i++)
2918     {
2919         zai          = zarg*i;
2920         sinzai       = sin(zai);
2921         infl         = do_p3m_influence(sinzai, order);
2922         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2923     }
2924     bsp_mod[0] = 1.0;
2925     for (i = 1; i < maxk; i++)
2926     {
2927         zai        = zarg*i;
2928         sinzai     = sin(zai);
2929         infl       = do_p3m_influence(sinzai, order);
2930         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2931     }
2932 }
2933
2934 /* Calculate the P3M B-spline moduli */
2935 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2936                                     int nx, int ny, int nz, int order)
2937 {
2938     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2939     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2940     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2941 }
2942
2943
2944 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2945 {
2946     int nslab, n, i;
2947     int fw, bw;
2948
2949     nslab = atc->nslab;
2950
2951     n = 0;
2952     for (i = 1; i <= nslab/2; i++)
2953     {
2954         fw = (atc->nodeid + i) % nslab;
2955         bw = (atc->nodeid - i + nslab) % nslab;
2956         if (n < nslab - 1)
2957         {
2958             atc->node_dest[n] = fw;
2959             atc->node_src[n]  = bw;
2960             n++;
2961         }
2962         if (n < nslab - 1)
2963         {
2964             atc->node_dest[n] = bw;
2965             atc->node_src[n]  = fw;
2966             n++;
2967         }
2968     }
2969 }
2970
2971 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2972 {
2973     int thread, i;
2974
2975     if (NULL != log)
2976     {
2977         fprintf(log, "Destroying PME data structures.\n");
2978     }
2979
2980     sfree((*pmedata)->nnx);
2981     sfree((*pmedata)->nny);
2982     sfree((*pmedata)->nnz);
2983
2984     for (i = 0; i < (*pmedata)->ngrids; ++i)
2985     {
2986         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
2987         sfree((*pmedata)->fftgrid[i]);
2988         sfree((*pmedata)->cfftgrid[i]);
2989         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
2990     }
2991
2992     sfree((*pmedata)->lb_buf1);
2993     sfree((*pmedata)->lb_buf2);
2994
2995     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2996     {
2997         free_work(&(*pmedata)->work[thread]);
2998     }
2999     sfree((*pmedata)->work);
3000
3001     sfree(*pmedata);
3002     *pmedata = NULL;
3003
3004     return 0;
3005 }
3006
3007 static int mult_up(int n, int f)
3008 {
3009     return ((n + f - 1)/f)*f;
3010 }
3011
3012
3013 static double pme_load_imbalance(gmx_pme_t pme)
3014 {
3015     int    nma, nmi;
3016     double n1, n2, n3;
3017
3018     nma = pme->nnodes_major;
3019     nmi = pme->nnodes_minor;
3020
3021     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3022     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3023     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3024
3025     /* pme_solve is roughly double the cost of an fft */
3026
3027     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3028 }
3029
3030 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3031                           int dimind, gmx_bool bSpread)
3032 {
3033     int nk, k, s, thread;
3034
3035     atc->dimind    = dimind;
3036     atc->nslab     = 1;
3037     atc->nodeid    = 0;
3038     atc->pd_nalloc = 0;
3039 #ifdef GMX_MPI
3040     if (pme->nnodes > 1)
3041     {
3042         atc->mpi_comm = pme->mpi_comm_d[dimind];
3043         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3044         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3045     }
3046     if (debug)
3047     {
3048         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3049     }
3050 #endif
3051
3052     atc->bSpread   = bSpread;
3053     atc->pme_order = pme->pme_order;
3054
3055     if (atc->nslab > 1)
3056     {
3057         snew(atc->node_dest, atc->nslab);
3058         snew(atc->node_src, atc->nslab);
3059         setup_coordinate_communication(atc);
3060
3061         snew(atc->count_thread, pme->nthread);
3062         for (thread = 0; thread < pme->nthread; thread++)
3063         {
3064             snew(atc->count_thread[thread], atc->nslab);
3065         }
3066         atc->count = atc->count_thread[0];
3067         snew(atc->rcount, atc->nslab);
3068         snew(atc->buf_index, atc->nslab);
3069     }
3070
3071     atc->nthread = pme->nthread;
3072     if (atc->nthread > 1)
3073     {
3074         snew(atc->thread_plist, atc->nthread);
3075     }
3076     snew(atc->spline, atc->nthread);
3077     for (thread = 0; thread < atc->nthread; thread++)
3078     {
3079         if (atc->nthread > 1)
3080         {
3081             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3082             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3083         }
3084         snew(atc->spline[thread].thread_one, pme->nthread);
3085         atc->spline[thread].thread_one[thread] = 1;
3086     }
3087 }
3088
3089 static void
3090 init_overlap_comm(pme_overlap_t *  ol,
3091                   int              norder,
3092 #ifdef GMX_MPI
3093                   MPI_Comm         comm,
3094 #endif
3095                   int              nnodes,
3096                   int              nodeid,
3097                   int              ndata,
3098                   int              commplainsize)
3099 {
3100     int lbnd, rbnd, maxlr, b, i;
3101     int exten;
3102     int nn, nk;
3103     pme_grid_comm_t *pgc;
3104     gmx_bool bCont;
3105     int fft_start, fft_end, send_index1, recv_index1;
3106 #ifdef GMX_MPI
3107     MPI_Status stat;
3108
3109     ol->mpi_comm = comm;
3110 #endif
3111
3112     ol->nnodes = nnodes;
3113     ol->nodeid = nodeid;
3114
3115     /* Linear translation of the PME grid won't affect reciprocal space
3116      * calculations, so to optimize we only interpolate "upwards",
3117      * which also means we only have to consider overlap in one direction.
3118      * I.e., particles on this node might also be spread to grid indices
3119      * that belong to higher nodes (modulo nnodes)
3120      */
3121
3122     snew(ol->s2g0, ol->nnodes+1);
3123     snew(ol->s2g1, ol->nnodes);
3124     if (debug)
3125     {
3126         fprintf(debug, "PME slab boundaries:");
3127     }
3128     for (i = 0; i < nnodes; i++)
3129     {
3130         /* s2g0 the local interpolation grid start.
3131          * s2g1 the local interpolation grid end.
3132          * Because grid overlap communication only goes forward,
3133          * the grid the slabs for fft's should be rounded down.
3134          */
3135         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3136         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3137
3138         if (debug)
3139         {
3140             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3141         }
3142     }
3143     ol->s2g0[nnodes] = ndata;
3144     if (debug)
3145     {
3146         fprintf(debug, "\n");
3147     }
3148
3149     /* Determine with how many nodes we need to communicate the grid overlap */
3150     b = 0;
3151     do
3152     {
3153         b++;
3154         bCont = FALSE;
3155         for (i = 0; i < nnodes; i++)
3156         {
3157             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3158                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3159             {
3160                 bCont = TRUE;
3161             }
3162         }
3163     }
3164     while (bCont && b < nnodes);
3165     ol->noverlap_nodes = b - 1;
3166
3167     snew(ol->send_id, ol->noverlap_nodes);
3168     snew(ol->recv_id, ol->noverlap_nodes);
3169     for (b = 0; b < ol->noverlap_nodes; b++)
3170     {
3171         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3172         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3173     }
3174     snew(ol->comm_data, ol->noverlap_nodes);
3175
3176     ol->send_size = 0;
3177     for (b = 0; b < ol->noverlap_nodes; b++)
3178     {
3179         pgc = &ol->comm_data[b];
3180         /* Send */
3181         fft_start        = ol->s2g0[ol->send_id[b]];
3182         fft_end          = ol->s2g0[ol->send_id[b]+1];
3183         if (ol->send_id[b] < nodeid)
3184         {
3185             fft_start += ndata;
3186             fft_end   += ndata;
3187         }
3188         send_index1       = ol->s2g1[nodeid];
3189         send_index1       = min(send_index1, fft_end);
3190         pgc->send_index0  = fft_start;
3191         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3192         ol->send_size    += pgc->send_nindex;
3193
3194         /* We always start receiving to the first index of our slab */
3195         fft_start        = ol->s2g0[ol->nodeid];
3196         fft_end          = ol->s2g0[ol->nodeid+1];
3197         recv_index1      = ol->s2g1[ol->recv_id[b]];
3198         if (ol->recv_id[b] > nodeid)
3199         {
3200             recv_index1 -= ndata;
3201         }
3202         recv_index1      = min(recv_index1, fft_end);
3203         pgc->recv_index0 = fft_start;
3204         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3205     }
3206
3207 #ifdef GMX_MPI
3208     /* Communicate the buffer sizes to receive */
3209     for (b = 0; b < ol->noverlap_nodes; b++)
3210     {
3211         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3212                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3213                      ol->mpi_comm, &stat);
3214     }
3215 #endif
3216
3217     /* For non-divisible grid we need pme_order iso pme_order-1 */
3218     snew(ol->sendbuf, norder*commplainsize);
3219     snew(ol->recvbuf, norder*commplainsize);
3220 }
3221
3222 static void
3223 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3224                               int **global_to_local,
3225                               real **fraction_shift)
3226 {
3227     int i;
3228     int * gtl;
3229     real * fsh;
3230
3231     snew(gtl, 5*n);
3232     snew(fsh, 5*n);
3233     for (i = 0; (i < 5*n); i++)
3234     {
3235         /* Determine the global to local grid index */
3236         gtl[i] = (i - local_start + n) % n;
3237         /* For coordinates that fall within the local grid the fraction
3238          * is correct, we don't need to shift it.
3239          */
3240         fsh[i] = 0;
3241         if (local_range < n)
3242         {
3243             /* Due to rounding issues i could be 1 beyond the lower or
3244              * upper boundary of the local grid. Correct the index for this.
3245              * If we shift the index, we need to shift the fraction by
3246              * the same amount in the other direction to not affect
3247              * the weights.
3248              * Note that due to this shifting the weights at the end of
3249              * the spline might change, but that will only involve values
3250              * between zero and values close to the precision of a real,
3251              * which is anyhow the accuracy of the whole mesh calculation.
3252              */
3253             /* With local_range=0 we should not change i=local_start */
3254             if (i % n != local_start)
3255             {
3256                 if (gtl[i] == n-1)
3257                 {
3258                     gtl[i] = 0;
3259                     fsh[i] = -1;
3260                 }
3261                 else if (gtl[i] == local_range)
3262                 {
3263                     gtl[i] = local_range - 1;
3264                     fsh[i] = 1;
3265                 }
3266             }
3267         }
3268     }
3269
3270     *global_to_local = gtl;
3271     *fraction_shift  = fsh;
3272 }
3273
3274 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3275 {
3276     pme_spline_work_t *work;
3277
3278 #ifdef PME_SIMD4_SPREAD_GATHER
3279     real         tmp[12], *tmp_aligned;
3280     gmx_simd4_real_t zero_S;
3281     gmx_simd4_real_t real_mask_S0, real_mask_S1;
3282     int          of, i;
3283
3284     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3285
3286     tmp_aligned = gmx_simd4_align_real(tmp);
3287
3288     zero_S = gmx_simd4_setzero_r();
3289
3290     /* Generate bit masks to mask out the unused grid entries,
3291      * as we only operate on order of the 8 grid entries that are
3292      * load into 2 SIMD registers.
3293      */
3294     for (of = 0; of < 8-(order-1); of++)
3295     {
3296         for (i = 0; i < 8; i++)
3297         {
3298             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3299         }
3300         real_mask_S0      = gmx_simd4_load_r(tmp_aligned);
3301         real_mask_S1      = gmx_simd4_load_r(tmp_aligned+4);
3302         work->mask_S0[of] = gmx_simd4_cmplt_r(real_mask_S0, zero_S);
3303         work->mask_S1[of] = gmx_simd4_cmplt_r(real_mask_S1, zero_S);
3304     }
3305 #else
3306     work = NULL;
3307 #endif
3308
3309     return work;
3310 }
3311
3312 void gmx_pme_check_restrictions(int pme_order,
3313                                 int nkx, int nky, int nkz,
3314                                 int nnodes_major,
3315                                 int nnodes_minor,
3316                                 gmx_bool bUseThreads,
3317                                 gmx_bool bFatal,
3318                                 gmx_bool *bValidSettings)
3319 {
3320     if (pme_order > PME_ORDER_MAX)
3321     {
3322         if (!bFatal)
3323         {
3324             *bValidSettings = FALSE;
3325             return;
3326         }
3327         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3328                   pme_order, PME_ORDER_MAX);
3329     }
3330
3331     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3332         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3333         nkz <= pme_order)
3334     {
3335         if (!bFatal)
3336         {
3337             *bValidSettings = FALSE;
3338             return;
3339         }
3340         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3341                   pme_order);
3342     }
3343
3344     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3345      * We only allow multiple communication pulses in dim 1, not in dim 0.
3346      */
3347     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3348                         nkx != nnodes_major*(pme_order - 1)))
3349     {
3350         if (!bFatal)
3351         {
3352             *bValidSettings = FALSE;
3353             return;
3354         }
3355         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3356                   nkx/(double)nnodes_major, pme_order);
3357     }
3358
3359     if (bValidSettings != NULL)
3360     {
3361         *bValidSettings = TRUE;
3362     }
3363
3364     return;
3365 }
3366
3367 int gmx_pme_init(gmx_pme_t *         pmedata,
3368                  t_commrec *         cr,
3369                  int                 nnodes_major,
3370                  int                 nnodes_minor,
3371                  t_inputrec *        ir,
3372                  int                 homenr,
3373                  gmx_bool            bFreeEnergy_q,
3374                  gmx_bool            bFreeEnergy_lj,
3375                  gmx_bool            bReproducible,
3376                  int                 nthread)
3377 {
3378     gmx_pme_t pme = NULL;
3379
3380     int  use_threads, sum_use_threads, i;
3381     ivec ndata;
3382
3383     if (debug)
3384     {
3385         fprintf(debug, "Creating PME data structures.\n");
3386     }
3387     snew(pme, 1);
3388
3389     pme->sum_qgrid_tmp       = NULL;
3390     pme->sum_qgrid_dd_tmp    = NULL;
3391     pme->buf_nalloc          = 0;
3392
3393     pme->nnodes              = 1;
3394     pme->bPPnode             = TRUE;
3395
3396     pme->nnodes_major        = nnodes_major;
3397     pme->nnodes_minor        = nnodes_minor;
3398
3399 #ifdef GMX_MPI
3400     if (nnodes_major*nnodes_minor > 1)
3401     {
3402         pme->mpi_comm = cr->mpi_comm_mygroup;
3403
3404         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3405         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3406         if (pme->nnodes != nnodes_major*nnodes_minor)
3407         {
3408             gmx_incons("PME node count mismatch");
3409         }
3410     }
3411     else
3412     {
3413         pme->mpi_comm = MPI_COMM_NULL;
3414     }
3415 #endif
3416
3417     if (pme->nnodes == 1)
3418     {
3419 #ifdef GMX_MPI
3420         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3421         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3422 #endif
3423         pme->ndecompdim   = 0;
3424         pme->nodeid_major = 0;
3425         pme->nodeid_minor = 0;
3426 #ifdef GMX_MPI
3427         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3428 #endif
3429     }
3430     else
3431     {
3432         if (nnodes_minor == 1)
3433         {
3434 #ifdef GMX_MPI
3435             pme->mpi_comm_d[0] = pme->mpi_comm;
3436             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3437 #endif
3438             pme->ndecompdim   = 1;
3439             pme->nodeid_major = pme->nodeid;
3440             pme->nodeid_minor = 0;
3441
3442         }
3443         else if (nnodes_major == 1)
3444         {
3445 #ifdef GMX_MPI
3446             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3447             pme->mpi_comm_d[1] = pme->mpi_comm;
3448 #endif
3449             pme->ndecompdim   = 1;
3450             pme->nodeid_major = 0;
3451             pme->nodeid_minor = pme->nodeid;
3452         }
3453         else
3454         {
3455             if (pme->nnodes % nnodes_major != 0)
3456             {
3457                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3458             }
3459             pme->ndecompdim = 2;
3460
3461 #ifdef GMX_MPI
3462             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3463                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3464             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3465                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3466
3467             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3468             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3469             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3470             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3471 #endif
3472         }
3473         pme->bPPnode = (cr->duty & DUTY_PP);
3474     }
3475
3476     pme->nthread = nthread;
3477
3478     /* Check if any of the PME MPI ranks uses threads */
3479     use_threads = (pme->nthread > 1 ? 1 : 0);
3480 #ifdef GMX_MPI
3481     if (pme->nnodes > 1)
3482     {
3483         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3484                       MPI_SUM, pme->mpi_comm);
3485     }
3486     else
3487 #endif
3488     {
3489         sum_use_threads = use_threads;
3490     }
3491     pme->bUseThreads = (sum_use_threads > 0);
3492
3493     if (ir->ePBC == epbcSCREW)
3494     {
3495         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3496     }
3497
3498     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3499     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3500     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3501     pme->nkx         = ir->nkx;
3502     pme->nky         = ir->nky;
3503     pme->nkz         = ir->nkz;
3504     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3505     pme->pme_order   = ir->pme_order;
3506     pme->epsilon_r   = ir->epsilon_r;
3507
3508     /* If we violate restrictions, generate a fatal error here */
3509     gmx_pme_check_restrictions(pme->pme_order,
3510                                pme->nkx, pme->nky, pme->nkz,
3511                                pme->nnodes_major,
3512                                pme->nnodes_minor,
3513                                pme->bUseThreads,
3514                                TRUE,
3515                                NULL);
3516
3517     if (pme->nnodes > 1)
3518     {
3519         double imbal;
3520
3521 #ifdef GMX_MPI
3522         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3523         MPI_Type_commit(&(pme->rvec_mpi));
3524 #endif
3525
3526         /* Note that the charge spreading and force gathering, which usually
3527          * takes about the same amount of time as FFT+solve_pme,
3528          * is always fully load balanced
3529          * (unless the charge distribution is inhomogeneous).
3530          */
3531
3532         imbal = pme_load_imbalance(pme);
3533         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3534         {
3535             fprintf(stderr,
3536                     "\n"
3537                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3538                     "      For optimal PME load balancing\n"
3539                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3540                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3541                     "\n",
3542                     (int)((imbal-1)*100 + 0.5),
3543                     pme->nkx, pme->nky, pme->nnodes_major,
3544                     pme->nky, pme->nkz, pme->nnodes_minor);
3545         }
3546     }
3547
3548     /* For non-divisible grid we need pme_order iso pme_order-1 */
3549     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3550      * y is always copied through a buffer: we don't need padding in z,
3551      * but we do need the overlap in x because of the communication order.
3552      */
3553     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3554 #ifdef GMX_MPI
3555                       pme->mpi_comm_d[0],
3556 #endif
3557                       pme->nnodes_major, pme->nodeid_major,
3558                       pme->nkx,
3559                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3560
3561     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3562      * We do this with an offset buffer of equal size, so we need to allocate
3563      * extra for the offset. That's what the (+1)*pme->nkz is for.
3564      */
3565     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3566 #ifdef GMX_MPI
3567                       pme->mpi_comm_d[1],
3568 #endif
3569                       pme->nnodes_minor, pme->nodeid_minor,
3570                       pme->nky,
3571                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3572
3573     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3574      * Note that gmx_pme_check_restrictions checked for this already.
3575      */
3576     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3577     {
3578         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3579     }
3580
3581     snew(pme->bsp_mod[XX], pme->nkx);
3582     snew(pme->bsp_mod[YY], pme->nky);
3583     snew(pme->bsp_mod[ZZ], pme->nkz);
3584
3585     /* The required size of the interpolation grid, including overlap.
3586      * The allocated size (pmegrid_n?) might be slightly larger.
3587      */
3588     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3589         pme->overlap[0].s2g0[pme->nodeid_major];
3590     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3591         pme->overlap[1].s2g0[pme->nodeid_minor];
3592     pme->pmegrid_nz_base = pme->nkz;
3593     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3594     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3595
3596     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3597     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3598     pme->pmegrid_start_iz = 0;
3599
3600     make_gridindex5_to_localindex(pme->nkx,
3601                                   pme->pmegrid_start_ix,
3602                                   pme->pmegrid_nx - (pme->pme_order-1),
3603                                   &pme->nnx, &pme->fshx);
3604     make_gridindex5_to_localindex(pme->nky,
3605                                   pme->pmegrid_start_iy,
3606                                   pme->pmegrid_ny - (pme->pme_order-1),
3607                                   &pme->nny, &pme->fshy);
3608     make_gridindex5_to_localindex(pme->nkz,
3609                                   pme->pmegrid_start_iz,
3610                                   pme->pmegrid_nz_base,
3611                                   &pme->nnz, &pme->fshz);
3612
3613     pme->spline_work = make_pme_spline_work(pme->pme_order);
3614
3615     ndata[0]    = pme->nkx;
3616     ndata[1]    = pme->nky;
3617     ndata[2]    = pme->nkz;
3618     pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3619     snew(pme->fftgrid, pme->ngrids);
3620     snew(pme->cfftgrid, pme->ngrids);
3621     snew(pme->pfft_setup, pme->ngrids);
3622
3623     for (i = 0; i < pme->ngrids; ++i)
3624     {
3625         if (((ir->ljpme_combination_rule == eljpmeLB) && i >= 2) || i % 2 == 0 || bFreeEnergy_q || bFreeEnergy_lj)
3626         {
3627             pmegrids_init(&pme->pmegrid[i],
3628                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3629                           pme->pmegrid_nz_base,
3630                           pme->pme_order,
3631                           pme->bUseThreads,
3632                           pme->nthread,
3633                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3634                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3635             /* This routine will allocate the grid data to fit the FFTs */
3636             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3637                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3638                                     pme->mpi_comm_d,
3639                                     bReproducible, pme->nthread);
3640
3641         }
3642     }
3643
3644     if (!pme->bP3M)
3645     {
3646         /* Use plain SPME B-spline interpolation */
3647         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3648     }
3649     else
3650     {
3651         /* Use the P3M grid-optimized influence function */
3652         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3653     }
3654
3655     /* Use atc[0] for spreading */
3656     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3657     if (pme->ndecompdim >= 2)
3658     {
3659         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3660     }
3661
3662     if (pme->nnodes == 1)
3663     {
3664         pme->atc[0].n = homenr;
3665         pme_realloc_atomcomm_things(&pme->atc[0]);
3666     }
3667
3668     pme->lb_buf1       = NULL;
3669     pme->lb_buf2       = NULL;
3670     pme->lb_buf_nalloc = 0;
3671
3672     {
3673         int thread;
3674
3675         /* Use fft5d, order after FFT is y major, z, x minor */
3676
3677         snew(pme->work, pme->nthread);
3678         for (thread = 0; thread < pme->nthread; thread++)
3679         {
3680             realloc_work(&pme->work[thread], pme->nkx);
3681         }
3682     }
3683
3684     *pmedata = pme;
3685
3686     return 0;
3687 }
3688
3689 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3690 {
3691     int d, t;
3692
3693     for (d = 0; d < DIM; d++)
3694     {
3695         if (new->grid.n[d] > old->grid.n[d])
3696         {
3697             return;
3698         }
3699     }
3700
3701     sfree_aligned(new->grid.grid);
3702     new->grid.grid = old->grid.grid;
3703
3704     if (new->grid_th != NULL && new->nthread == old->nthread)
3705     {
3706         sfree_aligned(new->grid_all);
3707         for (t = 0; t < new->nthread; t++)
3708         {
3709             new->grid_th[t].grid = old->grid_th[t].grid;
3710         }
3711     }
3712 }
3713
3714 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3715                    t_commrec *         cr,
3716                    gmx_pme_t           pme_src,
3717                    const t_inputrec *  ir,
3718                    ivec                grid_size)
3719 {
3720     t_inputrec irc;
3721     int homenr;
3722     int ret;
3723
3724     irc     = *ir;
3725     irc.nkx = grid_size[XX];
3726     irc.nky = grid_size[YY];
3727     irc.nkz = grid_size[ZZ];
3728
3729     if (pme_src->nnodes == 1)
3730     {
3731         homenr = pme_src->atc[0].n;
3732     }
3733     else
3734     {
3735         homenr = -1;
3736     }
3737
3738     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3739                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3740
3741     if (ret == 0)
3742     {
3743         /* We can easily reuse the allocated pme grids in pme_src */
3744         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3745         /* We would like to reuse the fft grids, but that's harder */
3746     }
3747
3748     return ret;
3749 }
3750
3751
3752 static void copy_local_grid(gmx_pme_t pme,
3753                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3754 {
3755     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3756     int  fft_my, fft_mz;
3757     int  nsx, nsy, nsz;
3758     ivec nf;
3759     int  offx, offy, offz, x, y, z, i0, i0t;
3760     int  d;
3761     pmegrid_t *pmegrid;
3762     real *grid_th;
3763
3764     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
3765                                    local_fft_ndata,
3766                                    local_fft_offset,
3767                                    local_fft_size);
3768     fft_my = local_fft_size[YY];
3769     fft_mz = local_fft_size[ZZ];
3770
3771     pmegrid = &pmegrids->grid_th[thread];
3772
3773     nsx = pmegrid->s[XX];
3774     nsy = pmegrid->s[YY];
3775     nsz = pmegrid->s[ZZ];
3776
3777     for (d = 0; d < DIM; d++)
3778     {
3779         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3780                     local_fft_ndata[d] - pmegrid->offset[d]);
3781     }
3782
3783     offx = pmegrid->offset[XX];
3784     offy = pmegrid->offset[YY];
3785     offz = pmegrid->offset[ZZ];
3786
3787     /* Directly copy the non-overlapping parts of the local grids.
3788      * This also initializes the full grid.
3789      */
3790     grid_th = pmegrid->grid;
3791     for (x = 0; x < nf[XX]; x++)
3792     {
3793         for (y = 0; y < nf[YY]; y++)
3794         {
3795             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3796             i0t = (x*nsy + y)*nsz;
3797             for (z = 0; z < nf[ZZ]; z++)
3798             {
3799                 fftgrid[i0+z] = grid_th[i0t+z];
3800             }
3801         }
3802     }
3803 }
3804
3805 static void
3806 reduce_threadgrid_overlap(gmx_pme_t pme,
3807                           const pmegrids_t *pmegrids, int thread,
3808                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3809 {
3810     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3811     int  fft_nx, fft_ny, fft_nz;
3812     int  fft_my, fft_mz;
3813     int  buf_my = -1;
3814     int  nsx, nsy, nsz;
3815     ivec ne;
3816     int  offx, offy, offz, x, y, z, i0, i0t;
3817     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3818     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3819     gmx_bool bCommX, bCommY;
3820     int  d;
3821     int  thread_f;
3822     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3823     const real *grid_th;
3824     real *commbuf = NULL;
3825
3826     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
3827                                    local_fft_ndata,
3828                                    local_fft_offset,
3829                                    local_fft_size);
3830     fft_nx = local_fft_ndata[XX];
3831     fft_ny = local_fft_ndata[YY];
3832     fft_nz = local_fft_ndata[ZZ];
3833
3834     fft_my = local_fft_size[YY];
3835     fft_mz = local_fft_size[ZZ];
3836
3837     /* This routine is called when all thread have finished spreading.
3838      * Here each thread sums grid contributions calculated by other threads
3839      * to the thread local grid volume.
3840      * To minimize the number of grid copying operations,
3841      * this routines sums immediately from the pmegrid to the fftgrid.
3842      */
3843
3844     /* Determine which part of the full node grid we should operate on,
3845      * this is our thread local part of the full grid.
3846      */
3847     pmegrid = &pmegrids->grid_th[thread];
3848
3849     for (d = 0; d < DIM; d++)
3850     {
3851         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3852                     local_fft_ndata[d]);
3853     }
3854
3855     offx = pmegrid->offset[XX];
3856     offy = pmegrid->offset[YY];
3857     offz = pmegrid->offset[ZZ];
3858
3859
3860     bClearBufX  = TRUE;
3861     bClearBufY  = TRUE;
3862     bClearBufXY = TRUE;
3863
3864     /* Now loop over all the thread data blocks that contribute
3865      * to the grid region we (our thread) are operating on.
3866      */
3867     /* Note that ffy_nx/y is equal to the number of grid points
3868      * between the first point of our node grid and the one of the next node.
3869      */
3870     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3871     {
3872         fx     = pmegrid->ci[XX] + sx;
3873         ox     = 0;
3874         bCommX = FALSE;
3875         if (fx < 0)
3876         {
3877             fx    += pmegrids->nc[XX];
3878             ox    -= fft_nx;
3879             bCommX = (pme->nnodes_major > 1);
3880         }
3881         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3882         ox       += pmegrid_g->offset[XX];
3883         if (!bCommX)
3884         {
3885             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3886         }
3887         else
3888         {
3889             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3890         }
3891
3892         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3893         {
3894             fy     = pmegrid->ci[YY] + sy;
3895             oy     = 0;
3896             bCommY = FALSE;
3897             if (fy < 0)
3898             {
3899                 fy    += pmegrids->nc[YY];
3900                 oy    -= fft_ny;
3901                 bCommY = (pme->nnodes_minor > 1);
3902             }
3903             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3904             oy       += pmegrid_g->offset[YY];
3905             if (!bCommY)
3906             {
3907                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3908             }
3909             else
3910             {
3911                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3912             }
3913
3914             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3915             {
3916                 fz = pmegrid->ci[ZZ] + sz;
3917                 oz = 0;
3918                 if (fz < 0)
3919                 {
3920                     fz += pmegrids->nc[ZZ];
3921                     oz -= fft_nz;
3922                 }
3923                 pmegrid_g = &pmegrids->grid_th[fz];
3924                 oz       += pmegrid_g->offset[ZZ];
3925                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3926
3927                 if (sx == 0 && sy == 0 && sz == 0)
3928                 {
3929                     /* We have already added our local contribution
3930                      * before calling this routine, so skip it here.
3931                      */
3932                     continue;
3933                 }
3934
3935                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3936
3937                 pmegrid_f = &pmegrids->grid_th[thread_f];
3938
3939                 grid_th = pmegrid_f->grid;
3940
3941                 nsx = pmegrid_f->s[XX];
3942                 nsy = pmegrid_f->s[YY];
3943                 nsz = pmegrid_f->s[ZZ];
3944
3945 #ifdef DEBUG_PME_REDUCE
3946                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3947                        pme->nodeid, thread, thread_f,
3948                        pme->pmegrid_start_ix,
3949                        pme->pmegrid_start_iy,
3950                        pme->pmegrid_start_iz,
3951                        sx, sy, sz,
3952                        offx-ox, tx1-ox, offx, tx1,
3953                        offy-oy, ty1-oy, offy, ty1,
3954                        offz-oz, tz1-oz, offz, tz1);
3955 #endif
3956
3957                 if (!(bCommX || bCommY))
3958                 {
3959                     /* Copy from the thread local grid to the node grid */
3960                     for (x = offx; x < tx1; x++)
3961                     {
3962                         for (y = offy; y < ty1; y++)
3963                         {
3964                             i0  = (x*fft_my + y)*fft_mz;
3965                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3966                             for (z = offz; z < tz1; z++)
3967                             {
3968                                 fftgrid[i0+z] += grid_th[i0t+z];
3969                             }
3970                         }
3971                     }
3972                 }
3973                 else
3974                 {
3975                     /* The order of this conditional decides
3976                      * where the corner volume gets stored with x+y decomp.
3977                      */
3978                     if (bCommY)
3979                     {
3980                         commbuf = commbuf_y;
3981                         buf_my  = ty1 - offy;
3982                         if (bCommX)
3983                         {
3984                             /* We index commbuf modulo the local grid size */
3985                             commbuf += buf_my*fft_nx*fft_nz;
3986
3987                             bClearBuf   = bClearBufXY;
3988                             bClearBufXY = FALSE;
3989                         }
3990                         else
3991                         {
3992                             bClearBuf  = bClearBufY;
3993                             bClearBufY = FALSE;
3994                         }
3995                     }
3996                     else
3997                     {
3998                         commbuf    = commbuf_x;
3999                         buf_my     = fft_ny;
4000                         bClearBuf  = bClearBufX;
4001                         bClearBufX = FALSE;
4002                     }
4003
4004                     /* Copy to the communication buffer */
4005                     for (x = offx; x < tx1; x++)
4006                     {
4007                         for (y = offy; y < ty1; y++)
4008                         {
4009                             i0  = (x*buf_my + y)*fft_nz;
4010                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4011
4012                             if (bClearBuf)
4013                             {
4014                                 /* First access of commbuf, initialize it */
4015                                 for (z = offz; z < tz1; z++)
4016                                 {
4017                                     commbuf[i0+z]  = grid_th[i0t+z];
4018                                 }
4019                             }
4020                             else
4021                             {
4022                                 for (z = offz; z < tz1; z++)
4023                                 {
4024                                     commbuf[i0+z] += grid_th[i0t+z];
4025                                 }
4026                             }
4027                         }
4028                     }
4029                 }
4030             }
4031         }
4032     }
4033 }
4034
4035
4036 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
4037 {
4038     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4039     pme_overlap_t *overlap;
4040     int  send_index0, send_nindex;
4041     int  recv_nindex;
4042 #ifdef GMX_MPI
4043     MPI_Status stat;
4044 #endif
4045     int  send_size_y, recv_size_y;
4046     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4047     real *sendptr, *recvptr;
4048     int  x, y, z, indg, indb;
4049
4050     /* Note that this routine is only used for forward communication.
4051      * Since the force gathering, unlike the charge spreading,
4052      * can be trivially parallelized over the particles,
4053      * the backwards process is much simpler and can use the "old"
4054      * communication setup.
4055      */
4056
4057     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4058                                    local_fft_ndata,
4059                                    local_fft_offset,
4060                                    local_fft_size);
4061
4062     if (pme->nnodes_minor > 1)
4063     {
4064         /* Major dimension */
4065         overlap = &pme->overlap[1];
4066
4067         if (pme->nnodes_major > 1)
4068         {
4069             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4070         }
4071         else
4072         {
4073             size_yx = 0;
4074         }
4075         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4076
4077         send_size_y = overlap->send_size;
4078
4079         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4080         {
4081             send_id       = overlap->send_id[ipulse];
4082             recv_id       = overlap->recv_id[ipulse];
4083             send_index0   =
4084                 overlap->comm_data[ipulse].send_index0 -
4085                 overlap->comm_data[0].send_index0;
4086             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4087             /* We don't use recv_index0, as we always receive starting at 0 */
4088             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4089             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4090
4091             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4092             recvptr = overlap->recvbuf;
4093
4094 #ifdef GMX_MPI
4095             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4096                          send_id, ipulse,
4097                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4098                          recv_id, ipulse,
4099                          overlap->mpi_comm, &stat);
4100 #endif
4101
4102             for (x = 0; x < local_fft_ndata[XX]; x++)
4103             {
4104                 for (y = 0; y < recv_nindex; y++)
4105                 {
4106                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4107                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4108                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4109                     {
4110                         fftgrid[indg+z] += recvptr[indb+z];
4111                     }
4112                 }
4113             }
4114
4115             if (pme->nnodes_major > 1)
4116             {
4117                 /* Copy from the received buffer to the send buffer for dim 0 */
4118                 sendptr = pme->overlap[0].sendbuf;
4119                 for (x = 0; x < size_yx; x++)
4120                 {
4121                     for (y = 0; y < recv_nindex; y++)
4122                     {
4123                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4124                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4125                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4126                         {
4127                             sendptr[indg+z] += recvptr[indb+z];
4128                         }
4129                     }
4130                 }
4131             }
4132         }
4133     }
4134
4135     /* We only support a single pulse here.
4136      * This is not a severe limitation, as this code is only used
4137      * with OpenMP and with OpenMP the (PME) domains can be larger.
4138      */
4139     if (pme->nnodes_major > 1)
4140     {
4141         /* Major dimension */
4142         overlap = &pme->overlap[0];
4143
4144         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4145         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4146
4147         ipulse = 0;
4148
4149         send_id       = overlap->send_id[ipulse];
4150         recv_id       = overlap->recv_id[ipulse];
4151         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4152         /* We don't use recv_index0, as we always receive starting at 0 */
4153         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4154
4155         sendptr = overlap->sendbuf;
4156         recvptr = overlap->recvbuf;
4157
4158         if (debug != NULL)
4159         {
4160             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
4161                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4162         }
4163
4164 #ifdef GMX_MPI
4165         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4166                      send_id, ipulse,
4167                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4168                      recv_id, ipulse,
4169                      overlap->mpi_comm, &stat);
4170 #endif
4171
4172         for (x = 0; x < recv_nindex; x++)
4173         {
4174             for (y = 0; y < local_fft_ndata[YY]; y++)
4175             {
4176                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4177                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4178                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4179                 {
4180                     fftgrid[indg+z] += recvptr[indb+z];
4181                 }
4182             }
4183         }
4184     }
4185 }
4186
4187
4188 static void spread_on_grid(gmx_pme_t pme,
4189                            pme_atomcomm_t *atc, pmegrids_t *grids,
4190                            gmx_bool bCalcSplines, gmx_bool bSpread,
4191                            real *fftgrid, gmx_bool bDoSplines )
4192 {
4193     int nthread, thread;
4194 #ifdef PME_TIME_THREADS
4195     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4196     static double cs1     = 0, cs2 = 0, cs3 = 0;
4197     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4198     static int cnt        = 0;
4199 #endif
4200
4201     nthread = pme->nthread;
4202     assert(nthread > 0);
4203
4204 #ifdef PME_TIME_THREADS
4205     c1 = omp_cyc_start();
4206 #endif
4207     if (bCalcSplines)
4208     {
4209 #pragma omp parallel for num_threads(nthread) schedule(static)
4210         for (thread = 0; thread < nthread; thread++)
4211         {
4212             int start, end;
4213
4214             start = atc->n* thread   /nthread;
4215             end   = atc->n*(thread+1)/nthread;
4216
4217             /* Compute fftgrid index for all atoms,
4218              * with help of some extra variables.
4219              */
4220             calc_interpolation_idx(pme, atc, start, end, thread);
4221         }
4222     }
4223 #ifdef PME_TIME_THREADS
4224     c1   = omp_cyc_end(c1);
4225     cs1 += (double)c1;
4226 #endif
4227
4228 #ifdef PME_TIME_THREADS
4229     c2 = omp_cyc_start();
4230 #endif
4231 #pragma omp parallel for num_threads(nthread) schedule(static)
4232     for (thread = 0; thread < nthread; thread++)
4233     {
4234         splinedata_t *spline;
4235         pmegrid_t *grid = NULL;
4236
4237         /* make local bsplines  */
4238         if (grids == NULL || !pme->bUseThreads)
4239         {
4240             spline = &atc->spline[0];
4241
4242             spline->n = atc->n;
4243
4244             if (bSpread)
4245             {
4246                 grid = &grids->grid;
4247             }
4248         }
4249         else
4250         {
4251             spline = &atc->spline[thread];
4252
4253             if (grids->nthread == 1)
4254             {
4255                 /* One thread, we operate on all charges */
4256                 spline->n = atc->n;
4257             }
4258             else
4259             {
4260                 /* Get the indices our thread should operate on */
4261                 make_thread_local_ind(atc, thread, spline);
4262             }
4263
4264             grid = &grids->grid_th[thread];
4265         }
4266
4267         if (bCalcSplines)
4268         {
4269             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4270                           atc->fractx, spline->n, spline->ind, atc->q, bDoSplines);
4271         }
4272
4273         if (bSpread)
4274         {
4275             /* put local atoms on grid. */
4276 #ifdef PME_TIME_SPREAD
4277             ct1a = omp_cyc_start();
4278 #endif
4279             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4280
4281             if (pme->bUseThreads)
4282             {
4283                 copy_local_grid(pme, grids, thread, fftgrid);
4284             }
4285 #ifdef PME_TIME_SPREAD
4286             ct1a          = omp_cyc_end(ct1a);
4287             cs1a[thread] += (double)ct1a;
4288 #endif
4289         }
4290     }
4291 #ifdef PME_TIME_THREADS
4292     c2   = omp_cyc_end(c2);
4293     cs2 += (double)c2;
4294 #endif
4295
4296     if (bSpread && pme->bUseThreads)
4297     {
4298 #ifdef PME_TIME_THREADS
4299         c3 = omp_cyc_start();
4300 #endif
4301 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4302         for (thread = 0; thread < grids->nthread; thread++)
4303         {
4304             reduce_threadgrid_overlap(pme, grids, thread,
4305                                       fftgrid,
4306                                       pme->overlap[0].sendbuf,
4307                                       pme->overlap[1].sendbuf);
4308         }
4309 #ifdef PME_TIME_THREADS
4310         c3   = omp_cyc_end(c3);
4311         cs3 += (double)c3;
4312 #endif
4313
4314         if (pme->nnodes > 1)
4315         {
4316             /* Communicate the overlapping part of the fftgrid.
4317              * For this communication call we need to check pme->bUseThreads
4318              * to have all ranks communicate here, regardless of pme->nthread.
4319              */
4320             sum_fftgrid_dd(pme, fftgrid);
4321         }
4322     }
4323
4324 #ifdef PME_TIME_THREADS
4325     cnt++;
4326     if (cnt % 20 == 0)
4327     {
4328         printf("idx %.2f spread %.2f red %.2f",
4329                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4330 #ifdef PME_TIME_SPREAD
4331         for (thread = 0; thread < nthread; thread++)
4332         {
4333             printf(" %.2f", cs1a[thread]*1e-9);
4334         }
4335 #endif
4336         printf("\n");
4337     }
4338 #endif
4339 }
4340
4341
4342 static void dump_grid(FILE *fp,
4343                       int sx, int sy, int sz, int nx, int ny, int nz,
4344                       int my, int mz, const real *g)
4345 {
4346     int x, y, z;
4347
4348     for (x = 0; x < nx; x++)
4349     {
4350         for (y = 0; y < ny; y++)
4351         {
4352             for (z = 0; z < nz; z++)
4353             {
4354                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4355                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4356             }
4357         }
4358     }
4359 }
4360
4361 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4362 {
4363     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4364
4365     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4366                                    local_fft_ndata,
4367                                    local_fft_offset,
4368                                    local_fft_size);
4369
4370     dump_grid(stderr,
4371               pme->pmegrid_start_ix,
4372               pme->pmegrid_start_iy,
4373               pme->pmegrid_start_iz,
4374               pme->pmegrid_nx-pme->pme_order+1,
4375               pme->pmegrid_ny-pme->pme_order+1,
4376               pme->pmegrid_nz-pme->pme_order+1,
4377               local_fft_size[YY],
4378               local_fft_size[ZZ],
4379               fftgrid);
4380 }
4381
4382
4383 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4384 {
4385     pme_atomcomm_t *atc;
4386     pmegrids_t *grid;
4387
4388     if (pme->nnodes > 1)
4389     {
4390         gmx_incons("gmx_pme_calc_energy called in parallel");
4391     }
4392     if (pme->bFEP > 1)
4393     {
4394         gmx_incons("gmx_pme_calc_energy with free energy");
4395     }
4396
4397     atc            = &pme->atc_energy;
4398     atc->nthread   = 1;
4399     if (atc->spline == NULL)
4400     {
4401         snew(atc->spline, atc->nthread);
4402     }
4403     atc->nslab     = 1;
4404     atc->bSpread   = TRUE;
4405     atc->pme_order = pme->pme_order;
4406     atc->n         = n;
4407     pme_realloc_atomcomm_things(atc);
4408     atc->x         = x;
4409     atc->q         = q;
4410
4411     /* We only use the A-charges grid */
4412     grid = &pme->pmegrid[PME_GRID_QA];
4413
4414     /* Only calculate the spline coefficients, don't actually spread */
4415     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE);
4416
4417     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4418 }
4419
4420
4421 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4422                                    gmx_walltime_accounting_t walltime_accounting,
4423                                    t_nrnb *nrnb, t_inputrec *ir,
4424                                    gmx_int64_t step)
4425 {
4426     /* Reset all the counters related to performance over the run */
4427     wallcycle_stop(wcycle, ewcRUN);
4428     wallcycle_reset_all(wcycle);
4429     init_nrnb(nrnb);
4430     if (ir->nsteps >= 0)
4431     {
4432         /* ir->nsteps is not used here, but we update it for consistency */
4433         ir->nsteps -= step - ir->init_step;
4434     }
4435     ir->init_step = step;
4436     wallcycle_start(wcycle, ewcRUN);
4437     walltime_accounting_start(walltime_accounting);
4438 }
4439
4440
4441 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4442                                ivec grid_size,
4443                                t_commrec *cr, t_inputrec *ir,
4444                                gmx_pme_t *pme_ret)
4445 {
4446     int ind;
4447     gmx_pme_t pme = NULL;
4448
4449     ind = 0;
4450     while (ind < *npmedata)
4451     {
4452         pme = (*pmedata)[ind];
4453         if (pme->nkx == grid_size[XX] &&
4454             pme->nky == grid_size[YY] &&
4455             pme->nkz == grid_size[ZZ])
4456         {
4457             *pme_ret = pme;
4458
4459             return;
4460         }
4461
4462         ind++;
4463     }
4464
4465     (*npmedata)++;
4466     srenew(*pmedata, *npmedata);
4467
4468     /* Generate a new PME data structure, copying part of the old pointers */
4469     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4470
4471     *pme_ret = (*pmedata)[ind];
4472 }
4473
4474 int gmx_pmeonly(gmx_pme_t pme,
4475                 t_commrec *cr,    t_nrnb *mynrnb,
4476                 gmx_wallcycle_t wcycle,
4477                 gmx_walltime_accounting_t walltime_accounting,
4478                 real ewaldcoeff_q, real ewaldcoeff_lj,
4479                 t_inputrec *ir)
4480 {
4481     int npmedata;
4482     gmx_pme_t *pmedata;
4483     gmx_pme_pp_t pme_pp;
4484     int  ret;
4485     int  natoms;
4486     matrix box;
4487     rvec *x_pp      = NULL, *f_pp = NULL;
4488     real *chargeA   = NULL, *chargeB = NULL;
4489     real *c6A       = NULL, *c6B = NULL;
4490     real *sigmaA    = NULL, *sigmaB = NULL;
4491     real lambda_q   = 0;
4492     real lambda_lj  = 0;
4493     int  maxshift_x = 0, maxshift_y = 0;
4494     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4495     matrix vir_q, vir_lj;
4496     float cycles;
4497     int  count;
4498     gmx_bool bEnerVir;
4499     int pme_flags;
4500     gmx_int64_t step, step_rel;
4501     ivec grid_switch;
4502
4503     /* This data will only use with PME tuning, i.e. switching PME grids */
4504     npmedata = 1;
4505     snew(pmedata, npmedata);
4506     pmedata[0] = pme;
4507
4508     pme_pp = gmx_pme_pp_init(cr);
4509
4510     init_nrnb(mynrnb);
4511
4512     count = 0;
4513     do /****** this is a quasi-loop over time steps! */
4514     {
4515         /* The reason for having a loop here is PME grid tuning/switching */
4516         do
4517         {
4518             /* Domain decomposition */
4519             ret = gmx_pme_recv_params_coords(pme_pp,
4520                                              &natoms,
4521                                              &chargeA, &chargeB,
4522                                              &c6A, &c6B,
4523                                              &sigmaA, &sigmaB,
4524                                              box, &x_pp, &f_pp,
4525                                              &maxshift_x, &maxshift_y,
4526                                              &pme->bFEP_q, &pme->bFEP_lj,
4527                                              &lambda_q, &lambda_lj,
4528                                              &bEnerVir,
4529                                              &pme_flags,
4530                                              &step,
4531                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4532
4533             if (ret == pmerecvqxSWITCHGRID)
4534             {
4535                 /* Switch the PME grid to grid_switch */
4536                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4537             }
4538
4539             if (ret == pmerecvqxRESETCOUNTERS)
4540             {
4541                 /* Reset the cycle and flop counters */
4542                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4543             }
4544         }
4545         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4546
4547         if (ret == pmerecvqxFINISH)
4548         {
4549             /* We should stop: break out of the loop */
4550             break;
4551         }
4552
4553         step_rel = step - ir->init_step;
4554
4555         if (count == 0)
4556         {
4557             wallcycle_start(wcycle, ewcRUN);
4558             walltime_accounting_start(walltime_accounting);
4559         }
4560
4561         wallcycle_start(wcycle, ewcPMEMESH);
4562
4563         dvdlambda_q  = 0;
4564         dvdlambda_lj = 0;
4565         clear_mat(vir_q);
4566         clear_mat(vir_lj);
4567
4568         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4569                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4570                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4571                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4572                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4573                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4574
4575         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4576
4577         gmx_pme_send_force_vir_ener(pme_pp,
4578                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4579                                     dvdlambda_q, dvdlambda_lj, cycles);
4580
4581         count++;
4582     } /***** end of quasi-loop, we stop with the break above */
4583     while (TRUE);
4584
4585     walltime_accounting_end(walltime_accounting);
4586
4587     return 0;
4588 }
4589
4590 static void
4591 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4592 {
4593     int  i;
4594
4595     for (i = 0; i < pme->atc[0].n; ++i)
4596     {
4597         real sigma4;
4598
4599         sigma4           = local_sigma[i];
4600         sigma4           = sigma4*sigma4;
4601         sigma4           = sigma4*sigma4;
4602         pme->atc[0].q[i] = local_c6[i] / sigma4;
4603     }
4604 }
4605
4606 static void
4607 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4608 {
4609     int  i;
4610
4611     for (i = 0; i < pme->atc[0].n; ++i)
4612     {
4613         pme->atc[0].q[i] *= local_sigma[i];
4614     }
4615 }
4616
4617 static void
4618 do_redist_x_q(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4619               gmx_bool bFirst, rvec x[], real *data)
4620 {
4621     int      d;
4622     pme_atomcomm_t *atc;
4623     atc = &pme->atc[0];
4624
4625     for (d = pme->ndecompdim - 1; d >= 0; d--)
4626     {
4627         int             n_d;
4628         rvec           *x_d;
4629         real           *q_d;
4630
4631         if (d == pme->ndecompdim - 1)
4632         {
4633             n_d = homenr;
4634             x_d = x + start;
4635             q_d = data;
4636         }
4637         else
4638         {
4639             n_d = pme->atc[d + 1].n;
4640             x_d = atc->x;
4641             q_d = atc->q;
4642         }
4643         atc      = &pme->atc[d];
4644         atc->npd = n_d;
4645         if (atc->npd > atc->pd_nalloc)
4646         {
4647             atc->pd_nalloc = over_alloc_dd(atc->npd);
4648             srenew(atc->pd, atc->pd_nalloc);
4649         }
4650         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4651         where();
4652         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4653         if (DOMAINDECOMP(cr))
4654         {
4655             dd_pmeredist_x_q(pme, n_d, bFirst, x_d, q_d, atc);
4656         }
4657     }
4658 }
4659
4660 /* TODO: when adding free-energy support, sigmaB will no longer be
4661  * unused */
4662 int gmx_pme_do(gmx_pme_t pme,
4663                int start,       int homenr,
4664                rvec x[],        rvec f[],
4665                real *chargeA,   real *chargeB,
4666                real *c6A,       real *c6B,
4667                real *sigmaA,    real gmx_unused *sigmaB,
4668                matrix box, t_commrec *cr,
4669                int  maxshift_x, int maxshift_y,
4670                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4671                matrix vir_q,      real ewaldcoeff_q,
4672                matrix vir_lj,   real ewaldcoeff_lj,
4673                real *energy_q,  real *energy_lj,
4674                real lambda_q, real lambda_lj,
4675                real *dvdlambda_q, real *dvdlambda_lj,
4676                int flags)
4677 {
4678     int     q, d, i, j, ntot, npme, qmax;
4679     int     nx, ny, nz;
4680     int     n_d, local_ny;
4681     pme_atomcomm_t *atc = NULL;
4682     pmegrids_t *pmegrid = NULL;
4683     real    *grid       = NULL;
4684     real    *ptr;
4685     rvec    *x_d, *f_d;
4686     real    *charge = NULL, *q_d;
4687     real    energy_AB[4];
4688     matrix  vir_AB[4];
4689     real    scale, lambda;
4690     gmx_bool bClearF;
4691     gmx_parallel_3dfft_t pfft_setup;
4692     real *  fftgrid;
4693     t_complex * cfftgrid;
4694     int     thread;
4695     gmx_bool bFirst, bDoSplines;
4696     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4697     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4698
4699     assert(pme->nnodes > 0);
4700     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4701
4702     if (pme->nnodes > 1)
4703     {
4704         atc      = &pme->atc[0];
4705         atc->npd = homenr;
4706         if (atc->npd > atc->pd_nalloc)
4707         {
4708             atc->pd_nalloc = over_alloc_dd(atc->npd);
4709             srenew(atc->pd, atc->pd_nalloc);
4710         }
4711         for (d = pme->ndecompdim-1; d >= 0; d--)
4712         {
4713             atc           = &pme->atc[d];
4714             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4715         }
4716     }
4717     else
4718     {
4719         atc = &pme->atc[0];
4720         /* This could be necessary for TPI */
4721         pme->atc[0].n = homenr;
4722         if (DOMAINDECOMP(cr))
4723         {
4724             pme_realloc_atomcomm_things(atc);
4725         }
4726         atc->x = x;
4727         atc->f = f;
4728     }
4729
4730     m_inv_ur0(box, pme->recipbox);
4731     bFirst = TRUE;
4732
4733     /* For simplicity, we construct the splines for all particles if
4734      * more than one PME calculations is needed. Some optimization
4735      * could be done by keeping track of which atoms have splines
4736      * constructed, and construct new splines on each pass for atoms
4737      * that don't yet have them.
4738      */
4739
4740     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4741
4742     /* We need a maximum of four separate PME calculations:
4743      * q=0: Coulomb PME with charges from state A
4744      * q=1: Coulomb PME with charges from state B
4745      * q=2: LJ PME with C6 from state A
4746      * q=3: LJ PME with C6 from state B
4747      * For Lorentz-Berthelot combination rules, a separate loop is used to
4748      * calculate all the terms
4749      */
4750
4751     /* If we are doing LJ-PME with LB, we only do Q here */
4752     qmax = (flags & GMX_PME_LJ_LB) ? DO_Q : DO_Q_AND_LJ;
4753
4754     for (q = 0; q < qmax; ++q)
4755     {
4756         /* Check if we should do calculations at this q
4757          * If q is odd we should be doing FEP
4758          * If q < 2 we should be doing electrostatic PME
4759          * If q >= 2 we should be doing LJ-PME
4760          */
4761         if ((!pme->bFEP_q && q == 1)
4762             || (!pme->bFEP_lj && q == 3)
4763             || (!(flags & GMX_PME_DO_COULOMB) && q < 2)
4764             || (!(flags & GMX_PME_DO_LJ) && q >= 2))
4765         {
4766             continue;
4767         }
4768         /* Unpack structure */
4769         pmegrid    = &pme->pmegrid[q];
4770         fftgrid    = pme->fftgrid[q];
4771         cfftgrid   = pme->cfftgrid[q];
4772         pfft_setup = pme->pfft_setup[q];
4773         switch (q)
4774         {
4775             case 0: charge = chargeA + start; break;
4776             case 1: charge = chargeB + start; break;
4777             case 2: charge = c6A + start; break;
4778             case 3: charge = c6B + start; break;
4779         }
4780
4781         grid = pmegrid->grid.grid;
4782
4783         if (debug)
4784         {
4785             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4786                     cr->nnodes, cr->nodeid);
4787             fprintf(debug, "Grid = %p\n", (void*)grid);
4788             if (grid == NULL)
4789             {
4790                 gmx_fatal(FARGS, "No grid!");
4791             }
4792         }
4793         where();
4794
4795         if (pme->nnodes == 1)
4796         {
4797             atc->q = charge;
4798         }
4799         else
4800         {
4801             wallcycle_start(wcycle, ewcPME_REDISTXF);
4802             do_redist_x_q(pme, cr, start, homenr, bFirst, x, charge);
4803             where();
4804
4805             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4806         }
4807
4808         if (debug)
4809         {
4810             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4811                     cr->nodeid, atc->n);
4812         }
4813
4814         if (flags & GMX_PME_SPREAD_Q)
4815         {
4816             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4817
4818             /* Spread the charges on a grid */
4819             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines);
4820
4821             if (bFirst)
4822             {
4823                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4824             }
4825             inc_nrnb(nrnb, eNR_SPREADQBSP,
4826                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4827
4828             if (!pme->bUseThreads)
4829             {
4830                 wrap_periodic_pmegrid(pme, grid);
4831
4832                 /* sum contributions to local grid from other nodes */
4833 #ifdef GMX_MPI
4834                 if (pme->nnodes > 1)
4835                 {
4836                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4837                     where();
4838                 }
4839 #endif
4840
4841                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, q);
4842             }
4843
4844             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4845
4846             /*
4847                dump_local_fftgrid(pme,fftgrid);
4848                exit(0);
4849              */
4850         }
4851
4852         /* Here we start a large thread parallel region */
4853 #pragma omp parallel num_threads(pme->nthread) private(thread)
4854         {
4855             thread = gmx_omp_get_thread_num();
4856             if (flags & GMX_PME_SOLVE)
4857             {
4858                 int loop_count;
4859
4860                 /* do 3d-fft */
4861                 if (thread == 0)
4862                 {
4863                     wallcycle_start(wcycle, ewcPME_FFT);
4864                 }
4865                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4866                                            thread, wcycle);
4867                 if (thread == 0)
4868                 {
4869                     wallcycle_stop(wcycle, ewcPME_FFT);
4870                 }
4871                 where();
4872
4873                 /* solve in k-space for our local cells */
4874                 if (thread == 0)
4875                 {
4876                     wallcycle_start(wcycle, (q < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4877                 }
4878                 if (q < DO_Q)
4879                 {
4880                     loop_count =
4881                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
4882                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4883                                       bCalcEnerVir,
4884                                       pme->nthread, thread);
4885                 }
4886                 else
4887                 {
4888                     loop_count =
4889                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
4890                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4891                                          bCalcEnerVir,
4892                                          pme->nthread, thread);
4893                 }
4894
4895                 if (thread == 0)
4896                 {
4897                     wallcycle_stop(wcycle, (q < DO_Q ? ewcPME_SOLVE : ewcLJPME));
4898                     where();
4899                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4900                 }
4901             }
4902
4903             if (bCalcF)
4904             {
4905                 /* do 3d-invfft */
4906                 if (thread == 0)
4907                 {
4908                     where();
4909                     wallcycle_start(wcycle, ewcPME_FFT);
4910                 }
4911                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4912                                            thread, wcycle);
4913                 if (thread == 0)
4914                 {
4915                     wallcycle_stop(wcycle, ewcPME_FFT);
4916
4917                     where();
4918
4919                     if (pme->nodeid == 0)
4920                     {
4921                         ntot  = pme->nkx*pme->nky*pme->nkz;
4922                         npme  = ntot*log((real)ntot)/log(2.0);
4923                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4924                     }
4925
4926                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4927                 }
4928
4929                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, q, pme->nthread, thread);
4930             }
4931         }
4932         /* End of thread parallel section.
4933          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4934          */
4935
4936         if (bCalcF)
4937         {
4938             /* distribute local grid to all nodes */
4939 #ifdef GMX_MPI
4940             if (pme->nnodes > 1)
4941             {
4942                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4943             }
4944 #endif
4945             where();
4946
4947             unwrap_periodic_pmegrid(pme, grid);
4948
4949             /* interpolate forces for our local atoms */
4950
4951             where();
4952
4953             /* If we are running without parallelization,
4954              * atc->f is the actual force array, not a buffer,
4955              * therefore we should not clear it.
4956              */
4957             lambda  = q < DO_Q ? lambda_q : lambda_lj;
4958             bClearF = (bFirst && PAR(cr));
4959 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4960             for (thread = 0; thread < pme->nthread; thread++)
4961             {
4962                 gather_f_bsplines(pme, grid, bClearF, atc,
4963                                   &atc->spline[thread],
4964                                   pme->bFEP ? (q % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
4965             }
4966
4967             where();
4968
4969             inc_nrnb(nrnb, eNR_GATHERFBSP,
4970                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4971             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4972         }
4973
4974         if (bCalcEnerVir)
4975         {
4976             /* This should only be called on the master thread
4977              * and after the threads have synchronized.
4978              */
4979             if (q < 2)
4980             {
4981                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4982             }
4983             else
4984             {
4985                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4986             }
4987         }
4988         bFirst = FALSE;
4989     } /* of q-loop */
4990
4991     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
4992      * seven terms. */
4993
4994     if (flags & GMX_PME_LJ_LB)
4995     {
4996         real *local_c6 = NULL, *local_sigma = NULL;
4997         if (pme->nnodes == 1)
4998         {
4999             if (pme->lb_buf1 == NULL)
5000             {
5001                 pme->lb_buf_nalloc = pme->atc[0].n;
5002                 snew(pme->lb_buf1, pme->lb_buf_nalloc);
5003             }
5004             pme->atc[0].q = pme->lb_buf1;
5005             local_c6      = c6A;
5006             local_sigma   = sigmaA;
5007         }
5008         else
5009         {
5010             atc = &pme->atc[0];
5011
5012             wallcycle_start(wcycle, ewcPME_REDISTXF);
5013
5014             do_redist_x_q(pme, cr, start, homenr, bFirst, x, c6A);
5015             if (pme->lb_buf_nalloc < atc->n)
5016             {
5017                 pme->lb_buf_nalloc = atc->nalloc;
5018                 srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5019                 srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5020             }
5021             local_c6 = pme->lb_buf1;
5022             for (i = 0; i < atc->n; ++i)
5023             {
5024                 local_c6[i] = atc->q[i];
5025             }
5026             where();
5027
5028             do_redist_x_q(pme, cr, start, homenr, FALSE, x, sigmaA);
5029             local_sigma = pme->lb_buf2;
5030             for (i = 0; i < atc->n; ++i)
5031             {
5032                 local_sigma[i] = atc->q[i];
5033             }
5034             where();
5035
5036             wallcycle_stop(wcycle, ewcPME_REDISTXF);
5037         }
5038         calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5039
5040         /*Seven terms in LJ-PME with LB, q < 2 reserved for electrostatics*/
5041         for (q = 2; q < 9; ++q)
5042         {
5043             /* Unpack structure */
5044             pmegrid    = &pme->pmegrid[q];
5045             fftgrid    = pme->fftgrid[q];
5046             cfftgrid   = pme->cfftgrid[q];
5047             pfft_setup = pme->pfft_setup[q];
5048             calc_next_lb_coeffs(pme, local_sigma);
5049             grid = pmegrid->grid.grid;
5050             where();
5051
5052             if (flags & GMX_PME_SPREAD_Q)
5053             {
5054                 wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5055                 /* Spread the charges on a grid */
5056                 spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines);
5057
5058                 if (bFirst)
5059                 {
5060                     inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5061                 }
5062                 inc_nrnb(nrnb, eNR_SPREADQBSP,
5063                          pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5064                 if (pme->nthread == 1)
5065                 {
5066                     wrap_periodic_pmegrid(pme, grid);
5067
5068                     /* sum contributions to local grid from other nodes */
5069 #ifdef GMX_MPI
5070                     if (pme->nnodes > 1)
5071                     {
5072                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
5073                         where();
5074                     }
5075 #endif
5076                     copy_pmegrid_to_fftgrid(pme, grid, fftgrid, q);
5077                 }
5078
5079                 wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5080             }
5081
5082             /*Here we start a large thread parallel region*/
5083 #pragma omp parallel num_threads(pme->nthread) private(thread)
5084             {
5085                 thread = gmx_omp_get_thread_num();
5086                 if (flags & GMX_PME_SOLVE)
5087                 {
5088                     /* do 3d-fft */
5089                     if (thread == 0)
5090                     {
5091                         wallcycle_start(wcycle, ewcPME_FFT);
5092                     }
5093
5094                     gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5095                                                thread, wcycle);
5096                     if (thread == 0)
5097                     {
5098                         wallcycle_stop(wcycle, ewcPME_FFT);
5099                     }
5100                     where();
5101                 }
5102             }
5103             bFirst = FALSE;
5104         }
5105         if (flags & GMX_PME_SOLVE)
5106         {
5107             int loop_count;
5108             /* solve in k-space for our local cells */
5109 #pragma omp parallel num_threads(pme->nthread) private(thread)
5110             {
5111                 thread = gmx_omp_get_thread_num();
5112                 if (thread == 0)
5113                 {
5114                     wallcycle_start(wcycle, ewcLJPME);
5115                 }
5116
5117                 loop_count =
5118                     solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5119                                      box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5120                                      bCalcEnerVir,
5121                                      pme->nthread, thread);
5122                 if (thread == 0)
5123                 {
5124                     wallcycle_stop(wcycle, ewcLJPME);
5125                     where();
5126                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5127                 }
5128             }
5129         }
5130
5131         if (bCalcEnerVir)
5132         {
5133             /* This should only be called on the master thread and
5134              * after the threads have synchronized.
5135              */
5136             get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2], vir_AB[2]);
5137         }
5138
5139         if (bCalcF)
5140         {
5141             bFirst = !(flags & GMX_PME_DO_COULOMB);
5142             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5143             for (q = 8; q >= 2; --q)
5144             {
5145                 /* Unpack structure */
5146                 pmegrid    = &pme->pmegrid[q];
5147                 fftgrid    = pme->fftgrid[q];
5148                 cfftgrid   = pme->cfftgrid[q];
5149                 pfft_setup = pme->pfft_setup[q];
5150                 grid       = pmegrid->grid.grid;
5151                 calc_next_lb_coeffs(pme, local_sigma);
5152                 where();
5153 #pragma omp parallel num_threads(pme->nthread) private(thread)
5154                 {
5155                     thread = gmx_omp_get_thread_num();
5156                     /* do 3d-invfft */
5157                     if (thread == 0)
5158                     {
5159                         where();
5160                         wallcycle_start(wcycle, ewcPME_FFT);
5161                     }
5162
5163                     gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5164                                                thread, wcycle);
5165                     if (thread == 0)
5166                     {
5167                         wallcycle_stop(wcycle, ewcPME_FFT);
5168
5169                         where();
5170
5171                         if (pme->nodeid == 0)
5172                         {
5173                             ntot  = pme->nkx*pme->nky*pme->nkz;
5174                             npme  = ntot*log((real)ntot)/log(2.0);
5175                             inc_nrnb(nrnb, eNR_FFT, 2*npme);
5176                         }
5177                         wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5178                     }
5179
5180                     copy_fftgrid_to_pmegrid(pme, fftgrid, grid, q, pme->nthread, thread);
5181
5182                 } /*#pragma omp parallel*/
5183
5184                 /* distribute local grid to all nodes */
5185 #ifdef GMX_MPI
5186                 if (pme->nnodes > 1)
5187                 {
5188                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
5189                 }
5190 #endif
5191                 where();
5192
5193                 unwrap_periodic_pmegrid(pme, grid);
5194
5195                 /* interpolate forces for our local atoms */
5196                 where();
5197                 bClearF = (bFirst && PAR(cr));
5198                 scale   = pme->bFEP ? (q < 9 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5199                 scale  *= lb_scale_factor[q-2];
5200 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5201                 for (thread = 0; thread < pme->nthread; thread++)
5202                 {
5203                     gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5204                                       &pme->atc[0].spline[thread],
5205                                       scale);
5206                 }
5207                 where();
5208
5209                 inc_nrnb(nrnb, eNR_GATHERFBSP,
5210                          pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5211                 wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5212
5213                 bFirst = FALSE;
5214             } /*for (q = 8; q >= 2; --q)*/
5215         }     /*if (bCalcF)*/
5216     }         /*if (flags & GMX_PME_LJ_LB)*/
5217
5218     if (bCalcF && pme->nnodes > 1)
5219     {
5220         wallcycle_start(wcycle, ewcPME_REDISTXF);
5221         for (d = 0; d < pme->ndecompdim; d++)
5222         {
5223             atc = &pme->atc[d];
5224             if (d == pme->ndecompdim - 1)
5225             {
5226                 n_d = homenr;
5227                 f_d = f + start;
5228             }
5229             else
5230             {
5231                 n_d = pme->atc[d+1].n;
5232                 f_d = pme->atc[d+1].f;
5233             }
5234             if (DOMAINDECOMP(cr))
5235             {
5236                 dd_pmeredist_f(pme, atc, n_d, f_d,
5237                                d == pme->ndecompdim-1 && pme->bPPnode);
5238             }
5239         }
5240
5241         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5242     }
5243     where();
5244
5245     if (bCalcEnerVir)
5246     {
5247         if (flags & GMX_PME_DO_COULOMB)
5248         {
5249             if (!pme->bFEP_q)
5250             {
5251                 *energy_q = energy_AB[0];
5252                 m_add(vir_q, vir_AB[0], vir_q);
5253             }
5254             else
5255             {
5256                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5257                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5258                 for (i = 0; i < DIM; i++)
5259                 {
5260                     for (j = 0; j < DIM; j++)
5261                     {
5262                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5263                             lambda_q*vir_AB[1][i][j];
5264                     }
5265                 }
5266             }
5267             if (debug)
5268             {
5269                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5270             }
5271         }
5272         else
5273         {
5274             *energy_q = 0;
5275         }
5276
5277         if (flags & GMX_PME_DO_LJ)
5278         {
5279             if (!pme->bFEP_lj)
5280             {
5281                 *energy_lj = energy_AB[2];
5282                 m_add(vir_lj, vir_AB[2], vir_lj);
5283             }
5284             else
5285             {
5286                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5287                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5288                 for (i = 0; i < DIM; i++)
5289                 {
5290                     for (j = 0; j < DIM; j++)
5291                     {
5292                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5293                     }
5294                 }
5295             }
5296             if (debug)
5297             {
5298                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5299             }
5300         }
5301         else
5302         {
5303             *energy_lj = 0;
5304         }
5305     }
5306     return 0;
5307 }