Merge release-4-6 into master
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* IMPORTANT FOR DEVELOPERS:
38  *
39  * Triclinic pme stuff isn't entirely trivial, and we've experienced
40  * some bugs during development (many of them due to me). To avoid
41  * this in the future, please check the following things if you make
42  * changes in this file:
43  *
44  * 1. You should obtain identical (at least to the PME precision)
45  *    energies, forces, and virial for
46  *    a rectangular box and a triclinic one where the z (or y) axis is
47  *    tilted a whole box side. For instance you could use these boxes:
48  *
49  *    rectangular       triclinic
50  *     2  0  0           2  0  0
51  *     0  2  0           0  2  0
52  *     0  0  6           2  2  6
53  *
54  * 2. You should check the energy conservation in a triclinic box.
55  *
56  * It might seem an overkill, but better safe than sorry.
57  * /Erik 001109
58  */
59
60 #ifdef HAVE_CONFIG_H
61 #include <config.h>
62 #endif
63
64 #include <stdio.h>
65 #include <string.h>
66 #include <math.h>
67 #include <assert.h>
68 #include "typedefs.h"
69 #include "txtdump.h"
70 #include "vec.h"
71 #include "gmxcomplex.h"
72 #include "smalloc.h"
73 #include "coulomb.h"
74 #include "gmx_fatal.h"
75 #include "pme.h"
76 #include "network.h"
77 #include "physics.h"
78 #include "nrnb.h"
79 #include "macros.h"
80
81 #include "gromacs/fft/parallel_3dfft.h"
82 #include "gromacs/fileio/futil.h"
83 #include "gromacs/fileio/pdbio.h"
84 #include "gromacs/timing/cyclecounter.h"
85 #include "gromacs/timing/wallcycle.h"
86 #include "gromacs/utility/gmxmpi.h"
87 #include "gromacs/utility/gmxomp.h"
88
89 /* Include the SIMD macro file and then check for support */
90 #include "gromacs/simd/macros.h"
91 #if defined GMX_HAVE_SIMD_MACROS && defined GMX_SIMD_HAVE_EXP
92 /* Turn on arbitrary width SIMD intrinsics for PME solve */
93 #define PME_SIMD
94 #endif
95
96 #define PME_GRID_QA    0 /* Gridindex for A-state for Q */
97 #define PME_GRID_C6A   2 /* Gridindex for A-state for LJ */
98 #define DO_Q           2 /* Electrostatic grids have index q<2 */
99 #define DO_Q_AND_LJ    4 /* non-LB LJ grids have index 2 <= q < 4 */
100 #define DO_Q_AND_LJ_LB 9 /* With LB rules we need a total of 2+7 grids */
101
102 /* Pascal triangle coefficients scaled with (1/2)^6 for LJ-PME with LB-rules */
103 const real lb_scale_factor[] = {
104     1.0/64, 6.0/64, 15.0/64, 20.0/64,
105     15.0/64, 6.0/64, 1.0/64
106 };
107
108 /* Pascal triangle coefficients used in solve_pme_lj_yzx, only need to do 4 calculations due to symmetry */
109 const real lb_scale_factor_symm[] = { 2.0/64, 12.0/64, 30.0/64, 20.0/64 };
110
111 /* Include the 4-wide SIMD macro file */
112 #include "gromacs/simd/four_wide_macros.h"
113 /* Check if we have 4-wide SIMD macro support */
114 #ifdef GMX_HAVE_SIMD4_MACROS
115 /* Do PME spread and gather with 4-wide SIMD.
116  * NOTE: SIMD is only used with PME order 4 and 5 (which are the most common).
117  */
118 #define PME_SIMD4_SPREAD_GATHER
119
120 #ifdef GMX_SIMD4_HAVE_UNALIGNED
121 /* With PME-order=4 on x86, unaligned load+store is slightly faster
122  * than doubling all SIMD operations when using aligned load+store.
123  */
124 #define PME_SIMD4_UNALIGNED
125 #endif
126 #endif
127
128 #define DFT_TOL 1e-7
129 /* #define PRT_FORCE */
130 /* conditions for on the fly time-measurement */
131 /* #define TAKETIME (step > 1 && timesteps < 10) */
132 #define TAKETIME FALSE
133
134 /* #define PME_TIME_THREADS */
135
136 #ifdef GMX_DOUBLE
137 #define mpi_type MPI_DOUBLE
138 #else
139 #define mpi_type MPI_FLOAT
140 #endif
141
142 #ifdef PME_SIMD4_SPREAD_GATHER
143 #define SIMD4_ALIGNMENT  (GMX_SIMD4_WIDTH*sizeof(real))
144 #else
145 /* We can use any alignment, apart from 0, so we use 4 reals */
146 #define SIMD4_ALIGNMENT  (4*sizeof(real))
147 #endif
148
149 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
150  * to preserve alignment.
151  */
152 #define GMX_CACHE_SEP 64
153
154 /* We only define a maximum to be able to use local arrays without allocation.
155  * An order larger than 12 should never be needed, even for test cases.
156  * If needed it can be changed here.
157  */
158 #define PME_ORDER_MAX 12
159
160 /* Internal datastructures */
161 typedef struct {
162     int send_index0;
163     int send_nindex;
164     int recv_index0;
165     int recv_nindex;
166     int recv_size;   /* Receive buffer width, used with OpenMP */
167 } pme_grid_comm_t;
168
169 typedef struct {
170 #ifdef GMX_MPI
171     MPI_Comm         mpi_comm;
172 #endif
173     int              nnodes, nodeid;
174     int             *s2g0;
175     int             *s2g1;
176     int              noverlap_nodes;
177     int             *send_id, *recv_id;
178     int              send_size; /* Send buffer width, used with OpenMP */
179     pme_grid_comm_t *comm_data;
180     real            *sendbuf;
181     real            *recvbuf;
182 } pme_overlap_t;
183
184 typedef struct {
185     int *n;      /* Cumulative counts of the number of particles per thread */
186     int  nalloc; /* Allocation size of i */
187     int *i;      /* Particle indices ordered on thread index (n) */
188 } thread_plist_t;
189
190 typedef struct {
191     int      *thread_one;
192     int       n;
193     int      *ind;
194     splinevec theta;
195     real     *ptr_theta_z;
196     splinevec dtheta;
197     real     *ptr_dtheta_z;
198 } splinedata_t;
199
200 typedef struct {
201     int      dimind;        /* The index of the dimension, 0=x, 1=y */
202     int      nslab;
203     int      nodeid;
204 #ifdef GMX_MPI
205     MPI_Comm mpi_comm;
206 #endif
207
208     int     *node_dest;     /* The nodes to send x and q to with DD */
209     int     *node_src;      /* The nodes to receive x and q from with DD */
210     int     *buf_index;     /* Index for commnode into the buffers */
211
212     int      maxshift;
213
214     int      npd;
215     int      pd_nalloc;
216     int     *pd;
217     int     *count;         /* The number of atoms to send to each node */
218     int    **count_thread;
219     int     *rcount;        /* The number of atoms to receive */
220
221     int      n;
222     int      nalloc;
223     rvec    *x;
224     real    *q;
225     rvec    *f;
226     gmx_bool bSpread;       /* These coordinates are used for spreading */
227     int      pme_order;
228     ivec    *idx;
229     rvec    *fractx;            /* Fractional coordinate relative to
230                                  * the lower cell boundary
231                                  */
232     int             nthread;
233     int            *thread_idx; /* Which thread should spread which charge */
234     thread_plist_t *thread_plist;
235     splinedata_t   *spline;
236 } pme_atomcomm_t;
237
238 #define FLBS  3
239 #define FLBSZ 4
240
241 typedef struct {
242     ivec  ci;     /* The spatial location of this grid         */
243     ivec  n;      /* The used size of *grid, including order-1 */
244     ivec  offset; /* The grid offset from the full node grid   */
245     int   order;  /* PME spreading order                       */
246     ivec  s;      /* The allocated size of *grid, s >= n       */
247     real *grid;   /* The grid local thread, size n             */
248 } pmegrid_t;
249
250 typedef struct {
251     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
252     int        nthread;      /* The number of threads operating on this grid     */
253     ivec       nc;           /* The local spatial decomposition over the threads */
254     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
255     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
256     int      **g2t;          /* The grid to thread index                         */
257     ivec       nthread_comm; /* The number of threads to communicate with        */
258 } pmegrids_t;
259
260 typedef struct {
261 #ifdef PME_SIMD4_SPREAD_GATHER
262     /* Masks for 4-wide SIMD aligned spreading and gathering */
263     gmx_simd4_pb mask_S0[6], mask_S1[6];
264 #else
265     int          dummy; /* C89 requires that struct has at least one member */
266 #endif
267 } pme_spline_work_t;
268
269 typedef struct {
270     /* work data for solve_pme */
271     int      nalloc;
272     real *   mhx;
273     real *   mhy;
274     real *   mhz;
275     real *   m2;
276     real *   denom;
277     real *   tmp1_alloc;
278     real *   tmp1;
279     real *   tmp2;
280     real *   eterm;
281     real *   m2inv;
282
283     real     energy_q;
284     matrix   vir_q;
285     real     energy_lj;
286     matrix   vir_lj;
287 } pme_work_t;
288
289 typedef struct gmx_pme {
290     int           ndecompdim; /* The number of decomposition dimensions */
291     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
292     int           nodeid_major;
293     int           nodeid_minor;
294     int           nnodes;    /* The number of nodes doing PME */
295     int           nnodes_major;
296     int           nnodes_minor;
297
298     MPI_Comm      mpi_comm;
299     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
300 #ifdef GMX_MPI
301     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
302 #endif
303
304     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
305     int        nthread;       /* The number of threads doing PME on our rank */
306
307     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
308     gmx_bool   bFEP;          /* Compute Free energy contribution */
309     gmx_bool   bFEP_q;
310     gmx_bool   bFEP_lj;
311     int        nkx, nky, nkz; /* Grid dimensions */
312     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
313     int        pme_order;
314     real       epsilon_r;
315
316     int        ngrids;                  /* number of grids we maintain for pmegrid, (c)fftgrid and pfft_setups*/
317
318     pmegrids_t pmegrid[DO_Q_AND_LJ_LB]; /* Grids on which we do spreading/interpolation,
319                                          * includes overlap Grid indices are ordered as
320                                          * follows:
321                                          * 0: Coloumb PME, state A
322                                          * 1: Coloumb PME, state B
323                                          * 2-8: LJ-PME
324                                          * This can probably be done in a better way
325                                          * but this simple hack works for now
326                                          */
327     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
328     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
329     /* pmegrid_nz might be larger than strictly necessary to ensure
330      * memory alignment, pmegrid_nz_base gives the real base size.
331      */
332     int     pmegrid_nz_base;
333     /* The local PME grid starting indices */
334     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
335
336     /* Work data for spreading and gathering */
337     pme_spline_work_t     *spline_work;
338
339     real                 **fftgrid; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
340     /* inside the interpolation grid, but separate for 2D PME decomp. */
341     int                    fftgrid_nx, fftgrid_ny, fftgrid_nz;
342
343     t_complex            **cfftgrid;  /* Grids for complex FFT data */
344
345     int                    cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
346
347     gmx_parallel_3dfft_t  *pfft_setup;
348
349     int                   *nnx, *nny, *nnz;
350     real                  *fshx, *fshy, *fshz;
351
352     pme_atomcomm_t         atc[2]; /* Indexed on decomposition index */
353     matrix                 recipbox;
354     splinevec              bsp_mod;
355     /* Buffers to store data for local atoms for L-B combination rule
356      * calculations in LJ-PME. lb_buf1 stores either the coefficients
357      * for spreading/gathering (in serial), or the C6 parameters for
358      * local atoms (in parallel).  lb_buf2 is only used in parallel,
359      * and stores the sigma values for local atoms. */
360     real                 *lb_buf1, *lb_buf2;
361     int                   lb_buf_nalloc; /* Allocation size for the above buffers. */
362
363     pme_overlap_t         overlap[2];    /* Indexed on dimension, 0=x, 1=y */
364
365     pme_atomcomm_t        atc_energy;    /* Only for gmx_pme_calc_energy */
366
367     rvec                 *bufv;          /* Communication buffer */
368     real                 *bufr;          /* Communication buffer */
369     int                   buf_nalloc;    /* The communication buffer size */
370
371     /* thread local work data for solve_pme */
372     pme_work_t *work;
373
374     /* Work data for PME_redist */
375     gmx_bool redist_init;
376     int *    scounts;
377     int *    rcounts;
378     int *    sdispls;
379     int *    rdispls;
380     int *    sidx;
381     int *    idxa;
382     real *   redist_buf;
383     int      redist_buf_nalloc;
384
385     /* Work data for sum_qgrid */
386     real *   sum_qgrid_tmp;
387     real *   sum_qgrid_dd_tmp;
388 } t_gmx_pme;
389
390 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
391                                    int start, int end, int thread)
392 {
393     int             i;
394     int            *idxptr, tix, tiy, tiz;
395     real           *xptr, *fptr, tx, ty, tz;
396     real            rxx, ryx, ryy, rzx, rzy, rzz;
397     int             nx, ny, nz;
398     int             start_ix, start_iy, start_iz;
399     int            *g2tx, *g2ty, *g2tz;
400     gmx_bool        bThreads;
401     int            *thread_idx = NULL;
402     thread_plist_t *tpl        = NULL;
403     int            *tpl_n      = NULL;
404     int             thread_i;
405
406     nx  = pme->nkx;
407     ny  = pme->nky;
408     nz  = pme->nkz;
409
410     start_ix = pme->pmegrid_start_ix;
411     start_iy = pme->pmegrid_start_iy;
412     start_iz = pme->pmegrid_start_iz;
413
414     rxx = pme->recipbox[XX][XX];
415     ryx = pme->recipbox[YY][XX];
416     ryy = pme->recipbox[YY][YY];
417     rzx = pme->recipbox[ZZ][XX];
418     rzy = pme->recipbox[ZZ][YY];
419     rzz = pme->recipbox[ZZ][ZZ];
420
421     g2tx = pme->pmegrid[PME_GRID_QA].g2t[XX];
422     g2ty = pme->pmegrid[PME_GRID_QA].g2t[YY];
423     g2tz = pme->pmegrid[PME_GRID_QA].g2t[ZZ];
424
425     bThreads = (atc->nthread > 1);
426     if (bThreads)
427     {
428         thread_idx = atc->thread_idx;
429
430         tpl   = &atc->thread_plist[thread];
431         tpl_n = tpl->n;
432         for (i = 0; i < atc->nthread; i++)
433         {
434             tpl_n[i] = 0;
435         }
436     }
437
438     for (i = start; i < end; i++)
439     {
440         xptr   = atc->x[i];
441         idxptr = atc->idx[i];
442         fptr   = atc->fractx[i];
443
444         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
445         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
446         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
447         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
448
449         tix = (int)(tx);
450         tiy = (int)(ty);
451         tiz = (int)(tz);
452
453         /* Because decomposition only occurs in x and y,
454          * we never have a fraction correction in z.
455          */
456         fptr[XX] = tx - tix + pme->fshx[tix];
457         fptr[YY] = ty - tiy + pme->fshy[tiy];
458         fptr[ZZ] = tz - tiz;
459
460         idxptr[XX] = pme->nnx[tix];
461         idxptr[YY] = pme->nny[tiy];
462         idxptr[ZZ] = pme->nnz[tiz];
463
464 #ifdef DEBUG
465         range_check(idxptr[XX], 0, pme->pmegrid_nx);
466         range_check(idxptr[YY], 0, pme->pmegrid_ny);
467         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
468 #endif
469
470         if (bThreads)
471         {
472             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
473             thread_idx[i] = thread_i;
474             tpl_n[thread_i]++;
475         }
476     }
477
478     if (bThreads)
479     {
480         /* Make a list of particle indices sorted on thread */
481
482         /* Get the cumulative count */
483         for (i = 1; i < atc->nthread; i++)
484         {
485             tpl_n[i] += tpl_n[i-1];
486         }
487         /* The current implementation distributes particles equally
488          * over the threads, so we could actually allocate for that
489          * in pme_realloc_atomcomm_things.
490          */
491         if (tpl_n[atc->nthread-1] > tpl->nalloc)
492         {
493             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
494             srenew(tpl->i, tpl->nalloc);
495         }
496         /* Set tpl_n to the cumulative start */
497         for (i = atc->nthread-1; i >= 1; i--)
498         {
499             tpl_n[i] = tpl_n[i-1];
500         }
501         tpl_n[0] = 0;
502
503         /* Fill our thread local array with indices sorted on thread */
504         for (i = start; i < end; i++)
505         {
506             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
507         }
508         /* Now tpl_n contains the cummulative count again */
509     }
510 }
511
512 static void make_thread_local_ind(pme_atomcomm_t *atc,
513                                   int thread, splinedata_t *spline)
514 {
515     int             n, t, i, start, end;
516     thread_plist_t *tpl;
517
518     /* Combine the indices made by each thread into one index */
519
520     n     = 0;
521     start = 0;
522     for (t = 0; t < atc->nthread; t++)
523     {
524         tpl = &atc->thread_plist[t];
525         /* Copy our part (start - end) from the list of thread t */
526         if (thread > 0)
527         {
528             start = tpl->n[thread-1];
529         }
530         end = tpl->n[thread];
531         for (i = start; i < end; i++)
532         {
533             spline->ind[n++] = tpl->i[i];
534         }
535     }
536
537     spline->n = n;
538 }
539
540
541 static void pme_calc_pidx(int start, int end,
542                           matrix recipbox, rvec x[],
543                           pme_atomcomm_t *atc, int *count)
544 {
545     int   nslab, i;
546     int   si;
547     real *xptr, s;
548     real  rxx, ryx, rzx, ryy, rzy;
549     int  *pd;
550
551     /* Calculate PME task index (pidx) for each grid index.
552      * Here we always assign equally sized slabs to each node
553      * for load balancing reasons (the PME grid spacing is not used).
554      */
555
556     nslab = atc->nslab;
557     pd    = atc->pd;
558
559     /* Reset the count */
560     for (i = 0; i < nslab; i++)
561     {
562         count[i] = 0;
563     }
564
565     if (atc->dimind == 0)
566     {
567         rxx = recipbox[XX][XX];
568         ryx = recipbox[YY][XX];
569         rzx = recipbox[ZZ][XX];
570         /* Calculate the node index in x-dimension */
571         for (i = start; i < end; i++)
572         {
573             xptr   = x[i];
574             /* Fractional coordinates along box vectors */
575             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
576             si    = (int)(s + 2*nslab) % nslab;
577             pd[i] = si;
578             count[si]++;
579         }
580     }
581     else
582     {
583         ryy = recipbox[YY][YY];
584         rzy = recipbox[ZZ][YY];
585         /* Calculate the node index in y-dimension */
586         for (i = start; i < end; i++)
587         {
588             xptr   = x[i];
589             /* Fractional coordinates along box vectors */
590             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
591             si    = (int)(s + 2*nslab) % nslab;
592             pd[i] = si;
593             count[si]++;
594         }
595     }
596 }
597
598 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
599                                   pme_atomcomm_t *atc)
600 {
601     int nthread, thread, slab;
602
603     nthread = atc->nthread;
604
605 #pragma omp parallel for num_threads(nthread) schedule(static)
606     for (thread = 0; thread < nthread; thread++)
607     {
608         pme_calc_pidx(natoms* thread   /nthread,
609                       natoms*(thread+1)/nthread,
610                       recipbox, x, atc, atc->count_thread[thread]);
611     }
612     /* Non-parallel reduction, since nslab is small */
613
614     for (thread = 1; thread < nthread; thread++)
615     {
616         for (slab = 0; slab < atc->nslab; slab++)
617         {
618             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
619         }
620     }
621 }
622
623 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
624 {
625     const int padding = 4;
626     int       i;
627
628     srenew(th[XX], nalloc);
629     srenew(th[YY], nalloc);
630     /* In z we add padding, this is only required for the aligned SIMD code */
631     sfree_aligned(*ptr_z);
632     snew_aligned(*ptr_z, nalloc+2*padding, SIMD4_ALIGNMENT);
633     th[ZZ] = *ptr_z + padding;
634
635     for (i = 0; i < padding; i++)
636     {
637         (*ptr_z)[               i] = 0;
638         (*ptr_z)[padding+nalloc+i] = 0;
639     }
640 }
641
642 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
643 {
644     int i, d;
645
646     srenew(spline->ind, atc->nalloc);
647     /* Initialize the index to identity so it works without threads */
648     for (i = 0; i < atc->nalloc; i++)
649     {
650         spline->ind[i] = i;
651     }
652
653     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
654                       atc->pme_order*atc->nalloc);
655     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
656                       atc->pme_order*atc->nalloc);
657 }
658
659 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
660 {
661     int nalloc_old, i, j, nalloc_tpl;
662
663     /* We have to avoid a NULL pointer for atc->x to avoid
664      * possible fatal errors in MPI routines.
665      */
666     if (atc->n > atc->nalloc || atc->nalloc == 0)
667     {
668         nalloc_old  = atc->nalloc;
669         atc->nalloc = over_alloc_dd(max(atc->n, 1));
670
671         if (atc->nslab > 1)
672         {
673             srenew(atc->x, atc->nalloc);
674             srenew(atc->q, atc->nalloc);
675             srenew(atc->f, atc->nalloc);
676             for (i = nalloc_old; i < atc->nalloc; i++)
677             {
678                 clear_rvec(atc->f[i]);
679             }
680         }
681         if (atc->bSpread)
682         {
683             srenew(atc->fractx, atc->nalloc);
684             srenew(atc->idx, atc->nalloc);
685
686             if (atc->nthread > 1)
687             {
688                 srenew(atc->thread_idx, atc->nalloc);
689             }
690
691             for (i = 0; i < atc->nthread; i++)
692             {
693                 pme_realloc_splinedata(&atc->spline[i], atc);
694             }
695         }
696     }
697 }
698
699 static void pmeredist_pd(gmx_pme_t pme, gmx_bool gmx_unused forw,
700                          int n, gmx_bool gmx_unused bXF, rvec gmx_unused *x_f,
701                          real gmx_unused *charge, pme_atomcomm_t *atc)
702 /* Redistribute particle data for PME calculation */
703 /* domain decomposition by x coordinate           */
704 {
705     int *idxa;
706     int  i, ii;
707
708     if (FALSE == pme->redist_init)
709     {
710         snew(pme->scounts, atc->nslab);
711         snew(pme->rcounts, atc->nslab);
712         snew(pme->sdispls, atc->nslab);
713         snew(pme->rdispls, atc->nslab);
714         snew(pme->sidx, atc->nslab);
715         pme->redist_init = TRUE;
716     }
717     if (n > pme->redist_buf_nalloc)
718     {
719         pme->redist_buf_nalloc = over_alloc_dd(n);
720         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
721     }
722
723     pme->idxa = atc->pd;
724
725 #ifdef GMX_MPI
726     if (forw && bXF)
727     {
728         /* forward, redistribution from pp to pme */
729
730         /* Calculate send counts and exchange them with other nodes */
731         for (i = 0; (i < atc->nslab); i++)
732         {
733             pme->scounts[i] = 0;
734         }
735         for (i = 0; (i < n); i++)
736         {
737             pme->scounts[pme->idxa[i]]++;
738         }
739         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
740
741         /* Calculate send and receive displacements and index into send
742            buffer */
743         pme->sdispls[0] = 0;
744         pme->rdispls[0] = 0;
745         pme->sidx[0]    = 0;
746         for (i = 1; i < atc->nslab; i++)
747         {
748             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
749             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
750             pme->sidx[i]    = pme->sdispls[i];
751         }
752         /* Total # of particles to be received */
753         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
754
755         pme_realloc_atomcomm_things(atc);
756
757         /* Copy particle coordinates into send buffer and exchange*/
758         for (i = 0; (i < n); i++)
759         {
760             ii = DIM*pme->sidx[pme->idxa[i]];
761             pme->sidx[pme->idxa[i]]++;
762             pme->redist_buf[ii+XX] = x_f[i][XX];
763             pme->redist_buf[ii+YY] = x_f[i][YY];
764             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
765         }
766         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
767                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
768                       pme->rvec_mpi, atc->mpi_comm);
769     }
770     if (forw)
771     {
772         /* Copy charge into send buffer and exchange*/
773         for (i = 0; i < atc->nslab; i++)
774         {
775             pme->sidx[i] = pme->sdispls[i];
776         }
777         for (i = 0; (i < n); i++)
778         {
779             ii = pme->sidx[pme->idxa[i]];
780             pme->sidx[pme->idxa[i]]++;
781             pme->redist_buf[ii] = charge[i];
782         }
783         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
784                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
785                       atc->mpi_comm);
786     }
787     else   /* backward, redistribution from pme to pp */
788     {
789         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
790                       pme->redist_buf, pme->scounts, pme->sdispls,
791                       pme->rvec_mpi, atc->mpi_comm);
792
793         /* Copy data from receive buffer */
794         for (i = 0; i < atc->nslab; i++)
795         {
796             pme->sidx[i] = pme->sdispls[i];
797         }
798         for (i = 0; (i < n); i++)
799         {
800             ii          = DIM*pme->sidx[pme->idxa[i]];
801             x_f[i][XX] += pme->redist_buf[ii+XX];
802             x_f[i][YY] += pme->redist_buf[ii+YY];
803             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
804             pme->sidx[pme->idxa[i]]++;
805         }
806     }
807 #endif
808 }
809
810 static void pme_dd_sendrecv(pme_atomcomm_t gmx_unused *atc,
811                             gmx_bool gmx_unused bBackward, int gmx_unused shift,
812                             void gmx_unused *buf_s, int gmx_unused nbyte_s,
813                             void gmx_unused *buf_r, int gmx_unused nbyte_r)
814 {
815 #ifdef GMX_MPI
816     int        dest, src;
817     MPI_Status stat;
818
819     if (bBackward == FALSE)
820     {
821         dest = atc->node_dest[shift];
822         src  = atc->node_src[shift];
823     }
824     else
825     {
826         dest = atc->node_src[shift];
827         src  = atc->node_dest[shift];
828     }
829
830     if (nbyte_s > 0 && nbyte_r > 0)
831     {
832         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
833                      dest, shift,
834                      buf_r, nbyte_r, MPI_BYTE,
835                      src, shift,
836                      atc->mpi_comm, &stat);
837     }
838     else if (nbyte_s > 0)
839     {
840         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
841                  dest, shift,
842                  atc->mpi_comm);
843     }
844     else if (nbyte_r > 0)
845     {
846         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
847                  src, shift,
848                  atc->mpi_comm, &stat);
849     }
850 #endif
851 }
852
853 static void dd_pmeredist_x_q(gmx_pme_t pme,
854                              int n, gmx_bool bX, rvec *x, real *charge,
855                              pme_atomcomm_t *atc)
856 {
857     int *commnode, *buf_index;
858     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
859
860     commnode  = atc->node_dest;
861     buf_index = atc->buf_index;
862
863     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
864
865     nsend = 0;
866     for (i = 0; i < nnodes_comm; i++)
867     {
868         buf_index[commnode[i]] = nsend;
869         nsend                 += atc->count[commnode[i]];
870     }
871     if (bX)
872     {
873         if (atc->count[atc->nodeid] + nsend != n)
874         {
875             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
876                       "This usually means that your system is not well equilibrated.",
877                       n - (atc->count[atc->nodeid] + nsend),
878                       pme->nodeid, 'x'+atc->dimind);
879         }
880
881         if (nsend > pme->buf_nalloc)
882         {
883             pme->buf_nalloc = over_alloc_dd(nsend);
884             srenew(pme->bufv, pme->buf_nalloc);
885             srenew(pme->bufr, pme->buf_nalloc);
886         }
887
888         atc->n = atc->count[atc->nodeid];
889         for (i = 0; i < nnodes_comm; i++)
890         {
891             scount = atc->count[commnode[i]];
892             /* Communicate the count */
893             if (debug)
894             {
895                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
896                         atc->dimind, atc->nodeid, commnode[i], scount);
897             }
898             pme_dd_sendrecv(atc, FALSE, i,
899                             &scount, sizeof(int),
900                             &atc->rcount[i], sizeof(int));
901             atc->n += atc->rcount[i];
902         }
903
904         pme_realloc_atomcomm_things(atc);
905     }
906
907     local_pos = 0;
908     for (i = 0; i < n; i++)
909     {
910         node = atc->pd[i];
911         if (node == atc->nodeid)
912         {
913             /* Copy direct to the receive buffer */
914             if (bX)
915             {
916                 copy_rvec(x[i], atc->x[local_pos]);
917             }
918             atc->q[local_pos] = charge[i];
919             local_pos++;
920         }
921         else
922         {
923             /* Copy to the send buffer */
924             if (bX)
925             {
926                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
927             }
928             pme->bufr[buf_index[node]] = charge[i];
929             buf_index[node]++;
930         }
931     }
932
933     buf_pos = 0;
934     for (i = 0; i < nnodes_comm; i++)
935     {
936         scount = atc->count[commnode[i]];
937         rcount = atc->rcount[i];
938         if (scount > 0 || rcount > 0)
939         {
940             if (bX)
941             {
942                 /* Communicate the coordinates */
943                 pme_dd_sendrecv(atc, FALSE, i,
944                                 pme->bufv[buf_pos], scount*sizeof(rvec),
945                                 atc->x[local_pos], rcount*sizeof(rvec));
946             }
947             /* Communicate the charges */
948             pme_dd_sendrecv(atc, FALSE, i,
949                             pme->bufr+buf_pos, scount*sizeof(real),
950                             atc->q+local_pos, rcount*sizeof(real));
951             buf_pos   += scount;
952             local_pos += atc->rcount[i];
953         }
954     }
955 }
956
957 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
958                            int n, rvec *f,
959                            gmx_bool bAddF)
960 {
961     int *commnode, *buf_index;
962     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
963
964     commnode  = atc->node_dest;
965     buf_index = atc->buf_index;
966
967     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
968
969     local_pos = atc->count[atc->nodeid];
970     buf_pos   = 0;
971     for (i = 0; i < nnodes_comm; i++)
972     {
973         scount = atc->rcount[i];
974         rcount = atc->count[commnode[i]];
975         if (scount > 0 || rcount > 0)
976         {
977             /* Communicate the forces */
978             pme_dd_sendrecv(atc, TRUE, i,
979                             atc->f[local_pos], scount*sizeof(rvec),
980                             pme->bufv[buf_pos], rcount*sizeof(rvec));
981             local_pos += scount;
982         }
983         buf_index[commnode[i]] = buf_pos;
984         buf_pos               += rcount;
985     }
986
987     local_pos = 0;
988     if (bAddF)
989     {
990         for (i = 0; i < n; i++)
991         {
992             node = atc->pd[i];
993             if (node == atc->nodeid)
994             {
995                 /* Add from the local force array */
996                 rvec_inc(f[i], atc->f[local_pos]);
997                 local_pos++;
998             }
999             else
1000             {
1001                 /* Add from the receive buffer */
1002                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
1003                 buf_index[node]++;
1004             }
1005         }
1006     }
1007     else
1008     {
1009         for (i = 0; i < n; i++)
1010         {
1011             node = atc->pd[i];
1012             if (node == atc->nodeid)
1013             {
1014                 /* Copy from the local force array */
1015                 copy_rvec(atc->f[local_pos], f[i]);
1016                 local_pos++;
1017             }
1018             else
1019             {
1020                 /* Copy from the receive buffer */
1021                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
1022                 buf_index[node]++;
1023             }
1024         }
1025     }
1026 }
1027
1028 #ifdef GMX_MPI
1029 static void gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
1030 {
1031     pme_overlap_t *overlap;
1032     int            send_index0, send_nindex;
1033     int            recv_index0, recv_nindex;
1034     MPI_Status     stat;
1035     int            i, j, k, ix, iy, iz, icnt;
1036     int            ipulse, send_id, recv_id, datasize;
1037     real          *p;
1038     real          *sendptr, *recvptr;
1039
1040     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
1041     overlap = &pme->overlap[1];
1042
1043     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1044     {
1045         /* Since we have already (un)wrapped the overlap in the z-dimension,
1046          * we only have to communicate 0 to nkz (not pmegrid_nz).
1047          */
1048         if (direction == GMX_SUM_QGRID_FORWARD)
1049         {
1050             send_id       = overlap->send_id[ipulse];
1051             recv_id       = overlap->recv_id[ipulse];
1052             send_index0   = overlap->comm_data[ipulse].send_index0;
1053             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1054             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1055             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1056         }
1057         else
1058         {
1059             send_id       = overlap->recv_id[ipulse];
1060             recv_id       = overlap->send_id[ipulse];
1061             send_index0   = overlap->comm_data[ipulse].recv_index0;
1062             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1063             recv_index0   = overlap->comm_data[ipulse].send_index0;
1064             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1065         }
1066
1067         /* Copy data to contiguous send buffer */
1068         if (debug)
1069         {
1070             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1071                     pme->nodeid, overlap->nodeid, send_id,
1072                     pme->pmegrid_start_iy,
1073                     send_index0-pme->pmegrid_start_iy,
1074                     send_index0-pme->pmegrid_start_iy+send_nindex);
1075         }
1076         icnt = 0;
1077         for (i = 0; i < pme->pmegrid_nx; i++)
1078         {
1079             ix = i;
1080             for (j = 0; j < send_nindex; j++)
1081             {
1082                 iy = j + send_index0 - pme->pmegrid_start_iy;
1083                 for (k = 0; k < pme->nkz; k++)
1084                 {
1085                     iz = k;
1086                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1087                 }
1088             }
1089         }
1090
1091         datasize      = pme->pmegrid_nx * pme->nkz;
1092
1093         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1094                      send_id, ipulse,
1095                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1096                      recv_id, ipulse,
1097                      overlap->mpi_comm, &stat);
1098
1099         /* Get data from contiguous recv buffer */
1100         if (debug)
1101         {
1102             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1103                     pme->nodeid, overlap->nodeid, recv_id,
1104                     pme->pmegrid_start_iy,
1105                     recv_index0-pme->pmegrid_start_iy,
1106                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1107         }
1108         icnt = 0;
1109         for (i = 0; i < pme->pmegrid_nx; i++)
1110         {
1111             ix = i;
1112             for (j = 0; j < recv_nindex; j++)
1113             {
1114                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1115                 for (k = 0; k < pme->nkz; k++)
1116                 {
1117                     iz = k;
1118                     if (direction == GMX_SUM_QGRID_FORWARD)
1119                     {
1120                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1121                     }
1122                     else
1123                     {
1124                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1125                     }
1126                 }
1127             }
1128         }
1129     }
1130
1131     /* Major dimension is easier, no copying required,
1132      * but we might have to sum to separate array.
1133      * Since we don't copy, we have to communicate up to pmegrid_nz,
1134      * not nkz as for the minor direction.
1135      */
1136     overlap = &pme->overlap[0];
1137
1138     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1139     {
1140         if (direction == GMX_SUM_QGRID_FORWARD)
1141         {
1142             send_id       = overlap->send_id[ipulse];
1143             recv_id       = overlap->recv_id[ipulse];
1144             send_index0   = overlap->comm_data[ipulse].send_index0;
1145             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1146             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1147             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1148             recvptr       = overlap->recvbuf;
1149         }
1150         else
1151         {
1152             send_id       = overlap->recv_id[ipulse];
1153             recv_id       = overlap->send_id[ipulse];
1154             send_index0   = overlap->comm_data[ipulse].recv_index0;
1155             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1156             recv_index0   = overlap->comm_data[ipulse].send_index0;
1157             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1158             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1159         }
1160
1161         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1162         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1163
1164         if (debug)
1165         {
1166             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1167                     pme->nodeid, overlap->nodeid, send_id,
1168                     pme->pmegrid_start_ix,
1169                     send_index0-pme->pmegrid_start_ix,
1170                     send_index0-pme->pmegrid_start_ix+send_nindex);
1171             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1172                     pme->nodeid, overlap->nodeid, recv_id,
1173                     pme->pmegrid_start_ix,
1174                     recv_index0-pme->pmegrid_start_ix,
1175                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1176         }
1177
1178         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1179                      send_id, ipulse,
1180                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1181                      recv_id, ipulse,
1182                      overlap->mpi_comm, &stat);
1183
1184         /* ADD data from contiguous recv buffer */
1185         if (direction == GMX_SUM_QGRID_FORWARD)
1186         {
1187             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1188             for (i = 0; i < recv_nindex*datasize; i++)
1189             {
1190                 p[i] += overlap->recvbuf[i];
1191             }
1192         }
1193     }
1194 }
1195 #endif
1196
1197
1198 static int copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid, int grid_index)
1199 {
1200     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1201     ivec    local_pme_size;
1202     int     i, ix, iy, iz;
1203     int     pmeidx, fftidx;
1204
1205     /* Dimensions should be identical for A/B grid, so we just use A here */
1206     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1207                                    local_fft_ndata,
1208                                    local_fft_offset,
1209                                    local_fft_size);
1210
1211     local_pme_size[0] = pme->pmegrid_nx;
1212     local_pme_size[1] = pme->pmegrid_ny;
1213     local_pme_size[2] = pme->pmegrid_nz;
1214
1215     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1216        the offset is identical, and the PME grid always has more data (due to overlap)
1217      */
1218     {
1219 #ifdef DEBUG_PME
1220         FILE *fp, *fp2;
1221         char  fn[STRLEN], format[STRLEN];
1222         real  val;
1223         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1224         fp = ffopen(fn, "w");
1225         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1226         fp2 = ffopen(fn, "w");
1227         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1228 #endif
1229
1230         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1231         {
1232             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1233             {
1234                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1235                 {
1236                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1237                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1238                     fftgrid[fftidx] = pmegrid[pmeidx];
1239 #ifdef DEBUG_PME
1240                     val = 100*pmegrid[pmeidx];
1241                     if (pmegrid[pmeidx] != 0)
1242                     {
1243                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1244                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1245                     }
1246                     if (pmegrid[pmeidx] != 0)
1247                     {
1248                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1249                                 "qgrid",
1250                                 pme->pmegrid_start_ix + ix,
1251                                 pme->pmegrid_start_iy + iy,
1252                                 pme->pmegrid_start_iz + iz,
1253                                 pmegrid[pmeidx]);
1254                     }
1255 #endif
1256                 }
1257             }
1258         }
1259 #ifdef DEBUG_PME
1260         ffclose(fp);
1261         ffclose(fp2);
1262 #endif
1263     }
1264     return 0;
1265 }
1266
1267
1268 static gmx_cycles_t omp_cyc_start()
1269 {
1270     return gmx_cycles_read();
1271 }
1272
1273 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1274 {
1275     return gmx_cycles_read() - c;
1276 }
1277
1278
1279 static int copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid, int grid_index,
1280                                    int nthread, int thread)
1281 {
1282     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1283     ivec          local_pme_size;
1284     int           ixy0, ixy1, ixy, ix, iy, iz;
1285     int           pmeidx, fftidx;
1286 #ifdef PME_TIME_THREADS
1287     gmx_cycles_t  c1;
1288     static double cs1 = 0;
1289     static int    cnt = 0;
1290 #endif
1291
1292 #ifdef PME_TIME_THREADS
1293     c1 = omp_cyc_start();
1294 #endif
1295     /* Dimensions should be identical for A/B grid, so we just use A here */
1296     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
1297                                    local_fft_ndata,
1298                                    local_fft_offset,
1299                                    local_fft_size);
1300
1301     local_pme_size[0] = pme->pmegrid_nx;
1302     local_pme_size[1] = pme->pmegrid_ny;
1303     local_pme_size[2] = pme->pmegrid_nz;
1304
1305     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1306        the offset is identical, and the PME grid always has more data (due to overlap)
1307      */
1308     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1309     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1310
1311     for (ixy = ixy0; ixy < ixy1; ixy++)
1312     {
1313         ix = ixy/local_fft_ndata[YY];
1314         iy = ixy - ix*local_fft_ndata[YY];
1315
1316         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1317         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1318         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1319         {
1320             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1321         }
1322     }
1323
1324 #ifdef PME_TIME_THREADS
1325     c1   = omp_cyc_end(c1);
1326     cs1 += (double)c1;
1327     cnt++;
1328     if (cnt % 20 == 0)
1329     {
1330         printf("copy %.2f\n", cs1*1e-9);
1331     }
1332 #endif
1333
1334     return 0;
1335 }
1336
1337
1338 static void wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1339 {
1340     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1341
1342     nx = pme->nkx;
1343     ny = pme->nky;
1344     nz = pme->nkz;
1345
1346     pnx = pme->pmegrid_nx;
1347     pny = pme->pmegrid_ny;
1348     pnz = pme->pmegrid_nz;
1349
1350     overlap = pme->pme_order - 1;
1351
1352     /* Add periodic overlap in z */
1353     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1354     {
1355         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1356         {
1357             for (iz = 0; iz < overlap; iz++)
1358             {
1359                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1360                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1361             }
1362         }
1363     }
1364
1365     if (pme->nnodes_minor == 1)
1366     {
1367         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1368         {
1369             for (iy = 0; iy < overlap; iy++)
1370             {
1371                 for (iz = 0; iz < nz; iz++)
1372                 {
1373                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1374                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1375                 }
1376             }
1377         }
1378     }
1379
1380     if (pme->nnodes_major == 1)
1381     {
1382         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1383
1384         for (ix = 0; ix < overlap; ix++)
1385         {
1386             for (iy = 0; iy < ny_x; iy++)
1387             {
1388                 for (iz = 0; iz < nz; iz++)
1389                 {
1390                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1391                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1392                 }
1393             }
1394         }
1395     }
1396 }
1397
1398
1399 static void unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1400 {
1401     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1402
1403     nx = pme->nkx;
1404     ny = pme->nky;
1405     nz = pme->nkz;
1406
1407     pnx = pme->pmegrid_nx;
1408     pny = pme->pmegrid_ny;
1409     pnz = pme->pmegrid_nz;
1410
1411     overlap = pme->pme_order - 1;
1412
1413     if (pme->nnodes_major == 1)
1414     {
1415         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1416
1417         for (ix = 0; ix < overlap; ix++)
1418         {
1419             int iy, iz;
1420
1421             for (iy = 0; iy < ny_x; iy++)
1422             {
1423                 for (iz = 0; iz < nz; iz++)
1424                 {
1425                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1426                         pmegrid[(ix*pny+iy)*pnz+iz];
1427                 }
1428             }
1429         }
1430     }
1431
1432     if (pme->nnodes_minor == 1)
1433     {
1434 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1435         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1436         {
1437             int iy, iz;
1438
1439             for (iy = 0; iy < overlap; iy++)
1440             {
1441                 for (iz = 0; iz < nz; iz++)
1442                 {
1443                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1444                         pmegrid[(ix*pny+iy)*pnz+iz];
1445                 }
1446             }
1447         }
1448     }
1449
1450     /* Copy periodic overlap in z */
1451 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1452     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1453     {
1454         int iy, iz;
1455
1456         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1457         {
1458             for (iz = 0; iz < overlap; iz++)
1459             {
1460                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1461                     pmegrid[(ix*pny+iy)*pnz+iz];
1462             }
1463         }
1464     }
1465 }
1466
1467
1468 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1469 #define DO_BSPLINE(order)                            \
1470     for (ithx = 0; (ithx < order); ithx++)                    \
1471     {                                                    \
1472         index_x = (i0+ithx)*pny*pnz;                     \
1473         valx    = qn*thx[ithx];                          \
1474                                                      \
1475         for (ithy = 0; (ithy < order); ithy++)                \
1476         {                                                \
1477             valxy    = valx*thy[ithy];                   \
1478             index_xy = index_x+(j0+ithy)*pnz;            \
1479                                                      \
1480             for (ithz = 0; (ithz < order); ithz++)            \
1481             {                                            \
1482                 index_xyz        = index_xy+(k0+ithz);   \
1483                 grid[index_xyz] += valxy*thz[ithz];      \
1484             }                                            \
1485         }                                                \
1486     }
1487
1488
1489 static void spread_q_bsplines_thread(pmegrid_t                    *pmegrid,
1490                                      pme_atomcomm_t               *atc,
1491                                      splinedata_t                 *spline,
1492                                      pme_spline_work_t gmx_unused *work)
1493 {
1494
1495     /* spread charges from home atoms to local grid */
1496     real          *grid;
1497     pme_overlap_t *ol;
1498     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1499     int       *    idxptr;
1500     int            order, norder, index_x, index_xy, index_xyz;
1501     real           valx, valxy, qn;
1502     real          *thx, *thy, *thz;
1503     int            localsize, bndsize;
1504     int            pnx, pny, pnz, ndatatot;
1505     int            offx, offy, offz;
1506
1507 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
1508     real           thz_buffer[12], *thz_aligned;
1509
1510     thz_aligned = gmx_simd4_align_real(thz_buffer);
1511 #endif
1512
1513     pnx = pmegrid->s[XX];
1514     pny = pmegrid->s[YY];
1515     pnz = pmegrid->s[ZZ];
1516
1517     offx = pmegrid->offset[XX];
1518     offy = pmegrid->offset[YY];
1519     offz = pmegrid->offset[ZZ];
1520
1521     ndatatot = pnx*pny*pnz;
1522     grid     = pmegrid->grid;
1523     for (i = 0; i < ndatatot; i++)
1524     {
1525         grid[i] = 0;
1526     }
1527
1528     order = pmegrid->order;
1529
1530     for (nn = 0; nn < spline->n; nn++)
1531     {
1532         n  = spline->ind[nn];
1533         qn = atc->q[n];
1534
1535         if (qn != 0)
1536         {
1537             idxptr = atc->idx[n];
1538             norder = nn*order;
1539
1540             i0   = idxptr[XX] - offx;
1541             j0   = idxptr[YY] - offy;
1542             k0   = idxptr[ZZ] - offz;
1543
1544             thx = spline->theta[XX] + norder;
1545             thy = spline->theta[YY] + norder;
1546             thz = spline->theta[ZZ] + norder;
1547
1548             switch (order)
1549             {
1550                 case 4:
1551 #ifdef PME_SIMD4_SPREAD_GATHER
1552 #ifdef PME_SIMD4_UNALIGNED
1553 #define PME_SPREAD_SIMD4_ORDER4
1554 #else
1555 #define PME_SPREAD_SIMD4_ALIGNED
1556 #define PME_ORDER 4
1557 #endif
1558 #include "pme_simd4.h"
1559 #else
1560                     DO_BSPLINE(4);
1561 #endif
1562                     break;
1563                 case 5:
1564 #ifdef PME_SIMD4_SPREAD_GATHER
1565 #define PME_SPREAD_SIMD4_ALIGNED
1566 #define PME_ORDER 5
1567 #include "pme_simd4.h"
1568 #else
1569                     DO_BSPLINE(5);
1570 #endif
1571                     break;
1572                 default:
1573                     DO_BSPLINE(order);
1574                     break;
1575             }
1576         }
1577     }
1578 }
1579
1580 static void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
1581 {
1582 #ifdef PME_SIMD4_SPREAD_GATHER
1583     if (pme_order == 5
1584 #ifndef PME_SIMD4_UNALIGNED
1585         || pme_order == 4
1586 #endif
1587         )
1588     {
1589         /* Round nz up to a multiple of 4 to ensure alignment */
1590         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1591     }
1592 #endif
1593 }
1594
1595 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1596 {
1597 #ifdef PME_SIMD4_SPREAD_GATHER
1598 #ifndef PME_SIMD4_UNALIGNED
1599     if (pme_order == 4)
1600     {
1601         /* Add extra elements to ensured aligned operations do not go
1602          * beyond the allocated grid size.
1603          * Note that for pme_order=5, the pme grid z-size alignment
1604          * ensures that we will not go beyond the grid size.
1605          */
1606         *gridsize += 4;
1607     }
1608 #endif
1609 #endif
1610 }
1611
1612 static void pmegrid_init(pmegrid_t *grid,
1613                          int cx, int cy, int cz,
1614                          int x0, int y0, int z0,
1615                          int x1, int y1, int z1,
1616                          gmx_bool set_alignment,
1617                          int pme_order,
1618                          real *ptr)
1619 {
1620     int nz, gridsize;
1621
1622     grid->ci[XX]     = cx;
1623     grid->ci[YY]     = cy;
1624     grid->ci[ZZ]     = cz;
1625     grid->offset[XX] = x0;
1626     grid->offset[YY] = y0;
1627     grid->offset[ZZ] = z0;
1628     grid->n[XX]      = x1 - x0 + pme_order - 1;
1629     grid->n[YY]      = y1 - y0 + pme_order - 1;
1630     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1631     copy_ivec(grid->n, grid->s);
1632
1633     nz = grid->s[ZZ];
1634     set_grid_alignment(&nz, pme_order);
1635     if (set_alignment)
1636     {
1637         grid->s[ZZ] = nz;
1638     }
1639     else if (nz != grid->s[ZZ])
1640     {
1641         gmx_incons("pmegrid_init call with an unaligned z size");
1642     }
1643
1644     grid->order = pme_order;
1645     if (ptr == NULL)
1646     {
1647         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1648         set_gridsize_alignment(&gridsize, pme_order);
1649         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
1650     }
1651     else
1652     {
1653         grid->grid = ptr;
1654     }
1655 }
1656
1657 static int div_round_up(int enumerator, int denominator)
1658 {
1659     return (enumerator + denominator - 1)/denominator;
1660 }
1661
1662 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1663                                   ivec nsub)
1664 {
1665     int gsize_opt, gsize;
1666     int nsx, nsy, nsz;
1667     char *env;
1668
1669     gsize_opt = -1;
1670     for (nsx = 1; nsx <= nthread; nsx++)
1671     {
1672         if (nthread % nsx == 0)
1673         {
1674             for (nsy = 1; nsy <= nthread; nsy++)
1675             {
1676                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1677                 {
1678                     nsz = nthread/(nsx*nsy);
1679
1680                     /* Determine the number of grid points per thread */
1681                     gsize =
1682                         (div_round_up(n[XX], nsx) + ovl)*
1683                         (div_round_up(n[YY], nsy) + ovl)*
1684                         (div_round_up(n[ZZ], nsz) + ovl);
1685
1686                     /* Minimize the number of grids points per thread
1687                      * and, secondarily, the number of cuts in minor dimensions.
1688                      */
1689                     if (gsize_opt == -1 ||
1690                         gsize < gsize_opt ||
1691                         (gsize == gsize_opt &&
1692                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1693                     {
1694                         nsub[XX]  = nsx;
1695                         nsub[YY]  = nsy;
1696                         nsub[ZZ]  = nsz;
1697                         gsize_opt = gsize;
1698                     }
1699                 }
1700             }
1701         }
1702     }
1703
1704     env = getenv("GMX_PME_THREAD_DIVISION");
1705     if (env != NULL)
1706     {
1707         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1708     }
1709
1710     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1711     {
1712         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1713     }
1714 }
1715
1716 static void pmegrids_init(pmegrids_t *grids,
1717                           int nx, int ny, int nz, int nz_base,
1718                           int pme_order,
1719                           gmx_bool bUseThreads,
1720                           int nthread,
1721                           int overlap_x,
1722                           int overlap_y)
1723 {
1724     ivec n, n_base, g0, g1;
1725     int t, x, y, z, d, i, tfac;
1726     int max_comm_lines = -1;
1727
1728     n[XX] = nx - (pme_order - 1);
1729     n[YY] = ny - (pme_order - 1);
1730     n[ZZ] = nz - (pme_order - 1);
1731
1732     copy_ivec(n, n_base);
1733     n_base[ZZ] = nz_base;
1734
1735     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1736                  NULL);
1737
1738     grids->nthread = nthread;
1739
1740     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1741
1742     if (bUseThreads)
1743     {
1744         ivec nst;
1745         int gridsize;
1746
1747         for (d = 0; d < DIM; d++)
1748         {
1749             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1750         }
1751         set_grid_alignment(&nst[ZZ], pme_order);
1752
1753         if (debug)
1754         {
1755             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1756                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1757             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1758                     nx, ny, nz,
1759                     nst[XX], nst[YY], nst[ZZ]);
1760         }
1761
1762         snew(grids->grid_th, grids->nthread);
1763         t        = 0;
1764         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1765         set_gridsize_alignment(&gridsize, pme_order);
1766         snew_aligned(grids->grid_all,
1767                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1768                      SIMD4_ALIGNMENT);
1769
1770         for (x = 0; x < grids->nc[XX]; x++)
1771         {
1772             for (y = 0; y < grids->nc[YY]; y++)
1773             {
1774                 for (z = 0; z < grids->nc[ZZ]; z++)
1775                 {
1776                     pmegrid_init(&grids->grid_th[t],
1777                                  x, y, z,
1778                                  (n[XX]*(x  ))/grids->nc[XX],
1779                                  (n[YY]*(y  ))/grids->nc[YY],
1780                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1781                                  (n[XX]*(x+1))/grids->nc[XX],
1782                                  (n[YY]*(y+1))/grids->nc[YY],
1783                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1784                                  TRUE,
1785                                  pme_order,
1786                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1787                     t++;
1788                 }
1789             }
1790         }
1791     }
1792     else
1793     {
1794         grids->grid_th = NULL;
1795     }
1796
1797     snew(grids->g2t, DIM);
1798     tfac = 1;
1799     for (d = DIM-1; d >= 0; d--)
1800     {
1801         snew(grids->g2t[d], n[d]);
1802         t = 0;
1803         for (i = 0; i < n[d]; i++)
1804         {
1805             /* The second check should match the parameters
1806              * of the pmegrid_init call above.
1807              */
1808             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1809             {
1810                 t++;
1811             }
1812             grids->g2t[d][i] = t*tfac;
1813         }
1814
1815         tfac *= grids->nc[d];
1816
1817         switch (d)
1818         {
1819             case XX: max_comm_lines = overlap_x;     break;
1820             case YY: max_comm_lines = overlap_y;     break;
1821             case ZZ: max_comm_lines = pme_order - 1; break;
1822         }
1823         grids->nthread_comm[d] = 0;
1824         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1825                grids->nthread_comm[d] < grids->nc[d])
1826         {
1827             grids->nthread_comm[d]++;
1828         }
1829         if (debug != NULL)
1830         {
1831             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1832                     'x'+d, grids->nthread_comm[d]);
1833         }
1834         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1835          * work, but this is not a problematic restriction.
1836          */
1837         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1838         {
1839             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1840         }
1841     }
1842 }
1843
1844
1845 static void pmegrids_destroy(pmegrids_t *grids)
1846 {
1847     int t;
1848
1849     if (grids->grid.grid != NULL)
1850     {
1851         sfree(grids->grid.grid);
1852
1853         if (grids->nthread > 0)
1854         {
1855             for (t = 0; t < grids->nthread; t++)
1856             {
1857                 sfree(grids->grid_th[t].grid);
1858             }
1859             sfree(grids->grid_th);
1860         }
1861     }
1862 }
1863
1864
1865 static void realloc_work(pme_work_t *work, int nkx)
1866 {
1867     int simd_width;
1868
1869     if (nkx > work->nalloc)
1870     {
1871         work->nalloc = nkx;
1872         srenew(work->mhx, work->nalloc);
1873         srenew(work->mhy, work->nalloc);
1874         srenew(work->mhz, work->nalloc);
1875         srenew(work->m2, work->nalloc);
1876         /* Allocate an aligned pointer for SIMD operations, including extra
1877          * elements at the end for padding.
1878          */
1879 #ifdef PME_SIMD
1880         simd_width = GMX_SIMD_WIDTH_HERE;
1881 #else
1882         /* We can use any alignment, apart from 0, so we use 4 */
1883         simd_width = 4;
1884 #endif
1885         sfree_aligned(work->denom);
1886         sfree_aligned(work->tmp1);
1887         sfree_aligned(work->tmp2);
1888         sfree_aligned(work->eterm);
1889         snew_aligned(work->denom, work->nalloc+simd_width, simd_width*sizeof(real));
1890         snew_aligned(work->tmp1,  work->nalloc+simd_width, simd_width*sizeof(real));
1891         snew_aligned(work->tmp2,  work->nalloc+simd_width, simd_width*sizeof(real));
1892         snew_aligned(work->eterm, work->nalloc+simd_width, simd_width*sizeof(real));
1893         srenew(work->m2inv, work->nalloc);
1894     }
1895 }
1896
1897
1898 static void free_work(pme_work_t *work)
1899 {
1900     sfree(work->mhx);
1901     sfree(work->mhy);
1902     sfree(work->mhz);
1903     sfree(work->m2);
1904     sfree_aligned(work->denom);
1905     sfree_aligned(work->tmp1);
1906     sfree_aligned(work->tmp2);
1907     sfree_aligned(work->eterm);
1908     sfree(work->m2inv);
1909 }
1910
1911
1912 #ifdef PME_SIMD
1913 /* Calculate exponentials through SIMD */
1914 inline static void calc_exponentials_q(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1915 {
1916     {
1917         const gmx_mm_pr two = gmx_set1_pr(2.0);
1918         gmx_mm_pr f_simd;
1919         gmx_mm_pr lu;
1920         gmx_mm_pr tmp_d1, d_inv, tmp_r, tmp_e;
1921         int kx;
1922         f_simd = gmx_set1_pr(f);
1923         /* We only need to calculate from start. But since start is 0 or 1
1924          * and we want to use aligned loads/stores, we always start from 0.
1925          */
1926         for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1927         {
1928             tmp_d1   = gmx_load_pr(d_aligned+kx);
1929             d_inv    = gmx_inv_pr(tmp_d1);
1930             tmp_r    = gmx_load_pr(r_aligned+kx);
1931             tmp_r    = gmx_exp_pr(tmp_r);
1932             tmp_e    = gmx_mul_pr(f_simd, d_inv);
1933             tmp_e    = gmx_mul_pr(tmp_e, tmp_r);
1934             gmx_store_pr(e_aligned+kx, tmp_e);
1935         }
1936     }
1937 }
1938 #else
1939 inline static void calc_exponentials_q(int start, int end, real f, real *d, real *r, real *e)
1940 {
1941     int kx;
1942     for (kx = start; kx < end; kx++)
1943     {
1944         d[kx] = 1.0/d[kx];
1945     }
1946     for (kx = start; kx < end; kx++)
1947     {
1948         r[kx] = exp(r[kx]);
1949     }
1950     for (kx = start; kx < end; kx++)
1951     {
1952         e[kx] = f*r[kx]*d[kx];
1953     }
1954 }
1955 #endif
1956
1957 #ifdef PME_SIMD
1958 /* Calculate exponentials through SIMD */
1959 inline static void calc_exponentials_lj(int gmx_unused start, int end, real *r_aligned, real *factor_aligned, real *d_aligned)
1960 {
1961     gmx_mm_pr tmp_r, tmp_d, tmp_fac, d_inv, tmp_mk;
1962     const gmx_mm_pr sqr_PI = gmx_sqrt_pr(gmx_set1_pr(M_PI));
1963     int kx;
1964     for (kx = 0; kx < end; kx += GMX_SIMD_WIDTH_HERE)
1965     {
1966         /* We only need to calculate from start. But since start is 0 or 1
1967          * and we want to use aligned loads/stores, we always start from 0.
1968          */
1969         tmp_d = gmx_load_pr(d_aligned+kx);
1970         d_inv = gmx_inv_pr(tmp_d);
1971         gmx_store_pr(d_aligned+kx, d_inv);
1972         tmp_r = gmx_load_pr(r_aligned+kx);
1973         tmp_r = gmx_exp_pr(tmp_r);
1974         gmx_store_pr(r_aligned+kx, tmp_r);
1975         tmp_mk  = gmx_load_pr(factor_aligned+kx);
1976         tmp_fac = gmx_mul_pr(sqr_PI, gmx_mul_pr(tmp_mk, gmx_erfc_pr(tmp_mk)));
1977         gmx_store_pr(factor_aligned+kx, tmp_fac);
1978     }
1979 }
1980 #else
1981 inline static void calc_exponentials_lj(int start, int end, real *r, real *tmp2, real *d)
1982 {
1983     int kx;
1984     real mk;
1985     for (kx = start; kx < end; kx++)
1986     {
1987         d[kx] = 1.0/d[kx];
1988     }
1989
1990     for (kx = start; kx < end; kx++)
1991     {
1992         r[kx] = exp(r[kx]);
1993     }
1994
1995     for (kx = start; kx < end; kx++)
1996     {
1997         mk       = tmp2[kx];
1998         tmp2[kx] = sqrt(M_PI)*mk*gmx_erfc(mk);
1999     }
2000 }
2001 #endif
2002
2003 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
2004                          real ewaldcoeff, real vol,
2005                          gmx_bool bEnerVir,
2006                          int nthread, int thread)
2007 {
2008     /* do recip sum over local cells in grid */
2009     /* y major, z middle, x minor or continuous */
2010     t_complex *p0;
2011     int     kx, ky, kz, maxkx, maxky, maxkz;
2012     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
2013     real    mx, my, mz;
2014     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2015     real    ets2, struct2, vfactor, ets2vf;
2016     real    d1, d2, energy = 0;
2017     real    by, bz;
2018     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2019     real    rxx, ryx, ryy, rzx, rzy, rzz;
2020     pme_work_t *work;
2021     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
2022     real    mhxk, mhyk, mhzk, m2k;
2023     real    corner_fac;
2024     ivec    complex_order;
2025     ivec    local_ndata, local_offset, local_size;
2026     real    elfac;
2027
2028     elfac = ONE_4PI_EPS0/pme->epsilon_r;
2029
2030     nx = pme->nkx;
2031     ny = pme->nky;
2032     nz = pme->nkz;
2033
2034     /* Dimensions should be identical for A/B grid, so we just use A here */
2035     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_QA],
2036                                       complex_order,
2037                                       local_ndata,
2038                                       local_offset,
2039                                       local_size);
2040
2041     rxx = pme->recipbox[XX][XX];
2042     ryx = pme->recipbox[YY][XX];
2043     ryy = pme->recipbox[YY][YY];
2044     rzx = pme->recipbox[ZZ][XX];
2045     rzy = pme->recipbox[ZZ][YY];
2046     rzz = pme->recipbox[ZZ][ZZ];
2047
2048     maxkx = (nx+1)/2;
2049     maxky = (ny+1)/2;
2050     maxkz = nz/2+1;
2051
2052     work  = &pme->work[thread];
2053     mhx   = work->mhx;
2054     mhy   = work->mhy;
2055     mhz   = work->mhz;
2056     m2    = work->m2;
2057     denom = work->denom;
2058     tmp1  = work->tmp1;
2059     eterm = work->eterm;
2060     m2inv = work->m2inv;
2061
2062     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2063     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2064
2065     for (iyz = iyz0; iyz < iyz1; iyz++)
2066     {
2067         iy = iyz/local_ndata[ZZ];
2068         iz = iyz - iy*local_ndata[ZZ];
2069
2070         ky = iy + local_offset[YY];
2071
2072         if (ky < maxky)
2073         {
2074             my = ky;
2075         }
2076         else
2077         {
2078             my = (ky - ny);
2079         }
2080
2081         by = M_PI*vol*pme->bsp_mod[YY][ky];
2082
2083         kz = iz + local_offset[ZZ];
2084
2085         mz = kz;
2086
2087         bz = pme->bsp_mod[ZZ][kz];
2088
2089         /* 0.5 correction for corner points */
2090         corner_fac = 1;
2091         if (kz == 0 || kz == (nz+1)/2)
2092         {
2093             corner_fac = 0.5;
2094         }
2095
2096         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2097
2098         /* We should skip the k-space point (0,0,0) */
2099         /* Note that since here x is the minor index, local_offset[XX]=0 */
2100         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2101         {
2102             kxstart = local_offset[XX];
2103         }
2104         else
2105         {
2106             kxstart = local_offset[XX] + 1;
2107             p0++;
2108         }
2109         kxend = local_offset[XX] + local_ndata[XX];
2110
2111         if (bEnerVir)
2112         {
2113             /* More expensive inner loop, especially because of the storage
2114              * of the mh elements in array's.
2115              * Because x is the minor grid index, all mh elements
2116              * depend on kx for triclinic unit cells.
2117              */
2118
2119             /* Two explicit loops to avoid a conditional inside the loop */
2120             for (kx = kxstart; kx < maxkx; kx++)
2121             {
2122                 mx = kx;
2123
2124                 mhxk      = mx * rxx;
2125                 mhyk      = mx * ryx + my * ryy;
2126                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2127                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2128                 mhx[kx]   = mhxk;
2129                 mhy[kx]   = mhyk;
2130                 mhz[kx]   = mhzk;
2131                 m2[kx]    = m2k;
2132                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2133                 tmp1[kx]  = -factor*m2k;
2134             }
2135
2136             for (kx = maxkx; kx < kxend; kx++)
2137             {
2138                 mx = (kx - nx);
2139
2140                 mhxk      = mx * rxx;
2141                 mhyk      = mx * ryx + my * ryy;
2142                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2143                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2144                 mhx[kx]   = mhxk;
2145                 mhy[kx]   = mhyk;
2146                 mhz[kx]   = mhzk;
2147                 m2[kx]    = m2k;
2148                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2149                 tmp1[kx]  = -factor*m2k;
2150             }
2151
2152             for (kx = kxstart; kx < kxend; kx++)
2153             {
2154                 m2inv[kx] = 1.0/m2[kx];
2155             }
2156
2157             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2158
2159             for (kx = kxstart; kx < kxend; kx++, p0++)
2160             {
2161                 d1      = p0->re;
2162                 d2      = p0->im;
2163
2164                 p0->re  = d1*eterm[kx];
2165                 p0->im  = d2*eterm[kx];
2166
2167                 struct2 = 2.0*(d1*d1+d2*d2);
2168
2169                 tmp1[kx] = eterm[kx]*struct2;
2170             }
2171
2172             for (kx = kxstart; kx < kxend; kx++)
2173             {
2174                 ets2     = corner_fac*tmp1[kx];
2175                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2176                 energy  += ets2;
2177
2178                 ets2vf   = ets2*vfactor;
2179                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2180                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2181                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2182                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2183                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2184                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2185             }
2186         }
2187         else
2188         {
2189             /* We don't need to calculate the energy and the virial.
2190              * In this case the triclinic overhead is small.
2191              */
2192
2193             /* Two explicit loops to avoid a conditional inside the loop */
2194
2195             for (kx = kxstart; kx < maxkx; kx++)
2196             {
2197                 mx = kx;
2198
2199                 mhxk      = mx * rxx;
2200                 mhyk      = mx * ryx + my * ryy;
2201                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2202                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2203                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2204                 tmp1[kx]  = -factor*m2k;
2205             }
2206
2207             for (kx = maxkx; kx < kxend; kx++)
2208             {
2209                 mx = (kx - nx);
2210
2211                 mhxk      = mx * rxx;
2212                 mhyk      = mx * ryx + my * ryy;
2213                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2214                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2215                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2216                 tmp1[kx]  = -factor*m2k;
2217             }
2218
2219             calc_exponentials_q(kxstart, kxend, elfac, denom, tmp1, eterm);
2220
2221             for (kx = kxstart; kx < kxend; kx++, p0++)
2222             {
2223                 d1      = p0->re;
2224                 d2      = p0->im;
2225
2226                 p0->re  = d1*eterm[kx];
2227                 p0->im  = d2*eterm[kx];
2228             }
2229         }
2230     }
2231
2232     if (bEnerVir)
2233     {
2234         /* Update virial with local values.
2235          * The virial is symmetric by definition.
2236          * this virial seems ok for isotropic scaling, but I'm
2237          * experiencing problems on semiisotropic membranes.
2238          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2239          */
2240         work->vir_q[XX][XX] = 0.25*virxx;
2241         work->vir_q[YY][YY] = 0.25*viryy;
2242         work->vir_q[ZZ][ZZ] = 0.25*virzz;
2243         work->vir_q[XX][YY] = work->vir_q[YY][XX] = 0.25*virxy;
2244         work->vir_q[XX][ZZ] = work->vir_q[ZZ][XX] = 0.25*virxz;
2245         work->vir_q[YY][ZZ] = work->vir_q[ZZ][YY] = 0.25*viryz;
2246
2247         /* This energy should be corrected for a charged system */
2248         work->energy_q = 0.5*energy;
2249     }
2250
2251     /* Return the loop count */
2252     return local_ndata[YY]*local_ndata[XX];
2253 }
2254
2255 static int solve_pme_lj_yzx(gmx_pme_t pme, t_complex **grid, gmx_bool bLB,
2256                             real ewaldcoeff, real vol,
2257                             gmx_bool bEnerVir, int nthread, int thread)
2258 {
2259     /* do recip sum over local cells in grid */
2260     /* y major, z middle, x minor or continuous */
2261     int     ig, gcount;
2262     int     kx, ky, kz, maxkx, maxky, maxkz;
2263     int     nx, ny, nz, iy, iyz0, iyz1, iyz, iz, kxstart, kxend;
2264     real    mx, my, mz;
2265     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
2266     real    ets2, ets2vf;
2267     real    eterm, vterm, d1, d2, energy = 0;
2268     real    by, bz;
2269     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
2270     real    rxx, ryx, ryy, rzx, rzy, rzz;
2271     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *tmp2;
2272     real    mhxk, mhyk, mhzk, m2k;
2273     real    mk;
2274     pme_work_t *work;
2275     real    corner_fac;
2276     ivec    complex_order;
2277     ivec    local_ndata, local_offset, local_size;
2278     nx = pme->nkx;
2279     ny = pme->nky;
2280     nz = pme->nkz;
2281
2282     /* Dimensions should be identical for A/B grid, so we just use A here */
2283     gmx_parallel_3dfft_complex_limits(pme->pfft_setup[PME_GRID_C6A],
2284                                       complex_order,
2285                                       local_ndata,
2286                                       local_offset,
2287                                       local_size);
2288     rxx = pme->recipbox[XX][XX];
2289     ryx = pme->recipbox[YY][XX];
2290     ryy = pme->recipbox[YY][YY];
2291     rzx = pme->recipbox[ZZ][XX];
2292     rzy = pme->recipbox[ZZ][YY];
2293     rzz = pme->recipbox[ZZ][ZZ];
2294
2295     maxkx = (nx+1)/2;
2296     maxky = (ny+1)/2;
2297     maxkz = nz/2+1;
2298
2299     work  = &pme->work[thread];
2300     mhx   = work->mhx;
2301     mhy   = work->mhy;
2302     mhz   = work->mhz;
2303     m2    = work->m2;
2304     denom = work->denom;
2305     tmp1  = work->tmp1;
2306     tmp2  = work->tmp2;
2307
2308     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2309     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2310
2311     for (iyz = iyz0; iyz < iyz1; iyz++)
2312     {
2313         iy = iyz/local_ndata[ZZ];
2314         iz = iyz - iy*local_ndata[ZZ];
2315
2316         ky = iy + local_offset[YY];
2317
2318         if (ky < maxky)
2319         {
2320             my = ky;
2321         }
2322         else
2323         {
2324             my = (ky - ny);
2325         }
2326
2327         by = 3.0*vol*pme->bsp_mod[YY][ky]
2328             / (M_PI*sqrt(M_PI)*ewaldcoeff*ewaldcoeff*ewaldcoeff);
2329
2330         kz = iz + local_offset[ZZ];
2331
2332         mz = kz;
2333
2334         bz = pme->bsp_mod[ZZ][kz];
2335
2336         /* 0.5 correction for corner points */
2337         corner_fac = 1;
2338         if (kz == 0 || kz == (nz+1)/2)
2339         {
2340             corner_fac = 0.5;
2341         }
2342
2343         kxstart = local_offset[XX];
2344         kxend   = local_offset[XX] + local_ndata[XX];
2345         if (bEnerVir)
2346         {
2347             /* More expensive inner loop, especially because of the
2348              * storage of the mh elements in array's.  Because x is the
2349              * minor grid index, all mh elements depend on kx for
2350              * triclinic unit cells.
2351              */
2352
2353             /* Two explicit loops to avoid a conditional inside the loop */
2354             for (kx = kxstart; kx < maxkx; kx++)
2355             {
2356                 mx = kx;
2357
2358                 mhxk      = mx * rxx;
2359                 mhyk      = mx * ryx + my * ryy;
2360                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2361                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2362                 mhx[kx]   = mhxk;
2363                 mhy[kx]   = mhyk;
2364                 mhz[kx]   = mhzk;
2365                 m2[kx]    = m2k;
2366                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2367                 tmp1[kx]  = -factor*m2k;
2368                 tmp2[kx]  = sqrt(factor*m2k);
2369             }
2370
2371             for (kx = maxkx; kx < kxend; kx++)
2372             {
2373                 mx = (kx - nx);
2374
2375                 mhxk      = mx * rxx;
2376                 mhyk      = mx * ryx + my * ryy;
2377                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2378                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2379                 mhx[kx]   = mhxk;
2380                 mhy[kx]   = mhyk;
2381                 mhz[kx]   = mhzk;
2382                 m2[kx]    = m2k;
2383                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2384                 tmp1[kx]  = -factor*m2k;
2385                 tmp2[kx]  = sqrt(factor*m2k);
2386             }
2387
2388             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2389
2390             for (kx = kxstart; kx < kxend; kx++)
2391             {
2392                 m2k   = factor*m2[kx];
2393                 eterm = -((1.0 - 2.0*m2k)*tmp1[kx]
2394                           + 2.0*m2k*tmp2[kx]);
2395                 vterm    = 3.0*(-tmp1[kx] + tmp2[kx]);
2396                 tmp1[kx] = eterm*denom[kx];
2397                 tmp2[kx] = vterm*denom[kx];
2398             }
2399
2400             if (!bLB)
2401             {
2402                 t_complex *p0;
2403                 real       struct2;
2404
2405                 p0 = grid[0] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2406                 for (kx = kxstart; kx < kxend; kx++, p0++)
2407                 {
2408                     d1      = p0->re;
2409                     d2      = p0->im;
2410
2411                     eterm   = tmp1[kx];
2412                     vterm   = tmp2[kx];
2413                     p0->re  = d1*eterm;
2414                     p0->im  = d2*eterm;
2415
2416                     struct2 = 2.0*(d1*d1+d2*d2);
2417
2418                     tmp1[kx] = eterm*struct2;
2419                     tmp2[kx] = vterm*struct2;
2420                 }
2421             }
2422             else
2423             {
2424                 real *struct2 = denom;
2425                 real  str2;
2426
2427                 for (kx = kxstart; kx < kxend; kx++)
2428                 {
2429                     struct2[kx] = 0.0;
2430                 }
2431                 /* Due to symmetry we only need to calculate 4 of the 7 terms */
2432                 for (ig = 0; ig <= 3; ++ig)
2433                 {
2434                     t_complex *p0, *p1;
2435                     real       scale;
2436
2437                     p0    = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2438                     p1    = grid[6-ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2439                     scale = 2.0*lb_scale_factor_symm[ig];
2440                     for (kx = kxstart; kx < kxend; ++kx, ++p0, ++p1)
2441                     {
2442                         struct2[kx] += scale*(p0->re*p1->re + p0->im*p1->im);
2443                     }
2444
2445                 }
2446                 for (ig = 0; ig <= 6; ++ig)
2447                 {
2448                     t_complex *p0;
2449
2450                     p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2451                     for (kx = kxstart; kx < kxend; kx++, p0++)
2452                     {
2453                         d1     = p0->re;
2454                         d2     = p0->im;
2455
2456                         eterm  = tmp1[kx];
2457                         p0->re = d1*eterm;
2458                         p0->im = d2*eterm;
2459                     }
2460                 }
2461                 for (kx = kxstart; kx < kxend; kx++)
2462                 {
2463                     eterm    = tmp1[kx];
2464                     vterm    = tmp2[kx];
2465                     str2     = struct2[kx];
2466                     tmp1[kx] = eterm*str2;
2467                     tmp2[kx] = vterm*str2;
2468                 }
2469             }
2470
2471             for (kx = kxstart; kx < kxend; kx++)
2472             {
2473                 ets2     = corner_fac*tmp1[kx];
2474                 vterm    = 2.0*factor*tmp2[kx];
2475                 energy  += ets2;
2476                 ets2vf   = corner_fac*vterm;
2477                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2478                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2479                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2480                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2481                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2482                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2483             }
2484         }
2485         else
2486         {
2487             /* We don't need to calculate the energy and the virial.
2488              *  In this case the triclinic overhead is small.
2489              */
2490
2491             /* Two explicit loops to avoid a conditional inside the loop */
2492
2493             for (kx = kxstart; kx < maxkx; kx++)
2494             {
2495                 mx = kx;
2496
2497                 mhxk      = mx * rxx;
2498                 mhyk      = mx * ryx + my * ryy;
2499                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2500                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2501                 m2[kx]    = m2k;
2502                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2503                 tmp1[kx]  = -factor*m2k;
2504                 tmp2[kx]  = sqrt(factor*m2k);
2505             }
2506
2507             for (kx = maxkx; kx < kxend; kx++)
2508             {
2509                 mx = (kx - nx);
2510
2511                 mhxk      = mx * rxx;
2512                 mhyk      = mx * ryx + my * ryy;
2513                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2514                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2515                 m2[kx]    = m2k;
2516                 denom[kx] = bz*by*pme->bsp_mod[XX][kx];
2517                 tmp1[kx]  = -factor*m2k;
2518                 tmp2[kx]  = sqrt(factor*m2k);
2519             }
2520
2521             calc_exponentials_lj(kxstart, kxend, tmp1, tmp2, denom);
2522
2523             for (kx = kxstart; kx < kxend; kx++)
2524             {
2525                 m2k    = factor*m2[kx];
2526                 eterm  = -((1.0 - 2.0*m2k)*tmp1[kx]
2527                            + 2.0*m2k*tmp2[kx]);
2528                 tmp1[kx] = eterm*denom[kx];
2529             }
2530             gcount = (bLB ? 7 : 1);
2531             for (ig = 0; ig < gcount; ++ig)
2532             {
2533                 t_complex *p0;
2534
2535                 p0 = grid[ig] + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2536                 for (kx = kxstart; kx < kxend; kx++, p0++)
2537                 {
2538                     d1      = p0->re;
2539                     d2      = p0->im;
2540
2541                     eterm   = tmp1[kx];
2542
2543                     p0->re  = d1*eterm;
2544                     p0->im  = d2*eterm;
2545                 }
2546             }
2547         }
2548     }
2549     if (bEnerVir)
2550     {
2551         work->vir_lj[XX][XX] = 0.25*virxx;
2552         work->vir_lj[YY][YY] = 0.25*viryy;
2553         work->vir_lj[ZZ][ZZ] = 0.25*virzz;
2554         work->vir_lj[XX][YY] = work->vir_lj[YY][XX] = 0.25*virxy;
2555         work->vir_lj[XX][ZZ] = work->vir_lj[ZZ][XX] = 0.25*virxz;
2556         work->vir_lj[YY][ZZ] = work->vir_lj[ZZ][YY] = 0.25*viryz;
2557         /* This energy should be corrected for a charged system */
2558         work->energy_lj = 0.5*energy;
2559     }
2560     /* Return the loop count */
2561     return local_ndata[YY]*local_ndata[XX];
2562 }
2563
2564 static void get_pme_ener_vir_q(const gmx_pme_t pme, int nthread,
2565                                real *mesh_energy, matrix vir)
2566 {
2567     /* This function sums output over threads and should therefore
2568      * only be called after thread synchronization.
2569      */
2570     int thread;
2571
2572     *mesh_energy = pme->work[0].energy_q;
2573     copy_mat(pme->work[0].vir_q, vir);
2574
2575     for (thread = 1; thread < nthread; thread++)
2576     {
2577         *mesh_energy += pme->work[thread].energy_q;
2578         m_add(vir, pme->work[thread].vir_q, vir);
2579     }
2580 }
2581
2582 static void get_pme_ener_vir_lj(const gmx_pme_t pme, int nthread,
2583                                 real *mesh_energy, matrix vir)
2584 {
2585     /* This function sums output over threads and should therefore
2586      * only be called after thread synchronization.
2587      */
2588     int thread;
2589
2590     *mesh_energy = pme->work[0].energy_lj;
2591     copy_mat(pme->work[0].vir_lj, vir);
2592
2593     for (thread = 1; thread < nthread; thread++)
2594     {
2595         *mesh_energy += pme->work[thread].energy_lj;
2596         m_add(vir, pme->work[thread].vir_lj, vir);
2597     }
2598 }
2599
2600
2601 #define DO_FSPLINE(order)                      \
2602     for (ithx = 0; (ithx < order); ithx++)              \
2603     {                                              \
2604         index_x = (i0+ithx)*pny*pnz;               \
2605         tx      = thx[ithx];                       \
2606         dx      = dthx[ithx];                      \
2607                                                \
2608         for (ithy = 0; (ithy < order); ithy++)          \
2609         {                                          \
2610             index_xy = index_x+(j0+ithy)*pnz;      \
2611             ty       = thy[ithy];                  \
2612             dy       = dthy[ithy];                 \
2613             fxy1     = fz1 = 0;                    \
2614                                                \
2615             for (ithz = 0; (ithz < order); ithz++)      \
2616             {                                      \
2617                 gval  = grid[index_xy+(k0+ithz)];  \
2618                 fxy1 += thz[ithz]*gval;            \
2619                 fz1  += dthz[ithz]*gval;           \
2620             }                                      \
2621             fx += dx*ty*fxy1;                      \
2622             fy += tx*dy*fxy1;                      \
2623             fz += tx*ty*fz1;                       \
2624         }                                          \
2625     }
2626
2627
2628 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2629                               gmx_bool bClearF, pme_atomcomm_t *atc,
2630                               splinedata_t *spline,
2631                               real scale)
2632 {
2633     /* sum forces for local particles */
2634     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2635     int     index_x, index_xy;
2636     int     nx, ny, nz, pnx, pny, pnz;
2637     int *   idxptr;
2638     real    tx, ty, dx, dy, qn;
2639     real    fx, fy, fz, gval;
2640     real    fxy1, fz1;
2641     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2642     int     norder;
2643     real    rxx, ryx, ryy, rzx, rzy, rzz;
2644     int     order;
2645
2646     pme_spline_work_t *work;
2647
2648 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
2649     real           thz_buffer[12],  *thz_aligned;
2650     real           dthz_buffer[12], *dthz_aligned;
2651
2652     thz_aligned  = gmx_simd4_align_real(thz_buffer);
2653     dthz_aligned = gmx_simd4_align_real(dthz_buffer);
2654 #endif
2655
2656     work = pme->spline_work;
2657
2658     order = pme->pme_order;
2659     thx   = spline->theta[XX];
2660     thy   = spline->theta[YY];
2661     thz   = spline->theta[ZZ];
2662     dthx  = spline->dtheta[XX];
2663     dthy  = spline->dtheta[YY];
2664     dthz  = spline->dtheta[ZZ];
2665     nx    = pme->nkx;
2666     ny    = pme->nky;
2667     nz    = pme->nkz;
2668     pnx   = pme->pmegrid_nx;
2669     pny   = pme->pmegrid_ny;
2670     pnz   = pme->pmegrid_nz;
2671
2672     rxx   = pme->recipbox[XX][XX];
2673     ryx   = pme->recipbox[YY][XX];
2674     ryy   = pme->recipbox[YY][YY];
2675     rzx   = pme->recipbox[ZZ][XX];
2676     rzy   = pme->recipbox[ZZ][YY];
2677     rzz   = pme->recipbox[ZZ][ZZ];
2678
2679     for (nn = 0; nn < spline->n; nn++)
2680     {
2681         n  = spline->ind[nn];
2682         qn = scale*atc->q[n];
2683
2684         if (bClearF)
2685         {
2686             atc->f[n][XX] = 0;
2687             atc->f[n][YY] = 0;
2688             atc->f[n][ZZ] = 0;
2689         }
2690         if (qn != 0)
2691         {
2692             fx     = 0;
2693             fy     = 0;
2694             fz     = 0;
2695             idxptr = atc->idx[n];
2696             norder = nn*order;
2697
2698             i0   = idxptr[XX];
2699             j0   = idxptr[YY];
2700             k0   = idxptr[ZZ];
2701
2702             /* Pointer arithmetic alert, next six statements */
2703             thx  = spline->theta[XX] + norder;
2704             thy  = spline->theta[YY] + norder;
2705             thz  = spline->theta[ZZ] + norder;
2706             dthx = spline->dtheta[XX] + norder;
2707             dthy = spline->dtheta[YY] + norder;
2708             dthz = spline->dtheta[ZZ] + norder;
2709
2710             switch (order)
2711             {
2712                 case 4:
2713 #ifdef PME_SIMD4_SPREAD_GATHER
2714 #ifdef PME_SIMD4_UNALIGNED
2715 #define PME_GATHER_F_SIMD4_ORDER4
2716 #else
2717 #define PME_GATHER_F_SIMD4_ALIGNED
2718 #define PME_ORDER 4
2719 #endif
2720 #include "pme_simd4.h"
2721 #else
2722                     DO_FSPLINE(4);
2723 #endif
2724                     break;
2725                 case 5:
2726 #ifdef PME_SIMD4_SPREAD_GATHER
2727 #define PME_GATHER_F_SIMD4_ALIGNED
2728 #define PME_ORDER 5
2729 #include "pme_simd4.h"
2730 #else
2731                     DO_FSPLINE(5);
2732 #endif
2733                     break;
2734                 default:
2735                     DO_FSPLINE(order);
2736                     break;
2737             }
2738
2739             atc->f[n][XX] += -qn*( fx*nx*rxx );
2740             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2741             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2742         }
2743     }
2744     /* Since the energy and not forces are interpolated
2745      * the net force might not be exactly zero.
2746      * This can be solved by also interpolating F, but
2747      * that comes at a cost.
2748      * A better hack is to remove the net force every
2749      * step, but that must be done at a higher level
2750      * since this routine doesn't see all atoms if running
2751      * in parallel. Don't know how important it is?  EL 990726
2752      */
2753 }
2754
2755
2756 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2757                                    pme_atomcomm_t *atc)
2758 {
2759     splinedata_t *spline;
2760     int     n, ithx, ithy, ithz, i0, j0, k0;
2761     int     index_x, index_xy;
2762     int *   idxptr;
2763     real    energy, pot, tx, ty, qn, gval;
2764     real    *thx, *thy, *thz;
2765     int     norder;
2766     int     order;
2767
2768     spline = &atc->spline[0];
2769
2770     order = pme->pme_order;
2771
2772     energy = 0;
2773     for (n = 0; (n < atc->n); n++)
2774     {
2775         qn      = atc->q[n];
2776
2777         if (qn != 0)
2778         {
2779             idxptr = atc->idx[n];
2780             norder = n*order;
2781
2782             i0   = idxptr[XX];
2783             j0   = idxptr[YY];
2784             k0   = idxptr[ZZ];
2785
2786             /* Pointer arithmetic alert, next three statements */
2787             thx  = spline->theta[XX] + norder;
2788             thy  = spline->theta[YY] + norder;
2789             thz  = spline->theta[ZZ] + norder;
2790
2791             pot = 0;
2792             for (ithx = 0; (ithx < order); ithx++)
2793             {
2794                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2795                 tx      = thx[ithx];
2796
2797                 for (ithy = 0; (ithy < order); ithy++)
2798                 {
2799                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2800                     ty       = thy[ithy];
2801
2802                     for (ithz = 0; (ithz < order); ithz++)
2803                     {
2804                         gval  = grid[index_xy+(k0+ithz)];
2805                         pot  += tx*ty*thz[ithz]*gval;
2806                     }
2807
2808                 }
2809             }
2810
2811             energy += pot*qn;
2812         }
2813     }
2814
2815     return energy;
2816 }
2817
2818 /* Macro to force loop unrolling by fixing order.
2819  * This gives a significant performance gain.
2820  */
2821 #define CALC_SPLINE(order)                     \
2822     {                                              \
2823         int j, k, l;                                 \
2824         real dr, div;                               \
2825         real data[PME_ORDER_MAX];                  \
2826         real ddata[PME_ORDER_MAX];                 \
2827                                                \
2828         for (j = 0; (j < DIM); j++)                     \
2829         {                                          \
2830             dr  = xptr[j];                         \
2831                                                \
2832             /* dr is relative offset from lower cell limit */ \
2833             data[order-1] = 0;                     \
2834             data[1]       = dr;                          \
2835             data[0]       = 1 - dr;                      \
2836                                                \
2837             for (k = 3; (k < order); k++)               \
2838             {                                      \
2839                 div       = 1.0/(k - 1.0);               \
2840                 data[k-1] = div*dr*data[k-2];      \
2841                 for (l = 1; (l < (k-1)); l++)           \
2842                 {                                  \
2843                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2844                                        data[k-l-1]);                \
2845                 }                                  \
2846                 data[0] = div*(1-dr)*data[0];      \
2847             }                                      \
2848             /* differentiate */                    \
2849             ddata[0] = -data[0];                   \
2850             for (k = 1; (k < order); k++)               \
2851             {                                      \
2852                 ddata[k] = data[k-1] - data[k];    \
2853             }                                      \
2854                                                \
2855             div           = 1.0/(order - 1);                 \
2856             data[order-1] = div*dr*data[order-2];  \
2857             for (l = 1; (l < (order-1)); l++)           \
2858             {                                      \
2859                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2860                                        (order-l-dr)*data[order-l-1]); \
2861             }                                      \
2862             data[0] = div*(1 - dr)*data[0];        \
2863                                                \
2864             for (k = 0; k < order; k++)                 \
2865             {                                      \
2866                 theta[j][i*order+k]  = data[k];    \
2867                 dtheta[j][i*order+k] = ddata[k];   \
2868             }                                      \
2869         }                                          \
2870     }
2871
2872 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2873                    rvec fractx[], int nr, int ind[], real charge[],
2874                    gmx_bool bDoSplines)
2875 {
2876     /* construct splines for local atoms */
2877     int  i, ii;
2878     real *xptr;
2879
2880     for (i = 0; i < nr; i++)
2881     {
2882         /* With free energy we do not use the charge check.
2883          * In most cases this will be more efficient than calling make_bsplines
2884          * twice, since usually more than half the particles have charges.
2885          */
2886         ii = ind[i];
2887         if (bDoSplines || charge[ii] != 0.0)
2888         {
2889             xptr = fractx[ii];
2890             switch (order)
2891             {
2892                 case 4:  CALC_SPLINE(4);     break;
2893                 case 5:  CALC_SPLINE(5);     break;
2894                 default: CALC_SPLINE(order); break;
2895             }
2896         }
2897     }
2898 }
2899
2900
2901 void make_dft_mod(real *mod, real *data, int ndata)
2902 {
2903     int i, j;
2904     real sc, ss, arg;
2905
2906     for (i = 0; i < ndata; i++)
2907     {
2908         sc = ss = 0;
2909         for (j = 0; j < ndata; j++)
2910         {
2911             arg = (2.0*M_PI*i*j)/ndata;
2912             sc += data[j]*cos(arg);
2913             ss += data[j]*sin(arg);
2914         }
2915         mod[i] = sc*sc+ss*ss;
2916     }
2917     for (i = 0; i < ndata; i++)
2918     {
2919         if (mod[i] < 1e-7)
2920         {
2921             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2922         }
2923     }
2924 }
2925
2926
2927 static void make_bspline_moduli(splinevec bsp_mod,
2928                                 int nx, int ny, int nz, int order)
2929 {
2930     int nmax = max(nx, max(ny, nz));
2931     real *data, *ddata, *bsp_data;
2932     int i, k, l;
2933     real div;
2934
2935     snew(data, order);
2936     snew(ddata, order);
2937     snew(bsp_data, nmax);
2938
2939     data[order-1] = 0;
2940     data[1]       = 0;
2941     data[0]       = 1;
2942
2943     for (k = 3; k < order; k++)
2944     {
2945         div       = 1.0/(k-1.0);
2946         data[k-1] = 0;
2947         for (l = 1; l < (k-1); l++)
2948         {
2949             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2950         }
2951         data[0] = div*data[0];
2952     }
2953     /* differentiate */
2954     ddata[0] = -data[0];
2955     for (k = 1; k < order; k++)
2956     {
2957         ddata[k] = data[k-1]-data[k];
2958     }
2959     div           = 1.0/(order-1);
2960     data[order-1] = 0;
2961     for (l = 1; l < (order-1); l++)
2962     {
2963         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2964     }
2965     data[0] = div*data[0];
2966
2967     for (i = 0; i < nmax; i++)
2968     {
2969         bsp_data[i] = 0;
2970     }
2971     for (i = 1; i <= order; i++)
2972     {
2973         bsp_data[i] = data[i-1];
2974     }
2975
2976     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2977     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2978     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2979
2980     sfree(data);
2981     sfree(ddata);
2982     sfree(bsp_data);
2983 }
2984
2985
2986 /* Return the P3M optimal influence function */
2987 static double do_p3m_influence(double z, int order)
2988 {
2989     double z2, z4;
2990
2991     z2 = z*z;
2992     z4 = z2*z2;
2993
2994     /* The formula and most constants can be found in:
2995      * Ballenegger et al., JCTC 8, 936 (2012)
2996      */
2997     switch (order)
2998     {
2999         case 2:
3000             return 1.0 - 2.0*z2/3.0;
3001             break;
3002         case 3:
3003             return 1.0 - z2 + 2.0*z4/15.0;
3004             break;
3005         case 4:
3006             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
3007             break;
3008         case 5:
3009             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
3010             break;
3011         case 6:
3012             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
3013             break;
3014         case 7:
3015             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
3016         case 8:
3017             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
3018             break;
3019     }
3020
3021     return 0.0;
3022 }
3023
3024 /* Calculate the P3M B-spline moduli for one dimension */
3025 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
3026 {
3027     double zarg, zai, sinzai, infl;
3028     int    maxk, i;
3029
3030     if (order > 8)
3031     {
3032         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
3033     }
3034
3035     zarg = M_PI/n;
3036
3037     maxk = (n + 1)/2;
3038
3039     for (i = -maxk; i < 0; i++)
3040     {
3041         zai          = zarg*i;
3042         sinzai       = sin(zai);
3043         infl         = do_p3m_influence(sinzai, order);
3044         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
3045     }
3046     bsp_mod[0] = 1.0;
3047     for (i = 1; i < maxk; i++)
3048     {
3049         zai        = zarg*i;
3050         sinzai     = sin(zai);
3051         infl       = do_p3m_influence(sinzai, order);
3052         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
3053     }
3054 }
3055
3056 /* Calculate the P3M B-spline moduli */
3057 static void make_p3m_bspline_moduli(splinevec bsp_mod,
3058                                     int nx, int ny, int nz, int order)
3059 {
3060     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
3061     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
3062     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
3063 }
3064
3065
3066 static void setup_coordinate_communication(pme_atomcomm_t *atc)
3067 {
3068     int nslab, n, i;
3069     int fw, bw;
3070
3071     nslab = atc->nslab;
3072
3073     n = 0;
3074     for (i = 1; i <= nslab/2; i++)
3075     {
3076         fw = (atc->nodeid + i) % nslab;
3077         bw = (atc->nodeid - i + nslab) % nslab;
3078         if (n < nslab - 1)
3079         {
3080             atc->node_dest[n] = fw;
3081             atc->node_src[n]  = bw;
3082             n++;
3083         }
3084         if (n < nslab - 1)
3085         {
3086             atc->node_dest[n] = bw;
3087             atc->node_src[n]  = fw;
3088             n++;
3089         }
3090     }
3091 }
3092
3093 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
3094 {
3095     int thread, i;
3096
3097     if (NULL != log)
3098     {
3099         fprintf(log, "Destroying PME data structures.\n");
3100     }
3101
3102     sfree((*pmedata)->nnx);
3103     sfree((*pmedata)->nny);
3104     sfree((*pmedata)->nnz);
3105
3106     for (i = 0; i < (*pmedata)->ngrids; ++i)
3107     {
3108         pmegrids_destroy(&(*pmedata)->pmegrid[i]);
3109         sfree((*pmedata)->fftgrid[i]);
3110         sfree((*pmedata)->cfftgrid[i]);
3111         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setup[i]);
3112     }
3113
3114     sfree((*pmedata)->lb_buf1);
3115     sfree((*pmedata)->lb_buf2);
3116
3117     for (thread = 0; thread < (*pmedata)->nthread; thread++)
3118     {
3119         free_work(&(*pmedata)->work[thread]);
3120     }
3121     sfree((*pmedata)->work);
3122
3123     sfree(*pmedata);
3124     *pmedata = NULL;
3125
3126     return 0;
3127 }
3128
3129 static int mult_up(int n, int f)
3130 {
3131     return ((n + f - 1)/f)*f;
3132 }
3133
3134
3135 static double pme_load_imbalance(gmx_pme_t pme)
3136 {
3137     int    nma, nmi;
3138     double n1, n2, n3;
3139
3140     nma = pme->nnodes_major;
3141     nmi = pme->nnodes_minor;
3142
3143     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
3144     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
3145     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
3146
3147     /* pme_solve is roughly double the cost of an fft */
3148
3149     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
3150 }
3151
3152 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
3153                           int dimind, gmx_bool bSpread)
3154 {
3155     int nk, k, s, thread;
3156
3157     atc->dimind    = dimind;
3158     atc->nslab     = 1;
3159     atc->nodeid    = 0;
3160     atc->pd_nalloc = 0;
3161 #ifdef GMX_MPI
3162     if (pme->nnodes > 1)
3163     {
3164         atc->mpi_comm = pme->mpi_comm_d[dimind];
3165         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
3166         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
3167     }
3168     if (debug)
3169     {
3170         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
3171     }
3172 #endif
3173
3174     atc->bSpread   = bSpread;
3175     atc->pme_order = pme->pme_order;
3176
3177     if (atc->nslab > 1)
3178     {
3179         /* These three allocations are not required for particle decomp. */
3180         snew(atc->node_dest, atc->nslab);
3181         snew(atc->node_src, atc->nslab);
3182         setup_coordinate_communication(atc);
3183
3184         snew(atc->count_thread, pme->nthread);
3185         for (thread = 0; thread < pme->nthread; thread++)
3186         {
3187             snew(atc->count_thread[thread], atc->nslab);
3188         }
3189         atc->count = atc->count_thread[0];
3190         snew(atc->rcount, atc->nslab);
3191         snew(atc->buf_index, atc->nslab);
3192     }
3193
3194     atc->nthread = pme->nthread;
3195     if (atc->nthread > 1)
3196     {
3197         snew(atc->thread_plist, atc->nthread);
3198     }
3199     snew(atc->spline, atc->nthread);
3200     for (thread = 0; thread < atc->nthread; thread++)
3201     {
3202         if (atc->nthread > 1)
3203         {
3204             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
3205             atc->thread_plist[thread].n += GMX_CACHE_SEP;
3206         }
3207         snew(atc->spline[thread].thread_one, pme->nthread);
3208         atc->spline[thread].thread_one[thread] = 1;
3209     }
3210 }
3211
3212 static void
3213 init_overlap_comm(pme_overlap_t *  ol,
3214                   int              norder,
3215 #ifdef GMX_MPI
3216                   MPI_Comm         comm,
3217 #endif
3218                   int              nnodes,
3219                   int              nodeid,
3220                   int              ndata,
3221                   int              commplainsize)
3222 {
3223     int lbnd, rbnd, maxlr, b, i;
3224     int exten;
3225     int nn, nk;
3226     pme_grid_comm_t *pgc;
3227     gmx_bool bCont;
3228     int fft_start, fft_end, send_index1, recv_index1;
3229 #ifdef GMX_MPI
3230     MPI_Status stat;
3231
3232     ol->mpi_comm = comm;
3233 #endif
3234
3235     ol->nnodes = nnodes;
3236     ol->nodeid = nodeid;
3237
3238     /* Linear translation of the PME grid won't affect reciprocal space
3239      * calculations, so to optimize we only interpolate "upwards",
3240      * which also means we only have to consider overlap in one direction.
3241      * I.e., particles on this node might also be spread to grid indices
3242      * that belong to higher nodes (modulo nnodes)
3243      */
3244
3245     snew(ol->s2g0, ol->nnodes+1);
3246     snew(ol->s2g1, ol->nnodes);
3247     if (debug)
3248     {
3249         fprintf(debug, "PME slab boundaries:");
3250     }
3251     for (i = 0; i < nnodes; i++)
3252     {
3253         /* s2g0 the local interpolation grid start.
3254          * s2g1 the local interpolation grid end.
3255          * Because grid overlap communication only goes forward,
3256          * the grid the slabs for fft's should be rounded down.
3257          */
3258         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
3259         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
3260
3261         if (debug)
3262         {
3263             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
3264         }
3265     }
3266     ol->s2g0[nnodes] = ndata;
3267     if (debug)
3268     {
3269         fprintf(debug, "\n");
3270     }
3271
3272     /* Determine with how many nodes we need to communicate the grid overlap */
3273     b = 0;
3274     do
3275     {
3276         b++;
3277         bCont = FALSE;
3278         for (i = 0; i < nnodes; i++)
3279         {
3280             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
3281                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
3282             {
3283                 bCont = TRUE;
3284             }
3285         }
3286     }
3287     while (bCont && b < nnodes);
3288     ol->noverlap_nodes = b - 1;
3289
3290     snew(ol->send_id, ol->noverlap_nodes);
3291     snew(ol->recv_id, ol->noverlap_nodes);
3292     for (b = 0; b < ol->noverlap_nodes; b++)
3293     {
3294         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
3295         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
3296     }
3297     snew(ol->comm_data, ol->noverlap_nodes);
3298
3299     ol->send_size = 0;
3300     for (b = 0; b < ol->noverlap_nodes; b++)
3301     {
3302         pgc = &ol->comm_data[b];
3303         /* Send */
3304         fft_start        = ol->s2g0[ol->send_id[b]];
3305         fft_end          = ol->s2g0[ol->send_id[b]+1];
3306         if (ol->send_id[b] < nodeid)
3307         {
3308             fft_start += ndata;
3309             fft_end   += ndata;
3310         }
3311         send_index1       = ol->s2g1[nodeid];
3312         send_index1       = min(send_index1, fft_end);
3313         pgc->send_index0  = fft_start;
3314         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
3315         ol->send_size    += pgc->send_nindex;
3316
3317         /* We always start receiving to the first index of our slab */
3318         fft_start        = ol->s2g0[ol->nodeid];
3319         fft_end          = ol->s2g0[ol->nodeid+1];
3320         recv_index1      = ol->s2g1[ol->recv_id[b]];
3321         if (ol->recv_id[b] > nodeid)
3322         {
3323             recv_index1 -= ndata;
3324         }
3325         recv_index1      = min(recv_index1, fft_end);
3326         pgc->recv_index0 = fft_start;
3327         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
3328     }
3329
3330 #ifdef GMX_MPI
3331     /* Communicate the buffer sizes to receive */
3332     for (b = 0; b < ol->noverlap_nodes; b++)
3333     {
3334         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
3335                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
3336                      ol->mpi_comm, &stat);
3337     }
3338 #endif
3339
3340     /* For non-divisible grid we need pme_order iso pme_order-1 */
3341     snew(ol->sendbuf, norder*commplainsize);
3342     snew(ol->recvbuf, norder*commplainsize);
3343 }
3344
3345 static void
3346 make_gridindex5_to_localindex(int n, int local_start, int local_range,
3347                               int **global_to_local,
3348                               real **fraction_shift)
3349 {
3350     int i;
3351     int * gtl;
3352     real * fsh;
3353
3354     snew(gtl, 5*n);
3355     snew(fsh, 5*n);
3356     for (i = 0; (i < 5*n); i++)
3357     {
3358         /* Determine the global to local grid index */
3359         gtl[i] = (i - local_start + n) % n;
3360         /* For coordinates that fall within the local grid the fraction
3361          * is correct, we don't need to shift it.
3362          */
3363         fsh[i] = 0;
3364         if (local_range < n)
3365         {
3366             /* Due to rounding issues i could be 1 beyond the lower or
3367              * upper boundary of the local grid. Correct the index for this.
3368              * If we shift the index, we need to shift the fraction by
3369              * the same amount in the other direction to not affect
3370              * the weights.
3371              * Note that due to this shifting the weights at the end of
3372              * the spline might change, but that will only involve values
3373              * between zero and values close to the precision of a real,
3374              * which is anyhow the accuracy of the whole mesh calculation.
3375              */
3376             /* With local_range=0 we should not change i=local_start */
3377             if (i % n != local_start)
3378             {
3379                 if (gtl[i] == n-1)
3380                 {
3381                     gtl[i] = 0;
3382                     fsh[i] = -1;
3383                 }
3384                 else if (gtl[i] == local_range)
3385                 {
3386                     gtl[i] = local_range - 1;
3387                     fsh[i] = 1;
3388                 }
3389             }
3390         }
3391     }
3392
3393     *global_to_local = gtl;
3394     *fraction_shift  = fsh;
3395 }
3396
3397 static pme_spline_work_t *make_pme_spline_work(int gmx_unused order)
3398 {
3399     pme_spline_work_t *work;
3400
3401 #ifdef PME_SIMD4_SPREAD_GATHER
3402     real         tmp[12], *tmp_aligned;
3403     gmx_simd4_pr zero_S;
3404     gmx_simd4_pr real_mask_S0, real_mask_S1;
3405     int          of, i;
3406
3407     snew_aligned(work, 1, SIMD4_ALIGNMENT);
3408
3409     tmp_aligned = gmx_simd4_align_real(tmp);
3410
3411     zero_S = gmx_simd4_setzero_pr();
3412
3413     /* Generate bit masks to mask out the unused grid entries,
3414      * as we only operate on order of the 8 grid entries that are
3415      * load into 2 SIMD registers.
3416      */
3417     for (of = 0; of < 8-(order-1); of++)
3418     {
3419         for (i = 0; i < 8; i++)
3420         {
3421             tmp_aligned[i] = (i >= of && i < of+order ? -1.0 : 1.0);
3422         }
3423         real_mask_S0      = gmx_simd4_load_pr(tmp_aligned);
3424         real_mask_S1      = gmx_simd4_load_pr(tmp_aligned+4);
3425         work->mask_S0[of] = gmx_simd4_cmplt_pr(real_mask_S0, zero_S);
3426         work->mask_S1[of] = gmx_simd4_cmplt_pr(real_mask_S1, zero_S);
3427     }
3428 #else
3429     work = NULL;
3430 #endif
3431
3432     return work;
3433 }
3434
3435 void gmx_pme_check_restrictions(int pme_order,
3436                                 int nkx, int nky, int nkz,
3437                                 int nnodes_major,
3438                                 int nnodes_minor,
3439                                 gmx_bool bUseThreads,
3440                                 gmx_bool bFatal,
3441                                 gmx_bool *bValidSettings)
3442 {
3443     if (pme_order > PME_ORDER_MAX)
3444     {
3445         if (!bFatal)
3446         {
3447             *bValidSettings = FALSE;
3448             return;
3449         }
3450         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3451                   pme_order, PME_ORDER_MAX);
3452     }
3453
3454     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3455         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3456         nkz <= pme_order)
3457     {
3458         if (!bFatal)
3459         {
3460             *bValidSettings = FALSE;
3461             return;
3462         }
3463         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3464                   pme_order);
3465     }
3466
3467     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3468      * We only allow multiple communication pulses in dim 1, not in dim 0.
3469      */
3470     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3471                         nkx != nnodes_major*(pme_order - 1)))
3472     {
3473         if (!bFatal)
3474         {
3475             *bValidSettings = FALSE;
3476             return;
3477         }
3478         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3479                   nkx/(double)nnodes_major, pme_order);
3480     }
3481
3482     if (bValidSettings != NULL)
3483     {
3484         *bValidSettings = TRUE;
3485     }
3486
3487     return;
3488 }
3489
3490 int gmx_pme_init(gmx_pme_t *         pmedata,
3491                  t_commrec *         cr,
3492                  int                 nnodes_major,
3493                  int                 nnodes_minor,
3494                  t_inputrec *        ir,
3495                  int                 homenr,
3496                  gmx_bool            bFreeEnergy_q,
3497                  gmx_bool            bFreeEnergy_lj,
3498                  gmx_bool            bReproducible,
3499                  int                 nthread)
3500 {
3501     gmx_pme_t pme = NULL;
3502
3503     int  use_threads, sum_use_threads, i;
3504     ivec ndata;
3505
3506     if (debug)
3507     {
3508         fprintf(debug, "Creating PME data structures.\n");
3509     }
3510     snew(pme, 1);
3511
3512     pme->redist_init         = FALSE;
3513     pme->sum_qgrid_tmp       = NULL;
3514     pme->sum_qgrid_dd_tmp    = NULL;
3515     pme->buf_nalloc          = 0;
3516     pme->redist_buf_nalloc   = 0;
3517
3518     pme->nnodes              = 1;
3519     pme->bPPnode             = TRUE;
3520
3521     pme->nnodes_major        = nnodes_major;
3522     pme->nnodes_minor        = nnodes_minor;
3523
3524 #ifdef GMX_MPI
3525     if (nnodes_major*nnodes_minor > 1)
3526     {
3527         pme->mpi_comm = cr->mpi_comm_mygroup;
3528
3529         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3530         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3531         if (pme->nnodes != nnodes_major*nnodes_minor)
3532         {
3533             gmx_incons("PME node count mismatch");
3534         }
3535     }
3536     else
3537     {
3538         pme->mpi_comm = MPI_COMM_NULL;
3539     }
3540 #endif
3541
3542     if (pme->nnodes == 1)
3543     {
3544 #ifdef GMX_MPI
3545         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3546         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3547 #endif
3548         pme->ndecompdim   = 0;
3549         pme->nodeid_major = 0;
3550         pme->nodeid_minor = 0;
3551 #ifdef GMX_MPI
3552         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3553 #endif
3554     }
3555     else
3556     {
3557         if (nnodes_minor == 1)
3558         {
3559 #ifdef GMX_MPI
3560             pme->mpi_comm_d[0] = pme->mpi_comm;
3561             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3562 #endif
3563             pme->ndecompdim   = 1;
3564             pme->nodeid_major = pme->nodeid;
3565             pme->nodeid_minor = 0;
3566
3567         }
3568         else if (nnodes_major == 1)
3569         {
3570 #ifdef GMX_MPI
3571             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3572             pme->mpi_comm_d[1] = pme->mpi_comm;
3573 #endif
3574             pme->ndecompdim   = 1;
3575             pme->nodeid_major = 0;
3576             pme->nodeid_minor = pme->nodeid;
3577         }
3578         else
3579         {
3580             if (pme->nnodes % nnodes_major != 0)
3581             {
3582                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3583             }
3584             pme->ndecompdim = 2;
3585
3586 #ifdef GMX_MPI
3587             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3588                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3589             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3590                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3591
3592             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3593             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3594             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3595             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3596 #endif
3597         }
3598         pme->bPPnode = (cr->duty & DUTY_PP);
3599     }
3600
3601     pme->nthread = nthread;
3602
3603     /* Check if any of the PME MPI ranks uses threads */
3604     use_threads = (pme->nthread > 1 ? 1 : 0);
3605 #ifdef GMX_MPI
3606     if (pme->nnodes > 1)
3607     {
3608         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3609                       MPI_SUM, pme->mpi_comm);
3610     }
3611     else
3612 #endif
3613     {
3614         sum_use_threads = use_threads;
3615     }
3616     pme->bUseThreads = (sum_use_threads > 0);
3617
3618     if (ir->ePBC == epbcSCREW)
3619     {
3620         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3621     }
3622
3623     pme->bFEP_q      = ((ir->efep != efepNO) && bFreeEnergy_q);
3624     pme->bFEP_lj     = ((ir->efep != efepNO) && bFreeEnergy_lj);
3625     pme->bFEP        = (pme->bFEP_q || pme->bFEP_lj);
3626     pme->nkx         = ir->nkx;
3627     pme->nky         = ir->nky;
3628     pme->nkz         = ir->nkz;
3629     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3630     pme->pme_order   = ir->pme_order;
3631     pme->epsilon_r   = ir->epsilon_r;
3632
3633     /* If we violate restrictions, generate a fatal error here */
3634     gmx_pme_check_restrictions(pme->pme_order,
3635                                pme->nkx, pme->nky, pme->nkz,
3636                                pme->nnodes_major,
3637                                pme->nnodes_minor,
3638                                pme->bUseThreads,
3639                                TRUE,
3640                                NULL);
3641
3642     if (pme->nnodes > 1)
3643     {
3644         double imbal;
3645
3646 #ifdef GMX_MPI
3647         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3648         MPI_Type_commit(&(pme->rvec_mpi));
3649 #endif
3650
3651         /* Note that the charge spreading and force gathering, which usually
3652          * takes about the same amount of time as FFT+solve_pme,
3653          * is always fully load balanced
3654          * (unless the charge distribution is inhomogeneous).
3655          */
3656
3657         imbal = pme_load_imbalance(pme);
3658         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3659         {
3660             fprintf(stderr,
3661                     "\n"
3662                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3663                     "      For optimal PME load balancing\n"
3664                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3665                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3666                     "\n",
3667                     (int)((imbal-1)*100 + 0.5),
3668                     pme->nkx, pme->nky, pme->nnodes_major,
3669                     pme->nky, pme->nkz, pme->nnodes_minor);
3670         }
3671     }
3672
3673     /* For non-divisible grid we need pme_order iso pme_order-1 */
3674     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3675      * y is always copied through a buffer: we don't need padding in z,
3676      * but we do need the overlap in x because of the communication order.
3677      */
3678     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3679 #ifdef GMX_MPI
3680                       pme->mpi_comm_d[0],
3681 #endif
3682                       pme->nnodes_major, pme->nodeid_major,
3683                       pme->nkx,
3684                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3685
3686     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3687      * We do this with an offset buffer of equal size, so we need to allocate
3688      * extra for the offset. That's what the (+1)*pme->nkz is for.
3689      */
3690     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3691 #ifdef GMX_MPI
3692                       pme->mpi_comm_d[1],
3693 #endif
3694                       pme->nnodes_minor, pme->nodeid_minor,
3695                       pme->nky,
3696                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3697
3698     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3699      * Note that gmx_pme_check_restrictions checked for this already.
3700      */
3701     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3702     {
3703         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3704     }
3705
3706     snew(pme->bsp_mod[XX], pme->nkx);
3707     snew(pme->bsp_mod[YY], pme->nky);
3708     snew(pme->bsp_mod[ZZ], pme->nkz);
3709
3710     /* The required size of the interpolation grid, including overlap.
3711      * The allocated size (pmegrid_n?) might be slightly larger.
3712      */
3713     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3714         pme->overlap[0].s2g0[pme->nodeid_major];
3715     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3716         pme->overlap[1].s2g0[pme->nodeid_minor];
3717     pme->pmegrid_nz_base = pme->nkz;
3718     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3719     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3720
3721     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3722     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3723     pme->pmegrid_start_iz = 0;
3724
3725     make_gridindex5_to_localindex(pme->nkx,
3726                                   pme->pmegrid_start_ix,
3727                                   pme->pmegrid_nx - (pme->pme_order-1),
3728                                   &pme->nnx, &pme->fshx);
3729     make_gridindex5_to_localindex(pme->nky,
3730                                   pme->pmegrid_start_iy,
3731                                   pme->pmegrid_ny - (pme->pme_order-1),
3732                                   &pme->nny, &pme->fshy);
3733     make_gridindex5_to_localindex(pme->nkz,
3734                                   pme->pmegrid_start_iz,
3735                                   pme->pmegrid_nz_base,
3736                                   &pme->nnz, &pme->fshz);
3737
3738     pme->spline_work = make_pme_spline_work(pme->pme_order);
3739
3740     ndata[0]    = pme->nkx;
3741     ndata[1]    = pme->nky;
3742     ndata[2]    = pme->nkz;
3743     pme->ngrids = ((ir->ljpme_combination_rule == eljpmeLB) ? DO_Q_AND_LJ_LB : DO_Q_AND_LJ);
3744     snew(pme->fftgrid, pme->ngrids);
3745     snew(pme->cfftgrid, pme->ngrids);
3746     snew(pme->pfft_setup, pme->ngrids);
3747
3748     for (i = 0; i < pme->ngrids; ++i)
3749     {
3750         if (((ir->ljpme_combination_rule == eljpmeLB) && i >= 2) || i % 2 == 0 || bFreeEnergy_q || bFreeEnergy_lj)
3751         {
3752             pmegrids_init(&pme->pmegrid[i],
3753                           pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3754                           pme->pmegrid_nz_base,
3755                           pme->pme_order,
3756                           pme->bUseThreads,
3757                           pme->nthread,
3758                           pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3759                           pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3760             /* This routine will allocate the grid data to fit the FFTs */
3761             gmx_parallel_3dfft_init(&pme->pfft_setup[i], ndata,
3762                                     &pme->fftgrid[i], &pme->cfftgrid[i],
3763                                     pme->mpi_comm_d,
3764                                     bReproducible, pme->nthread);
3765
3766         }
3767     }
3768
3769     if (!pme->bP3M)
3770     {
3771         /* Use plain SPME B-spline interpolation */
3772         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3773     }
3774     else
3775     {
3776         /* Use the P3M grid-optimized influence function */
3777         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3778     }
3779
3780     /* Use atc[0] for spreading */
3781     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3782     if (pme->ndecompdim >= 2)
3783     {
3784         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3785     }
3786
3787     if (pme->nnodes == 1)
3788     {
3789         pme->atc[0].n = homenr;
3790         pme_realloc_atomcomm_things(&pme->atc[0]);
3791     }
3792
3793     pme->lb_buf1       = NULL;
3794     pme->lb_buf2       = NULL;
3795     pme->lb_buf_nalloc = 0;
3796
3797     {
3798         int thread;
3799
3800         /* Use fft5d, order after FFT is y major, z, x minor */
3801
3802         snew(pme->work, pme->nthread);
3803         for (thread = 0; thread < pme->nthread; thread++)
3804         {
3805             realloc_work(&pme->work[thread], pme->nkx);
3806         }
3807     }
3808
3809     *pmedata = pme;
3810
3811     return 0;
3812 }
3813
3814 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3815 {
3816     int d, t;
3817
3818     for (d = 0; d < DIM; d++)
3819     {
3820         if (new->grid.n[d] > old->grid.n[d])
3821         {
3822             return;
3823         }
3824     }
3825
3826     sfree_aligned(new->grid.grid);
3827     new->grid.grid = old->grid.grid;
3828
3829     if (new->grid_th != NULL && new->nthread == old->nthread)
3830     {
3831         sfree_aligned(new->grid_all);
3832         for (t = 0; t < new->nthread; t++)
3833         {
3834             new->grid_th[t].grid = old->grid_th[t].grid;
3835         }
3836     }
3837 }
3838
3839 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3840                    t_commrec *         cr,
3841                    gmx_pme_t           pme_src,
3842                    const t_inputrec *  ir,
3843                    ivec                grid_size)
3844 {
3845     t_inputrec irc;
3846     int homenr;
3847     int ret;
3848
3849     irc     = *ir;
3850     irc.nkx = grid_size[XX];
3851     irc.nky = grid_size[YY];
3852     irc.nkz = grid_size[ZZ];
3853
3854     if (pme_src->nnodes == 1)
3855     {
3856         homenr = pme_src->atc[0].n;
3857     }
3858     else
3859     {
3860         homenr = -1;
3861     }
3862
3863     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3864                        &irc, homenr, pme_src->bFEP_q, pme_src->bFEP_lj, FALSE, pme_src->nthread);
3865
3866     if (ret == 0)
3867     {
3868         /* We can easily reuse the allocated pme grids in pme_src */
3869         reuse_pmegrids(&pme_src->pmegrid[PME_GRID_QA], &(*pmedata)->pmegrid[PME_GRID_QA]);
3870         /* We would like to reuse the fft grids, but that's harder */
3871     }
3872
3873     return ret;
3874 }
3875
3876
3877 static void copy_local_grid(gmx_pme_t pme,
3878                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3879 {
3880     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3881     int  fft_my, fft_mz;
3882     int  nsx, nsy, nsz;
3883     ivec nf;
3884     int  offx, offy, offz, x, y, z, i0, i0t;
3885     int  d;
3886     pmegrid_t *pmegrid;
3887     real *grid_th;
3888
3889     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
3890                                    local_fft_ndata,
3891                                    local_fft_offset,
3892                                    local_fft_size);
3893     fft_my = local_fft_size[YY];
3894     fft_mz = local_fft_size[ZZ];
3895
3896     pmegrid = &pmegrids->grid_th[thread];
3897
3898     nsx = pmegrid->s[XX];
3899     nsy = pmegrid->s[YY];
3900     nsz = pmegrid->s[ZZ];
3901
3902     for (d = 0; d < DIM; d++)
3903     {
3904         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3905                     local_fft_ndata[d] - pmegrid->offset[d]);
3906     }
3907
3908     offx = pmegrid->offset[XX];
3909     offy = pmegrid->offset[YY];
3910     offz = pmegrid->offset[ZZ];
3911
3912     /* Directly copy the non-overlapping parts of the local grids.
3913      * This also initializes the full grid.
3914      */
3915     grid_th = pmegrid->grid;
3916     for (x = 0; x < nf[XX]; x++)
3917     {
3918         for (y = 0; y < nf[YY]; y++)
3919         {
3920             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3921             i0t = (x*nsy + y)*nsz;
3922             for (z = 0; z < nf[ZZ]; z++)
3923             {
3924                 fftgrid[i0+z] = grid_th[i0t+z];
3925             }
3926         }
3927     }
3928 }
3929
3930 static void
3931 reduce_threadgrid_overlap(gmx_pme_t pme,
3932                           const pmegrids_t *pmegrids, int thread,
3933                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3934 {
3935     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3936     int  fft_nx, fft_ny, fft_nz;
3937     int  fft_my, fft_mz;
3938     int  buf_my = -1;
3939     int  nsx, nsy, nsz;
3940     ivec ne;
3941     int  offx, offy, offz, x, y, z, i0, i0t;
3942     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3943     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3944     gmx_bool bCommX, bCommY;
3945     int  d;
3946     int  thread_f;
3947     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3948     const real *grid_th;
3949     real *commbuf = NULL;
3950
3951     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
3952                                    local_fft_ndata,
3953                                    local_fft_offset,
3954                                    local_fft_size);
3955     fft_nx = local_fft_ndata[XX];
3956     fft_ny = local_fft_ndata[YY];
3957     fft_nz = local_fft_ndata[ZZ];
3958
3959     fft_my = local_fft_size[YY];
3960     fft_mz = local_fft_size[ZZ];
3961
3962     /* This routine is called when all thread have finished spreading.
3963      * Here each thread sums grid contributions calculated by other threads
3964      * to the thread local grid volume.
3965      * To minimize the number of grid copying operations,
3966      * this routines sums immediately from the pmegrid to the fftgrid.
3967      */
3968
3969     /* Determine which part of the full node grid we should operate on,
3970      * this is our thread local part of the full grid.
3971      */
3972     pmegrid = &pmegrids->grid_th[thread];
3973
3974     for (d = 0; d < DIM; d++)
3975     {
3976         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3977                     local_fft_ndata[d]);
3978     }
3979
3980     offx = pmegrid->offset[XX];
3981     offy = pmegrid->offset[YY];
3982     offz = pmegrid->offset[ZZ];
3983
3984
3985     bClearBufX  = TRUE;
3986     bClearBufY  = TRUE;
3987     bClearBufXY = TRUE;
3988
3989     /* Now loop over all the thread data blocks that contribute
3990      * to the grid region we (our thread) are operating on.
3991      */
3992     /* Note that ffy_nx/y is equal to the number of grid points
3993      * between the first point of our node grid and the one of the next node.
3994      */
3995     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3996     {
3997         fx     = pmegrid->ci[XX] + sx;
3998         ox     = 0;
3999         bCommX = FALSE;
4000         if (fx < 0)
4001         {
4002             fx    += pmegrids->nc[XX];
4003             ox    -= fft_nx;
4004             bCommX = (pme->nnodes_major > 1);
4005         }
4006         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
4007         ox       += pmegrid_g->offset[XX];
4008         if (!bCommX)
4009         {
4010             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
4011         }
4012         else
4013         {
4014             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
4015         }
4016
4017         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
4018         {
4019             fy     = pmegrid->ci[YY] + sy;
4020             oy     = 0;
4021             bCommY = FALSE;
4022             if (fy < 0)
4023             {
4024                 fy    += pmegrids->nc[YY];
4025                 oy    -= fft_ny;
4026                 bCommY = (pme->nnodes_minor > 1);
4027             }
4028             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
4029             oy       += pmegrid_g->offset[YY];
4030             if (!bCommY)
4031             {
4032                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
4033             }
4034             else
4035             {
4036                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
4037             }
4038
4039             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
4040             {
4041                 fz = pmegrid->ci[ZZ] + sz;
4042                 oz = 0;
4043                 if (fz < 0)
4044                 {
4045                     fz += pmegrids->nc[ZZ];
4046                     oz -= fft_nz;
4047                 }
4048                 pmegrid_g = &pmegrids->grid_th[fz];
4049                 oz       += pmegrid_g->offset[ZZ];
4050                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
4051
4052                 if (sx == 0 && sy == 0 && sz == 0)
4053                 {
4054                     /* We have already added our local contribution
4055                      * before calling this routine, so skip it here.
4056                      */
4057                     continue;
4058                 }
4059
4060                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
4061
4062                 pmegrid_f = &pmegrids->grid_th[thread_f];
4063
4064                 grid_th = pmegrid_f->grid;
4065
4066                 nsx = pmegrid_f->s[XX];
4067                 nsy = pmegrid_f->s[YY];
4068                 nsz = pmegrid_f->s[ZZ];
4069
4070 #ifdef DEBUG_PME_REDUCE
4071                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
4072                        pme->nodeid, thread, thread_f,
4073                        pme->pmegrid_start_ix,
4074                        pme->pmegrid_start_iy,
4075                        pme->pmegrid_start_iz,
4076                        sx, sy, sz,
4077                        offx-ox, tx1-ox, offx, tx1,
4078                        offy-oy, ty1-oy, offy, ty1,
4079                        offz-oz, tz1-oz, offz, tz1);
4080 #endif
4081
4082                 if (!(bCommX || bCommY))
4083                 {
4084                     /* Copy from the thread local grid to the node grid */
4085                     for (x = offx; x < tx1; x++)
4086                     {
4087                         for (y = offy; y < ty1; y++)
4088                         {
4089                             i0  = (x*fft_my + y)*fft_mz;
4090                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4091                             for (z = offz; z < tz1; z++)
4092                             {
4093                                 fftgrid[i0+z] += grid_th[i0t+z];
4094                             }
4095                         }
4096                     }
4097                 }
4098                 else
4099                 {
4100                     /* The order of this conditional decides
4101                      * where the corner volume gets stored with x+y decomp.
4102                      */
4103                     if (bCommY)
4104                     {
4105                         commbuf = commbuf_y;
4106                         buf_my  = ty1 - offy;
4107                         if (bCommX)
4108                         {
4109                             /* We index commbuf modulo the local grid size */
4110                             commbuf += buf_my*fft_nx*fft_nz;
4111
4112                             bClearBuf   = bClearBufXY;
4113                             bClearBufXY = FALSE;
4114                         }
4115                         else
4116                         {
4117                             bClearBuf  = bClearBufY;
4118                             bClearBufY = FALSE;
4119                         }
4120                     }
4121                     else
4122                     {
4123                         commbuf    = commbuf_x;
4124                         buf_my     = fft_ny;
4125                         bClearBuf  = bClearBufX;
4126                         bClearBufX = FALSE;
4127                     }
4128
4129                     /* Copy to the communication buffer */
4130                     for (x = offx; x < tx1; x++)
4131                     {
4132                         for (y = offy; y < ty1; y++)
4133                         {
4134                             i0  = (x*buf_my + y)*fft_nz;
4135                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
4136
4137                             if (bClearBuf)
4138                             {
4139                                 /* First access of commbuf, initialize it */
4140                                 for (z = offz; z < tz1; z++)
4141                                 {
4142                                     commbuf[i0+z]  = grid_th[i0t+z];
4143                                 }
4144                             }
4145                             else
4146                             {
4147                                 for (z = offz; z < tz1; z++)
4148                                 {
4149                                     commbuf[i0+z] += grid_th[i0t+z];
4150                                 }
4151                             }
4152                         }
4153                     }
4154                 }
4155             }
4156         }
4157     }
4158 }
4159
4160
4161 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
4162 {
4163     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4164     pme_overlap_t *overlap;
4165     int  send_index0, send_nindex;
4166     int  recv_nindex;
4167 #ifdef GMX_MPI
4168     MPI_Status stat;
4169 #endif
4170     int  send_size_y, recv_size_y;
4171     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
4172     real *sendptr, *recvptr;
4173     int  x, y, z, indg, indb;
4174
4175     /* Note that this routine is only used for forward communication.
4176      * Since the force gathering, unlike the charge spreading,
4177      * can be trivially parallelized over the particles,
4178      * the backwards process is much simpler and can use the "old"
4179      * communication setup.
4180      */
4181
4182     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4183                                    local_fft_ndata,
4184                                    local_fft_offset,
4185                                    local_fft_size);
4186
4187     if (pme->nnodes_minor > 1)
4188     {
4189         /* Major dimension */
4190         overlap = &pme->overlap[1];
4191
4192         if (pme->nnodes_major > 1)
4193         {
4194             size_yx = pme->overlap[0].comm_data[0].send_nindex;
4195         }
4196         else
4197         {
4198             size_yx = 0;
4199         }
4200         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
4201
4202         send_size_y = overlap->send_size;
4203
4204         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
4205         {
4206             send_id       = overlap->send_id[ipulse];
4207             recv_id       = overlap->recv_id[ipulse];
4208             send_index0   =
4209                 overlap->comm_data[ipulse].send_index0 -
4210                 overlap->comm_data[0].send_index0;
4211             send_nindex   = overlap->comm_data[ipulse].send_nindex;
4212             /* We don't use recv_index0, as we always receive starting at 0 */
4213             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4214             recv_size_y   = overlap->comm_data[ipulse].recv_size;
4215
4216             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
4217             recvptr = overlap->recvbuf;
4218
4219 #ifdef GMX_MPI
4220             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
4221                          send_id, ipulse,
4222                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
4223                          recv_id, ipulse,
4224                          overlap->mpi_comm, &stat);
4225 #endif
4226
4227             for (x = 0; x < local_fft_ndata[XX]; x++)
4228             {
4229                 for (y = 0; y < recv_nindex; y++)
4230                 {
4231                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
4232                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
4233                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
4234                     {
4235                         fftgrid[indg+z] += recvptr[indb+z];
4236                     }
4237                 }
4238             }
4239
4240             if (pme->nnodes_major > 1)
4241             {
4242                 /* Copy from the received buffer to the send buffer for dim 0 */
4243                 sendptr = pme->overlap[0].sendbuf;
4244                 for (x = 0; x < size_yx; x++)
4245                 {
4246                     for (y = 0; y < recv_nindex; y++)
4247                     {
4248                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4249                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
4250                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
4251                         {
4252                             sendptr[indg+z] += recvptr[indb+z];
4253                         }
4254                     }
4255                 }
4256             }
4257         }
4258     }
4259
4260     /* We only support a single pulse here.
4261      * This is not a severe limitation, as this code is only used
4262      * with OpenMP and with OpenMP the (PME) domains can be larger.
4263      */
4264     if (pme->nnodes_major > 1)
4265     {
4266         /* Major dimension */
4267         overlap = &pme->overlap[0];
4268
4269         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
4270         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
4271
4272         ipulse = 0;
4273
4274         send_id       = overlap->send_id[ipulse];
4275         recv_id       = overlap->recv_id[ipulse];
4276         send_nindex   = overlap->comm_data[ipulse].send_nindex;
4277         /* We don't use recv_index0, as we always receive starting at 0 */
4278         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
4279
4280         sendptr = overlap->sendbuf;
4281         recvptr = overlap->recvbuf;
4282
4283         if (debug != NULL)
4284         {
4285             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
4286                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
4287         }
4288
4289 #ifdef GMX_MPI
4290         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
4291                      send_id, ipulse,
4292                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
4293                      recv_id, ipulse,
4294                      overlap->mpi_comm, &stat);
4295 #endif
4296
4297         for (x = 0; x < recv_nindex; x++)
4298         {
4299             for (y = 0; y < local_fft_ndata[YY]; y++)
4300             {
4301                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
4302                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
4303                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
4304                 {
4305                     fftgrid[indg+z] += recvptr[indb+z];
4306                 }
4307             }
4308         }
4309     }
4310 }
4311
4312
4313 static void spread_on_grid(gmx_pme_t pme,
4314                            pme_atomcomm_t *atc, pmegrids_t *grids,
4315                            gmx_bool bCalcSplines, gmx_bool bSpread,
4316                            real *fftgrid, gmx_bool bDoSplines )
4317 {
4318     int nthread, thread;
4319 #ifdef PME_TIME_THREADS
4320     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
4321     static double cs1     = 0, cs2 = 0, cs3 = 0;
4322     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
4323     static int cnt        = 0;
4324 #endif
4325
4326     nthread = pme->nthread;
4327     assert(nthread > 0);
4328
4329 #ifdef PME_TIME_THREADS
4330     c1 = omp_cyc_start();
4331 #endif
4332     if (bCalcSplines)
4333     {
4334 #pragma omp parallel for num_threads(nthread) schedule(static)
4335         for (thread = 0; thread < nthread; thread++)
4336         {
4337             int start, end;
4338
4339             start = atc->n* thread   /nthread;
4340             end   = atc->n*(thread+1)/nthread;
4341
4342             /* Compute fftgrid index for all atoms,
4343              * with help of some extra variables.
4344              */
4345             calc_interpolation_idx(pme, atc, start, end, thread);
4346         }
4347     }
4348 #ifdef PME_TIME_THREADS
4349     c1   = omp_cyc_end(c1);
4350     cs1 += (double)c1;
4351 #endif
4352
4353 #ifdef PME_TIME_THREADS
4354     c2 = omp_cyc_start();
4355 #endif
4356 #pragma omp parallel for num_threads(nthread) schedule(static)
4357     for (thread = 0; thread < nthread; thread++)
4358     {
4359         splinedata_t *spline;
4360         pmegrid_t *grid = NULL;
4361
4362         /* make local bsplines  */
4363         if (grids == NULL || !pme->bUseThreads)
4364         {
4365             spline = &atc->spline[0];
4366
4367             spline->n = atc->n;
4368
4369             if (bSpread)
4370             {
4371                 grid = &grids->grid;
4372             }
4373         }
4374         else
4375         {
4376             spline = &atc->spline[thread];
4377
4378             if (grids->nthread == 1)
4379             {
4380                 /* One thread, we operate on all charges */
4381                 spline->n = atc->n;
4382             }
4383             else
4384             {
4385                 /* Get the indices our thread should operate on */
4386                 make_thread_local_ind(atc, thread, spline);
4387             }
4388
4389             grid = &grids->grid_th[thread];
4390         }
4391
4392         if (bCalcSplines)
4393         {
4394             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4395                           atc->fractx, spline->n, spline->ind, atc->q, bDoSplines);
4396         }
4397
4398         if (bSpread)
4399         {
4400             /* put local atoms on grid. */
4401 #ifdef PME_TIME_SPREAD
4402             ct1a = omp_cyc_start();
4403 #endif
4404             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4405
4406             if (pme->bUseThreads)
4407             {
4408                 copy_local_grid(pme, grids, thread, fftgrid);
4409             }
4410 #ifdef PME_TIME_SPREAD
4411             ct1a          = omp_cyc_end(ct1a);
4412             cs1a[thread] += (double)ct1a;
4413 #endif
4414         }
4415     }
4416 #ifdef PME_TIME_THREADS
4417     c2   = omp_cyc_end(c2);
4418     cs2 += (double)c2;
4419 #endif
4420
4421     if (bSpread && pme->bUseThreads)
4422     {
4423 #ifdef PME_TIME_THREADS
4424         c3 = omp_cyc_start();
4425 #endif
4426 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4427         for (thread = 0; thread < grids->nthread; thread++)
4428         {
4429             reduce_threadgrid_overlap(pme, grids, thread,
4430                                       fftgrid,
4431                                       pme->overlap[0].sendbuf,
4432                                       pme->overlap[1].sendbuf);
4433         }
4434 #ifdef PME_TIME_THREADS
4435         c3   = omp_cyc_end(c3);
4436         cs3 += (double)c3;
4437 #endif
4438
4439         if (pme->nnodes > 1)
4440         {
4441             /* Communicate the overlapping part of the fftgrid.
4442              * For this communication call we need to check pme->bUseThreads
4443              * to have all ranks communicate here, regardless of pme->nthread.
4444              */
4445             sum_fftgrid_dd(pme, fftgrid);
4446         }
4447     }
4448
4449 #ifdef PME_TIME_THREADS
4450     cnt++;
4451     if (cnt % 20 == 0)
4452     {
4453         printf("idx %.2f spread %.2f red %.2f",
4454                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4455 #ifdef PME_TIME_SPREAD
4456         for (thread = 0; thread < nthread; thread++)
4457         {
4458             printf(" %.2f", cs1a[thread]*1e-9);
4459         }
4460 #endif
4461         printf("\n");
4462     }
4463 #endif
4464 }
4465
4466
4467 static void dump_grid(FILE *fp,
4468                       int sx, int sy, int sz, int nx, int ny, int nz,
4469                       int my, int mz, const real *g)
4470 {
4471     int x, y, z;
4472
4473     for (x = 0; x < nx; x++)
4474     {
4475         for (y = 0; y < ny; y++)
4476         {
4477             for (z = 0; z < nz; z++)
4478             {
4479                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4480                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4481             }
4482         }
4483     }
4484 }
4485
4486 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4487 {
4488     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4489
4490     gmx_parallel_3dfft_real_limits(pme->pfft_setup[PME_GRID_QA],
4491                                    local_fft_ndata,
4492                                    local_fft_offset,
4493                                    local_fft_size);
4494
4495     dump_grid(stderr,
4496               pme->pmegrid_start_ix,
4497               pme->pmegrid_start_iy,
4498               pme->pmegrid_start_iz,
4499               pme->pmegrid_nx-pme->pme_order+1,
4500               pme->pmegrid_ny-pme->pme_order+1,
4501               pme->pmegrid_nz-pme->pme_order+1,
4502               local_fft_size[YY],
4503               local_fft_size[ZZ],
4504               fftgrid);
4505 }
4506
4507
4508 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4509 {
4510     pme_atomcomm_t *atc;
4511     pmegrids_t *grid;
4512
4513     if (pme->nnodes > 1)
4514     {
4515         gmx_incons("gmx_pme_calc_energy called in parallel");
4516     }
4517     if (pme->bFEP > 1)
4518     {
4519         gmx_incons("gmx_pme_calc_energy with free energy");
4520     }
4521
4522     atc            = &pme->atc_energy;
4523     atc->nthread   = 1;
4524     if (atc->spline == NULL)
4525     {
4526         snew(atc->spline, atc->nthread);
4527     }
4528     atc->nslab     = 1;
4529     atc->bSpread   = TRUE;
4530     atc->pme_order = pme->pme_order;
4531     atc->n         = n;
4532     pme_realloc_atomcomm_things(atc);
4533     atc->x         = x;
4534     atc->q         = q;
4535
4536     /* We only use the A-charges grid */
4537     grid = &pme->pmegrid[PME_GRID_QA];
4538
4539     /* Only calculate the spline coefficients, don't actually spread */
4540     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgrid[PME_GRID_QA], FALSE);
4541
4542     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4543 }
4544
4545
4546 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4547                                    gmx_walltime_accounting_t walltime_accounting,
4548                                    t_nrnb *nrnb, t_inputrec *ir,
4549                                    gmx_int64_t step)
4550 {
4551     /* Reset all the counters related to performance over the run */
4552     wallcycle_stop(wcycle, ewcRUN);
4553     wallcycle_reset_all(wcycle);
4554     init_nrnb(nrnb);
4555     if (ir->nsteps >= 0)
4556     {
4557         /* ir->nsteps is not used here, but we update it for consistency */
4558         ir->nsteps -= step - ir->init_step;
4559     }
4560     ir->init_step = step;
4561     wallcycle_start(wcycle, ewcRUN);
4562     walltime_accounting_start(walltime_accounting);
4563 }
4564
4565
4566 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4567                                ivec grid_size,
4568                                t_commrec *cr, t_inputrec *ir,
4569                                gmx_pme_t *pme_ret)
4570 {
4571     int ind;
4572     gmx_pme_t pme = NULL;
4573
4574     ind = 0;
4575     while (ind < *npmedata)
4576     {
4577         pme = (*pmedata)[ind];
4578         if (pme->nkx == grid_size[XX] &&
4579             pme->nky == grid_size[YY] &&
4580             pme->nkz == grid_size[ZZ])
4581         {
4582             *pme_ret = pme;
4583
4584             return;
4585         }
4586
4587         ind++;
4588     }
4589
4590     (*npmedata)++;
4591     srenew(*pmedata, *npmedata);
4592
4593     /* Generate a new PME data structure, copying part of the old pointers */
4594     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4595
4596     *pme_ret = (*pmedata)[ind];
4597 }
4598
4599 int gmx_pmeonly(gmx_pme_t pme,
4600                 t_commrec *cr,    t_nrnb *mynrnb,
4601                 gmx_wallcycle_t wcycle,
4602                 gmx_walltime_accounting_t walltime_accounting,
4603                 real ewaldcoeff_q, real ewaldcoeff_lj,
4604                 t_inputrec *ir)
4605 {
4606     int npmedata;
4607     gmx_pme_t *pmedata;
4608     gmx_pme_pp_t pme_pp;
4609     int  ret;
4610     int  natoms;
4611     matrix box;
4612     rvec *x_pp      = NULL, *f_pp = NULL;
4613     real *chargeA   = NULL, *chargeB = NULL;
4614     real *c6A       = NULL, *c6B = NULL;
4615     real *sigmaA    = NULL, *sigmaB = NULL;
4616     real lambda_q   = 0;
4617     real lambda_lj  = 0;
4618     int  maxshift_x = 0, maxshift_y = 0;
4619     real energy_q, energy_lj, dvdlambda_q, dvdlambda_lj;
4620     matrix vir_q, vir_lj;
4621     float cycles;
4622     int  count;
4623     gmx_bool bEnerVir;
4624     int pme_flags;
4625     gmx_int64_t step, step_rel;
4626     ivec grid_switch;
4627
4628     /* This data will only use with PME tuning, i.e. switching PME grids */
4629     npmedata = 1;
4630     snew(pmedata, npmedata);
4631     pmedata[0] = pme;
4632
4633     pme_pp = gmx_pme_pp_init(cr);
4634
4635     init_nrnb(mynrnb);
4636
4637     count = 0;
4638     do /****** this is a quasi-loop over time steps! */
4639     {
4640         /* The reason for having a loop here is PME grid tuning/switching */
4641         do
4642         {
4643             /* Domain decomposition */
4644             ret = gmx_pme_recv_params_coords(pme_pp,
4645                                              &natoms,
4646                                              &chargeA, &chargeB,
4647                                              &c6A, &c6B,
4648                                              &sigmaA, &sigmaB,
4649                                              box, &x_pp, &f_pp,
4650                                              &maxshift_x, &maxshift_y,
4651                                              &pme->bFEP_q, &pme->bFEP_lj,
4652                                              &lambda_q, &lambda_lj,
4653                                              &bEnerVir,
4654                                              &pme_flags,
4655                                              &step,
4656                                              grid_switch, &ewaldcoeff_q, &ewaldcoeff_lj);
4657
4658             if (ret == pmerecvqxSWITCHGRID)
4659             {
4660                 /* Switch the PME grid to grid_switch */
4661                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4662             }
4663
4664             if (ret == pmerecvqxRESETCOUNTERS)
4665             {
4666                 /* Reset the cycle and flop counters */
4667                 reset_pmeonly_counters(wcycle, walltime_accounting, mynrnb, ir, step);
4668             }
4669         }
4670         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4671
4672         if (ret == pmerecvqxFINISH)
4673         {
4674             /* We should stop: break out of the loop */
4675             break;
4676         }
4677
4678         step_rel = step - ir->init_step;
4679
4680         if (count == 0)
4681         {
4682             wallcycle_start(wcycle, ewcRUN);
4683             walltime_accounting_start(walltime_accounting);
4684         }
4685
4686         wallcycle_start(wcycle, ewcPMEMESH);
4687
4688         dvdlambda_q  = 0;
4689         dvdlambda_lj = 0;
4690         clear_mat(vir_q);
4691         clear_mat(vir_lj);
4692
4693         gmx_pme_do(pme, 0, natoms, x_pp, f_pp,
4694                    chargeA, chargeB, c6A, c6B, sigmaA, sigmaB, box,
4695                    cr, maxshift_x, maxshift_y, mynrnb, wcycle,
4696                    vir_q, ewaldcoeff_q, vir_lj, ewaldcoeff_lj,
4697                    &energy_q, &energy_lj, lambda_q, lambda_lj, &dvdlambda_q, &dvdlambda_lj,
4698                    pme_flags | GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4699
4700         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4701
4702         gmx_pme_send_force_vir_ener(pme_pp,
4703                                     f_pp, vir_q, energy_q, vir_lj, energy_lj,
4704                                     dvdlambda_q, dvdlambda_lj, cycles);
4705
4706         count++;
4707     } /***** end of quasi-loop, we stop with the break above */
4708     while (TRUE);
4709
4710     walltime_accounting_end(walltime_accounting);
4711
4712     return 0;
4713 }
4714
4715 static void
4716 calc_initial_lb_coeffs(gmx_pme_t pme, real *local_c6, real *local_sigma)
4717 {
4718     int  i;
4719
4720     for (i = 0; i < pme->atc[0].n; ++i)
4721     {
4722         real sigma4;
4723
4724         sigma4           = local_sigma[i];
4725         sigma4           = sigma4*sigma4;
4726         sigma4           = sigma4*sigma4;
4727         pme->atc[0].q[i] = local_c6[i] / sigma4;
4728     }
4729 }
4730
4731 static void
4732 calc_next_lb_coeffs(gmx_pme_t pme, real *local_sigma)
4733 {
4734     int  i;
4735
4736     for (i = 0; i < pme->atc[0].n; ++i)
4737     {
4738         pme->atc[0].q[i] *= local_sigma[i];
4739     }
4740 }
4741
4742 static void
4743 do_redist_x_q(gmx_pme_t pme, t_commrec *cr, int start, int homenr,
4744               gmx_bool bFirst, rvec x[], real *data)
4745 {
4746     int      d;
4747     pme_atomcomm_t *atc;
4748     atc = &pme->atc[0];
4749
4750     for (d = pme->ndecompdim - 1; d >= 0; d--)
4751     {
4752         int             n_d;
4753         rvec           *x_d;
4754         real           *q_d;
4755
4756         if (d == pme->ndecompdim - 1)
4757         {
4758             n_d = homenr;
4759             x_d = x + start;
4760             q_d = data;
4761         }
4762         else
4763         {
4764             n_d = pme->atc[d + 1].n;
4765             x_d = atc->x;
4766             q_d = atc->q;
4767         }
4768         atc      = &pme->atc[d];
4769         atc->npd = n_d;
4770         if (atc->npd > atc->pd_nalloc)
4771         {
4772             atc->pd_nalloc = over_alloc_dd(atc->npd);
4773             srenew(atc->pd, atc->pd_nalloc);
4774         }
4775         pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4776         where();
4777         /* Redistribute x (only once) and qA/c6A or qB/c6B */
4778         if (DOMAINDECOMP(cr))
4779         {
4780             dd_pmeredist_x_q(pme, n_d, bFirst, x_d, q_d, atc);
4781         }
4782         else
4783         {
4784             pmeredist_pd(pme, TRUE, n_d, bFirst, x_d, q_d, atc);
4785         }
4786     }
4787 }
4788
4789 /* TODO: when adding free-energy support, sigmaB will no longer be
4790  * unused */
4791 int gmx_pme_do(gmx_pme_t pme,
4792                int start,       int homenr,
4793                rvec x[],        rvec f[],
4794                real *chargeA,   real *chargeB,
4795                real *c6A,       real *c6B,
4796                real *sigmaA,    real gmx_unused *sigmaB,
4797                matrix box, t_commrec *cr,
4798                int  maxshift_x, int maxshift_y,
4799                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4800                matrix vir_q,      real ewaldcoeff_q,
4801                matrix vir_lj,   real ewaldcoeff_lj,
4802                real *energy_q,  real *energy_lj,
4803                real lambda_q, real lambda_lj,
4804                real *dvdlambda_q, real *dvdlambda_lj,
4805                int flags)
4806 {
4807     int     q, d, i, j, ntot, npme, qmax;
4808     int     nx, ny, nz;
4809     int     n_d, local_ny;
4810     pme_atomcomm_t *atc = NULL;
4811     pmegrids_t *pmegrid = NULL;
4812     real    *grid       = NULL;
4813     real    *ptr;
4814     rvec    *x_d, *f_d;
4815     real    *charge = NULL, *q_d;
4816     real    energy_AB[4];
4817     matrix  vir_AB[4];
4818     real    scale, lambda;
4819     gmx_bool bClearF;
4820     gmx_parallel_3dfft_t pfft_setup;
4821     real *  fftgrid;
4822     t_complex * cfftgrid;
4823     int     thread;
4824     gmx_bool bFirst, bDoSplines;
4825     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4826     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4827
4828     assert(pme->nnodes > 0);
4829     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4830
4831     if (pme->nnodes > 1)
4832     {
4833         atc      = &pme->atc[0];
4834         atc->npd = homenr;
4835         if (atc->npd > atc->pd_nalloc)
4836         {
4837             atc->pd_nalloc = over_alloc_dd(atc->npd);
4838             srenew(atc->pd, atc->pd_nalloc);
4839         }
4840         for (d = pme->ndecompdim-1; d >= 0; d--)
4841         {
4842             atc           = &pme->atc[d];
4843             atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4844         }
4845     }
4846     else
4847     {
4848         atc = &pme->atc[0];
4849         /* This could be necessary for TPI */
4850         pme->atc[0].n = homenr;
4851         if (DOMAINDECOMP(cr))
4852         {
4853             pme_realloc_atomcomm_things(atc);
4854         }
4855         atc->x = x;
4856         atc->f = f;
4857     }
4858
4859     m_inv_ur0(box, pme->recipbox);
4860     bFirst = TRUE;
4861
4862     /* For simplicity, we construct the splines for all particles if
4863      * more than one PME calculations is needed. Some optimization
4864      * could be done by keeping track of which atoms have splines
4865      * constructed, and construct new splines on each pass for atoms
4866      * that don't yet have them.
4867      */
4868
4869     bDoSplines = pme->bFEP || ((flags & GMX_PME_DO_COULOMB) && (flags & GMX_PME_DO_LJ));
4870
4871     /* We need a maximum of four separate PME calculations:
4872      * q=0: Coulomb PME with charges from state A
4873      * q=1: Coulomb PME with charges from state B
4874      * q=2: LJ PME with C6 from state A
4875      * q=3: LJ PME with C6 from state B
4876      * For Lorentz-Berthelot combination rules, a separate loop is used to
4877      * calculate all the terms
4878      */
4879
4880     /* If we are doing LJ-PME with LB, we only do Q here */
4881     qmax = (flags & GMX_PME_LJ_LB) ? DO_Q : DO_Q_AND_LJ;
4882
4883     for (q = 0; q < qmax; ++q)
4884     {
4885         /* Check if we should do calculations at this q
4886          * If q is odd we should be doing FEP
4887          * If q < 2 we should be doing electrostatic PME
4888          * If q >= 2 we should be doing LJ-PME
4889          */
4890         if ((!pme->bFEP_q && q == 1)
4891             || (!pme->bFEP_lj && q == 3)
4892             || (!(flags & GMX_PME_DO_COULOMB) && q < 2)
4893             || (!(flags & GMX_PME_DO_LJ) && q >= 2))
4894         {
4895             continue;
4896         }
4897         /* Unpack structure */
4898         pmegrid    = &pme->pmegrid[q];
4899         fftgrid    = pme->fftgrid[q];
4900         cfftgrid   = pme->cfftgrid[q];
4901         pfft_setup = pme->pfft_setup[q];
4902         switch (q)
4903         {
4904             case 0: charge = chargeA + start; break;
4905             case 1: charge = chargeB + start; break;
4906             case 2: charge = c6A + start; break;
4907             case 3: charge = c6B + start; break;
4908         }
4909
4910         grid = pmegrid->grid.grid;
4911
4912         if (debug)
4913         {
4914             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4915                     cr->nnodes, cr->nodeid);
4916             fprintf(debug, "Grid = %p\n", (void*)grid);
4917             if (grid == NULL)
4918             {
4919                 gmx_fatal(FARGS, "No grid!");
4920             }
4921         }
4922         where();
4923
4924         if (pme->nnodes == 1)
4925         {
4926             atc->q = charge;
4927         }
4928         else
4929         {
4930             wallcycle_start(wcycle, ewcPME_REDISTXF);
4931             do_redist_x_q(pme, cr, start, homenr, bFirst, x, charge);
4932             where();
4933
4934             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4935         }
4936
4937         if (debug)
4938         {
4939             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4940                     cr->nodeid, atc->n);
4941         }
4942
4943         if (flags & GMX_PME_SPREAD_Q)
4944         {
4945             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4946
4947             /* Spread the charges on a grid */
4948             spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines);
4949
4950             if (bFirst)
4951             {
4952                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4953             }
4954             inc_nrnb(nrnb, eNR_SPREADQBSP,
4955                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4956
4957             if (!pme->bUseThreads)
4958             {
4959                 wrap_periodic_pmegrid(pme, grid);
4960
4961                 /* sum contributions to local grid from other nodes */
4962 #ifdef GMX_MPI
4963                 if (pme->nnodes > 1)
4964                 {
4965                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4966                     where();
4967                 }
4968 #endif
4969
4970                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid, q);
4971             }
4972
4973             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4974
4975             /*
4976                dump_local_fftgrid(pme,fftgrid);
4977                exit(0);
4978              */
4979         }
4980
4981         /* Here we start a large thread parallel region */
4982 #pragma omp parallel num_threads(pme->nthread) private(thread)
4983         {
4984             thread = gmx_omp_get_thread_num();
4985             if (flags & GMX_PME_SOLVE)
4986             {
4987                 int loop_count;
4988
4989                 /* do 3d-fft */
4990                 if (thread == 0)
4991                 {
4992                     wallcycle_start(wcycle, ewcPME_FFT);
4993                 }
4994                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4995                                            thread, wcycle);
4996                 if (thread == 0)
4997                 {
4998                     wallcycle_stop(wcycle, ewcPME_FFT);
4999                 }
5000                 where();
5001
5002                 /* solve in k-space for our local cells */
5003                 if (thread == 0)
5004                 {
5005                     wallcycle_start(wcycle, (q < DO_Q ? ewcPME_SOLVE : ewcLJPME));
5006                 }
5007                 if (q < DO_Q)
5008                 {
5009                     loop_count =
5010                         solve_pme_yzx(pme, cfftgrid, ewaldcoeff_q,
5011                                       box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5012                                       bCalcEnerVir,
5013                                       pme->nthread, thread);
5014                 }
5015                 else
5016                 {
5017                     loop_count =
5018                         solve_pme_lj_yzx(pme, &cfftgrid, FALSE, ewaldcoeff_lj,
5019                                          box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5020                                          bCalcEnerVir,
5021                                          pme->nthread, thread);
5022                 }
5023
5024                 if (thread == 0)
5025                 {
5026                     wallcycle_stop(wcycle, (q < DO_Q ? ewcPME_SOLVE : ewcLJPME));
5027                     where();
5028                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5029                 }
5030             }
5031
5032             if (bCalcF)
5033             {
5034                 /* do 3d-invfft */
5035                 if (thread == 0)
5036                 {
5037                     where();
5038                     wallcycle_start(wcycle, ewcPME_FFT);
5039                 }
5040                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5041                                            thread, wcycle);
5042                 if (thread == 0)
5043                 {
5044                     wallcycle_stop(wcycle, ewcPME_FFT);
5045
5046                     where();
5047
5048                     if (pme->nodeid == 0)
5049                     {
5050                         ntot  = pme->nkx*pme->nky*pme->nkz;
5051                         npme  = ntot*log((real)ntot)/log(2.0);
5052                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
5053                     }
5054
5055                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5056                 }
5057
5058                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, q, pme->nthread, thread);
5059             }
5060         }
5061         /* End of thread parallel section.
5062          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
5063          */
5064
5065         if (bCalcF)
5066         {
5067             /* distribute local grid to all nodes */
5068 #ifdef GMX_MPI
5069             if (pme->nnodes > 1)
5070             {
5071                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
5072             }
5073 #endif
5074             where();
5075
5076             unwrap_periodic_pmegrid(pme, grid);
5077
5078             /* interpolate forces for our local atoms */
5079
5080             where();
5081
5082             /* If we are running without parallelization,
5083              * atc->f is the actual force array, not a buffer,
5084              * therefore we should not clear it.
5085              */
5086             lambda  = q < DO_Q ? lambda_q : lambda_lj;
5087             bClearF = (bFirst && PAR(cr));
5088 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5089             for (thread = 0; thread < pme->nthread; thread++)
5090             {
5091                 gather_f_bsplines(pme, grid, bClearF, atc,
5092                                   &atc->spline[thread],
5093                                   pme->bFEP ? (q % 2 == 0 ? 1.0-lambda : lambda) : 1.0);
5094             }
5095
5096             where();
5097
5098             inc_nrnb(nrnb, eNR_GATHERFBSP,
5099                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5100             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5101         }
5102
5103         if (bCalcEnerVir)
5104         {
5105             /* This should only be called on the master thread
5106              * and after the threads have synchronized.
5107              */
5108             if (q < 2)
5109             {
5110                 get_pme_ener_vir_q(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
5111             }
5112             else
5113             {
5114                 get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
5115             }
5116         }
5117         bFirst = FALSE;
5118     } /* of q-loop */
5119
5120     /* For Lorentz-Berthelot combination rules in LJ-PME, we need to calculate
5121      * seven terms. */
5122
5123     if (flags & GMX_PME_LJ_LB)
5124     {
5125         real *local_c6 = NULL, *local_sigma = NULL;
5126         if (pme->nnodes == 1)
5127         {
5128             if (pme->lb_buf1 == NULL)
5129             {
5130                 pme->lb_buf_nalloc = pme->atc[0].n;
5131                 snew(pme->lb_buf1, pme->lb_buf_nalloc);
5132             }
5133             pme->atc[0].q = pme->lb_buf1;
5134             local_c6      = c6A;
5135             local_sigma   = sigmaA;
5136         }
5137         else
5138         {
5139             atc = &pme->atc[0];
5140
5141             wallcycle_start(wcycle, ewcPME_REDISTXF);
5142
5143             do_redist_x_q(pme, cr, start, homenr, bFirst, x, c6A);
5144             if (pme->lb_buf_nalloc < atc->n)
5145             {
5146                 pme->lb_buf_nalloc = atc->nalloc;
5147                 srenew(pme->lb_buf1, pme->lb_buf_nalloc);
5148                 srenew(pme->lb_buf2, pme->lb_buf_nalloc);
5149             }
5150             local_c6 = pme->lb_buf1;
5151             for (i = 0; i < atc->n; ++i)
5152             {
5153                 local_c6[i] = atc->q[i];
5154             }
5155             where();
5156
5157             do_redist_x_q(pme, cr, start, homenr, FALSE, x, sigmaA);
5158             local_sigma = pme->lb_buf2;
5159             for (i = 0; i < atc->n; ++i)
5160             {
5161                 local_sigma[i] = atc->q[i];
5162             }
5163             where();
5164
5165             wallcycle_stop(wcycle, ewcPME_REDISTXF);
5166         }
5167         calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5168
5169         /*Seven terms in LJ-PME with LB, q < 2 reserved for electrostatics*/
5170         for (q = 2; q < 9; ++q)
5171         {
5172             /* Unpack structure */
5173             pmegrid    = &pme->pmegrid[q];
5174             fftgrid    = pme->fftgrid[q];
5175             cfftgrid   = pme->cfftgrid[q];
5176             pfft_setup = pme->pfft_setup[q];
5177             calc_next_lb_coeffs(pme, local_sigma);
5178             grid = pmegrid->grid.grid;
5179             where();
5180
5181             if (flags & GMX_PME_SPREAD_Q)
5182             {
5183                 wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5184                 /* Spread the charges on a grid */
5185                 spread_on_grid(pme, &pme->atc[0], pmegrid, bFirst, TRUE, fftgrid, bDoSplines);
5186
5187                 if (bFirst)
5188                 {
5189                     inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
5190                 }
5191                 inc_nrnb(nrnb, eNR_SPREADQBSP,
5192                          pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
5193                 if (pme->nthread == 1)
5194                 {
5195                     wrap_periodic_pmegrid(pme, grid);
5196
5197                     /* sum contributions to local grid from other nodes */
5198 #ifdef GMX_MPI
5199                     if (pme->nnodes > 1)
5200                     {
5201                         gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
5202                         where();
5203                     }
5204 #endif
5205                     copy_pmegrid_to_fftgrid(pme, grid, fftgrid, q);
5206                 }
5207
5208                 wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5209             }
5210
5211             /*Here we start a large thread parallel region*/
5212 #pragma omp parallel num_threads(pme->nthread) private(thread)
5213             {
5214                 thread = gmx_omp_get_thread_num();
5215                 if (flags & GMX_PME_SOLVE)
5216                 {
5217                     /* do 3d-fft */
5218                     if (thread == 0)
5219                     {
5220                         wallcycle_start(wcycle, ewcPME_FFT);
5221                     }
5222
5223                     gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
5224                                                thread, wcycle);
5225                     if (thread == 0)
5226                     {
5227                         wallcycle_stop(wcycle, ewcPME_FFT);
5228                     }
5229                     where();
5230                 }
5231             }
5232             bFirst = FALSE;
5233         }
5234         if (flags & GMX_PME_SOLVE)
5235         {
5236             int loop_count;
5237             /* solve in k-space for our local cells */
5238 #pragma omp parallel num_threads(pme->nthread) private(thread)
5239             {
5240                 thread = gmx_omp_get_thread_num();
5241                 if (thread == 0)
5242                 {
5243                     wallcycle_start(wcycle, ewcLJPME);
5244                 }
5245
5246                 loop_count =
5247                     solve_pme_lj_yzx(pme, &pme->cfftgrid[2], TRUE, ewaldcoeff_lj,
5248                                      box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
5249                                      bCalcEnerVir,
5250                                      pme->nthread, thread);
5251                 if (thread == 0)
5252                 {
5253                     wallcycle_stop(wcycle, ewcLJPME);
5254                     where();
5255                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
5256                 }
5257             }
5258         }
5259
5260         if (bCalcEnerVir)
5261         {
5262             /* This should only be called on the master thread and
5263              * after the threads have synchronized.
5264              */
5265             get_pme_ener_vir_lj(pme, pme->nthread, &energy_AB[2], vir_AB[2]);
5266         }
5267
5268         if (bCalcF)
5269         {
5270             bFirst = !(flags & GMX_PME_DO_COULOMB);
5271             calc_initial_lb_coeffs(pme, local_c6, local_sigma);
5272             for (q = 8; q >= 2; --q)
5273             {
5274                 /* Unpack structure */
5275                 pmegrid    = &pme->pmegrid[q];
5276                 fftgrid    = pme->fftgrid[q];
5277                 cfftgrid   = pme->cfftgrid[q];
5278                 pfft_setup = pme->pfft_setup[q];
5279                 grid       = pmegrid->grid.grid;
5280                 calc_next_lb_coeffs(pme, local_sigma);
5281                 where();
5282 #pragma omp parallel num_threads(pme->nthread) private(thread)
5283                 {
5284                     thread = gmx_omp_get_thread_num();
5285                     /* do 3d-invfft */
5286                     if (thread == 0)
5287                     {
5288                         where();
5289                         wallcycle_start(wcycle, ewcPME_FFT);
5290                     }
5291
5292                     gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
5293                                                thread, wcycle);
5294                     if (thread == 0)
5295                     {
5296                         wallcycle_stop(wcycle, ewcPME_FFT);
5297
5298                         where();
5299
5300                         if (pme->nodeid == 0)
5301                         {
5302                             ntot  = pme->nkx*pme->nky*pme->nkz;
5303                             npme  = ntot*log((real)ntot)/log(2.0);
5304                             inc_nrnb(nrnb, eNR_FFT, 2*npme);
5305                         }
5306                         wallcycle_start(wcycle, ewcPME_SPREADGATHER);
5307                     }
5308
5309                     copy_fftgrid_to_pmegrid(pme, fftgrid, grid, q, pme->nthread, thread);
5310
5311                 } /*#pragma omp parallel*/
5312
5313                 /* distribute local grid to all nodes */
5314 #ifdef GMX_MPI
5315                 if (pme->nnodes > 1)
5316                 {
5317                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
5318                 }
5319 #endif
5320                 where();
5321
5322                 unwrap_periodic_pmegrid(pme, grid);
5323
5324                 /* interpolate forces for our local atoms */
5325                 where();
5326                 bClearF = (bFirst && PAR(cr));
5327                 scale   = pme->bFEP ? (q < 9 ? 1.0-lambda_lj : lambda_lj) : 1.0;
5328                 scale  *= lb_scale_factor[q-2];
5329 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
5330                 for (thread = 0; thread < pme->nthread; thread++)
5331                 {
5332                     gather_f_bsplines(pme, grid, bClearF, &pme->atc[0],
5333                                       &pme->atc[0].spline[thread],
5334                                       scale);
5335                 }
5336                 where();
5337
5338                 inc_nrnb(nrnb, eNR_GATHERFBSP,
5339                          pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
5340                 wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
5341
5342                 bFirst = FALSE;
5343             } /*for (q = 8; q >= 2; --q)*/
5344         }     /*if (bCalcF)*/
5345     }         /*if (flags & GMX_PME_LJ_LB)*/
5346
5347     if (bCalcF && pme->nnodes > 1)
5348     {
5349         wallcycle_start(wcycle, ewcPME_REDISTXF);
5350         for (d = 0; d < pme->ndecompdim; d++)
5351         {
5352             atc = &pme->atc[d];
5353             if (d == pme->ndecompdim - 1)
5354             {
5355                 n_d = homenr;
5356                 f_d = f + start;
5357             }
5358             else
5359             {
5360                 n_d = pme->atc[d+1].n;
5361                 f_d = pme->atc[d+1].f;
5362             }
5363             if (DOMAINDECOMP(cr))
5364             {
5365                 dd_pmeredist_f(pme, atc, n_d, f_d,
5366                                d == pme->ndecompdim-1 && pme->bPPnode);
5367             }
5368             else
5369             {
5370                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
5371             }
5372         }
5373
5374         wallcycle_stop(wcycle, ewcPME_REDISTXF);
5375     }
5376     where();
5377
5378     if (bCalcEnerVir)
5379     {
5380         if (flags & GMX_PME_DO_COULOMB)
5381         {
5382             if (!pme->bFEP_q)
5383             {
5384                 *energy_q = energy_AB[0];
5385                 m_add(vir_q, vir_AB[0], vir_q);
5386             }
5387             else
5388             {
5389                 *energy_q       = (1.0-lambda_q)*energy_AB[0] + lambda_q*energy_AB[1];
5390                 *dvdlambda_q   += energy_AB[1] - energy_AB[0];
5391                 for (i = 0; i < DIM; i++)
5392                 {
5393                     for (j = 0; j < DIM; j++)
5394                     {
5395                         vir_q[i][j] += (1.0-lambda_q)*vir_AB[0][i][j] +
5396                             lambda_q*vir_AB[1][i][j];
5397                     }
5398                 }
5399             }
5400             if (debug)
5401             {
5402                 fprintf(debug, "Electrostatic PME mesh energy: %g\n", *energy_q);
5403             }
5404         }
5405         else
5406         {
5407             *energy_q = 0;
5408         }
5409
5410         if (flags & GMX_PME_DO_LJ)
5411         {
5412             if (!pme->bFEP_lj)
5413             {
5414                 *energy_lj = energy_AB[2];
5415                 m_add(vir_lj, vir_AB[2], vir_lj);
5416             }
5417             else
5418             {
5419                 *energy_lj     = (1.0-lambda_lj)*energy_AB[2] + lambda_lj*energy_AB[3];
5420                 *dvdlambda_lj += energy_AB[3] - energy_AB[2];
5421                 for (i = 0; i < DIM; i++)
5422                 {
5423                     for (j = 0; j < DIM; j++)
5424                     {
5425                         vir_lj[i][j] += (1.0-lambda_lj)*vir_AB[2][i][j] + lambda_lj*vir_AB[3][i][j];
5426                     }
5427                 }
5428             }
5429             if (debug)
5430             {
5431                 fprintf(debug, "Lennard-Jones PME mesh energy: %g\n", *energy_lj);
5432             }
5433         }
5434         else
5435         {
5436             *energy_lj = 0;
5437         }
5438     }
5439     return 0;
5440 }