PME load balancing now checks for PME grid restrictions
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team,
6  * check out http://www.gromacs.org for more information.
7  * Copyright (c) 2012,2013, by the GROMACS development team, led by
8  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
9  * others, as listed in the AUTHORS file in the top-level source
10  * directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* IMPORTANT FOR DEVELOPERS:
39  *
40  * Triclinic pme stuff isn't entirely trivial, and we've experienced
41  * some bugs during development (many of them due to me). To avoid
42  * this in the future, please check the following things if you make
43  * changes in this file:
44  *
45  * 1. You should obtain identical (at least to the PME precision)
46  *    energies, forces, and virial for
47  *    a rectangular box and a triclinic one where the z (or y) axis is
48  *    tilted a whole box side. For instance you could use these boxes:
49  *
50  *    rectangular       triclinic
51  *     2  0  0           2  0  0
52  *     0  2  0           0  2  0
53  *     0  0  6           2  2  6
54  *
55  * 2. You should check the energy conservation in a triclinic box.
56  *
57  * It might seem an overkill, but better safe than sorry.
58  * /Erik 001109
59  */
60
61 #ifdef HAVE_CONFIG_H
62 #include <config.h>
63 #endif
64
65 #ifdef GMX_LIB_MPI
66 #include <mpi.h>
67 #endif
68 #ifdef GMX_THREAD_MPI
69 #include "tmpi.h"
70 #endif
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93 #include "gmx_omp.h"
94
95 /* Single precision, with SSE2 or higher available */
96 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
97
98 #include "gmx_x86_simd_single.h"
99
100 #define PME_SSE
101 /* Some old AMD processors could have problems with unaligned loads+stores */
102 #ifndef GMX_FAHCORE
103 #define PME_SSE_UNALIGNED
104 #endif
105 #endif
106
107 #include "mpelogging.h"
108
109 #define DFT_TOL 1e-7
110 /* #define PRT_FORCE */
111 /* conditions for on the fly time-measurement */
112 /* #define TAKETIME (step > 1 && timesteps < 10) */
113 #define TAKETIME FALSE
114
115 /* #define PME_TIME_THREADS */
116
117 #ifdef GMX_DOUBLE
118 #define mpi_type MPI_DOUBLE
119 #else
120 #define mpi_type MPI_FLOAT
121 #endif
122
123 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
124 #define GMX_CACHE_SEP 64
125
126 /* We only define a maximum to be able to use local arrays without allocation.
127  * An order larger than 12 should never be needed, even for test cases.
128  * If needed it can be changed here.
129  */
130 #define PME_ORDER_MAX 12
131
132 /* Internal datastructures */
133 typedef struct {
134     int send_index0;
135     int send_nindex;
136     int recv_index0;
137     int recv_nindex;
138     int recv_size;   /* Receive buffer width, used with OpenMP */
139 } pme_grid_comm_t;
140
141 typedef struct {
142 #ifdef GMX_MPI
143     MPI_Comm         mpi_comm;
144 #endif
145     int              nnodes, nodeid;
146     int             *s2g0;
147     int             *s2g1;
148     int              noverlap_nodes;
149     int             *send_id, *recv_id;
150     int              send_size; /* Send buffer width, used with OpenMP */
151     pme_grid_comm_t *comm_data;
152     real            *sendbuf;
153     real            *recvbuf;
154 } pme_overlap_t;
155
156 typedef struct {
157     int *n;      /* Cumulative counts of the number of particles per thread */
158     int  nalloc; /* Allocation size of i */
159     int *i;      /* Particle indices ordered on thread index (n) */
160 } thread_plist_t;
161
162 typedef struct {
163     int      *thread_one;
164     int       n;
165     int      *ind;
166     splinevec theta;
167     real     *ptr_theta_z;
168     splinevec dtheta;
169     real     *ptr_dtheta_z;
170 } splinedata_t;
171
172 typedef struct {
173     int      dimind;        /* The index of the dimension, 0=x, 1=y */
174     int      nslab;
175     int      nodeid;
176 #ifdef GMX_MPI
177     MPI_Comm mpi_comm;
178 #endif
179
180     int     *node_dest;     /* The nodes to send x and q to with DD */
181     int     *node_src;      /* The nodes to receive x and q from with DD */
182     int     *buf_index;     /* Index for commnode into the buffers */
183
184     int      maxshift;
185
186     int      npd;
187     int      pd_nalloc;
188     int     *pd;
189     int     *count;         /* The number of atoms to send to each node */
190     int    **count_thread;
191     int     *rcount;        /* The number of atoms to receive */
192
193     int      n;
194     int      nalloc;
195     rvec    *x;
196     real    *q;
197     rvec    *f;
198     gmx_bool bSpread;       /* These coordinates are used for spreading */
199     int      pme_order;
200     ivec    *idx;
201     rvec    *fractx;            /* Fractional coordinate relative to the
202                                  * lower cell boundary
203                                  */
204     int             nthread;
205     int            *thread_idx; /* Which thread should spread which charge */
206     thread_plist_t *thread_plist;
207     splinedata_t   *spline;
208 } pme_atomcomm_t;
209
210 #define FLBS  3
211 #define FLBSZ 4
212
213 typedef struct {
214     ivec  ci;     /* The spatial location of this grid         */
215     ivec  n;      /* The used size of *grid, including order-1 */
216     ivec  offset; /* The grid offset from the full node grid   */
217     int   order;  /* PME spreading order                       */
218     ivec  s;      /* The allocated size of *grid, s >= n       */
219     real *grid;   /* The grid local thread, size n             */
220 } pmegrid_t;
221
222 typedef struct {
223     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
224     int        nthread;      /* The number of threads operating on this grid     */
225     ivec       nc;           /* The local spatial decomposition over the threads */
226     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
227     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
228     int      **g2t;          /* The grid to thread index                         */
229     ivec       nthread_comm; /* The number of threads to communicate with        */
230 } pmegrids_t;
231
232
233 typedef struct {
234 #ifdef PME_SSE
235     /* Masks for SSE aligned spreading and gathering */
236     __m128 mask_SSE0[6], mask_SSE1[6];
237 #else
238     int    dummy; /* C89 requires that struct has at least one member */
239 #endif
240 } pme_spline_work_t;
241
242 typedef struct {
243     /* work data for solve_pme */
244     int      nalloc;
245     real *   mhx;
246     real *   mhy;
247     real *   mhz;
248     real *   m2;
249     real *   denom;
250     real *   tmp1_alloc;
251     real *   tmp1;
252     real *   eterm;
253     real *   m2inv;
254
255     real     energy;
256     matrix   vir;
257 } pme_work_t;
258
259 typedef struct gmx_pme {
260     int           ndecompdim; /* The number of decomposition dimensions */
261     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
262     int           nodeid_major;
263     int           nodeid_minor;
264     int           nnodes;    /* The number of nodes doing PME */
265     int           nnodes_major;
266     int           nnodes_minor;
267
268     MPI_Comm      mpi_comm;
269     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
270 #ifdef GMX_MPI
271     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
272 #endif
273
274     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
275     int        nthread;       /* The number of threads doing PME on our rank */
276
277     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
278     gmx_bool   bFEP;          /* Compute Free energy contribution */
279     int        nkx, nky, nkz; /* Grid dimensions */
280     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
281     int        pme_order;
282     real       epsilon_r;
283
284     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
285     pmegrids_t pmegridB;
286     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
287     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
288     /* pmegrid_nz might be larger than strictly necessary to ensure
289      * memory alignment, pmegrid_nz_base gives the real base size.
290      */
291     int     pmegrid_nz_base;
292     /* The local PME grid starting indices */
293     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
294
295     /* Work data for spreading and gathering */
296     pme_spline_work_t    *spline_work;
297
298     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
299     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
300     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
301
302     t_complex            *cfftgridA;  /* Grids for complex FFT data */
303     t_complex            *cfftgridB;
304     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
305
306     gmx_parallel_3dfft_t  pfft_setupA;
307     gmx_parallel_3dfft_t  pfft_setupB;
308
309     int                  *nnx, *nny, *nnz;
310     real                 *fshx, *fshy, *fshz;
311
312     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
313     matrix                recipbox;
314     splinevec             bsp_mod;
315
316     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
317
318     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
319
320     rvec                 *bufv;       /* Communication buffer */
321     real                 *bufr;       /* Communication buffer */
322     int                   buf_nalloc; /* The communication buffer size */
323
324     /* thread local work data for solve_pme */
325     pme_work_t *work;
326
327     /* Work data for PME_redist */
328     gmx_bool redist_init;
329     int *    scounts;
330     int *    rcounts;
331     int *    sdispls;
332     int *    rdispls;
333     int *    sidx;
334     int *    idxa;
335     real *   redist_buf;
336     int      redist_buf_nalloc;
337
338     /* Work data for sum_qgrid */
339     real *   sum_qgrid_tmp;
340     real *   sum_qgrid_dd_tmp;
341 } t_gmx_pme;
342
343
344 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
345                                    int start, int end, int thread)
346 {
347     int             i;
348     int            *idxptr, tix, tiy, tiz;
349     real           *xptr, *fptr, tx, ty, tz;
350     real            rxx, ryx, ryy, rzx, rzy, rzz;
351     int             nx, ny, nz;
352     int             start_ix, start_iy, start_iz;
353     int            *g2tx, *g2ty, *g2tz;
354     gmx_bool        bThreads;
355     int            *thread_idx = NULL;
356     thread_plist_t *tpl        = NULL;
357     int            *tpl_n      = NULL;
358     int             thread_i;
359
360     nx  = pme->nkx;
361     ny  = pme->nky;
362     nz  = pme->nkz;
363
364     start_ix = pme->pmegrid_start_ix;
365     start_iy = pme->pmegrid_start_iy;
366     start_iz = pme->pmegrid_start_iz;
367
368     rxx = pme->recipbox[XX][XX];
369     ryx = pme->recipbox[YY][XX];
370     ryy = pme->recipbox[YY][YY];
371     rzx = pme->recipbox[ZZ][XX];
372     rzy = pme->recipbox[ZZ][YY];
373     rzz = pme->recipbox[ZZ][ZZ];
374
375     g2tx = pme->pmegridA.g2t[XX];
376     g2ty = pme->pmegridA.g2t[YY];
377     g2tz = pme->pmegridA.g2t[ZZ];
378
379     bThreads = (atc->nthread > 1);
380     if (bThreads)
381     {
382         thread_idx = atc->thread_idx;
383
384         tpl   = &atc->thread_plist[thread];
385         tpl_n = tpl->n;
386         for (i = 0; i < atc->nthread; i++)
387         {
388             tpl_n[i] = 0;
389         }
390     }
391
392     for (i = start; i < end; i++)
393     {
394         xptr   = atc->x[i];
395         idxptr = atc->idx[i];
396         fptr   = atc->fractx[i];
397
398         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
399         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
400         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
401         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
402
403         tix = (int)(tx);
404         tiy = (int)(ty);
405         tiz = (int)(tz);
406
407         /* Because decomposition only occurs in x and y,
408          * we never have a fraction correction in z.
409          */
410         fptr[XX] = tx - tix + pme->fshx[tix];
411         fptr[YY] = ty - tiy + pme->fshy[tiy];
412         fptr[ZZ] = tz - tiz;
413
414         idxptr[XX] = pme->nnx[tix];
415         idxptr[YY] = pme->nny[tiy];
416         idxptr[ZZ] = pme->nnz[tiz];
417
418 #ifdef DEBUG
419         range_check(idxptr[XX], 0, pme->pmegrid_nx);
420         range_check(idxptr[YY], 0, pme->pmegrid_ny);
421         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
422 #endif
423
424         if (bThreads)
425         {
426             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
427             thread_idx[i] = thread_i;
428             tpl_n[thread_i]++;
429         }
430     }
431
432     if (bThreads)
433     {
434         /* Make a list of particle indices sorted on thread */
435
436         /* Get the cumulative count */
437         for (i = 1; i < atc->nthread; i++)
438         {
439             tpl_n[i] += tpl_n[i-1];
440         }
441         /* The current implementation distributes particles equally
442          * over the threads, so we could actually allocate for that
443          * in pme_realloc_atomcomm_things.
444          */
445         if (tpl_n[atc->nthread-1] > tpl->nalloc)
446         {
447             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
448             srenew(tpl->i, tpl->nalloc);
449         }
450         /* Set tpl_n to the cumulative start */
451         for (i = atc->nthread-1; i >= 1; i--)
452         {
453             tpl_n[i] = tpl_n[i-1];
454         }
455         tpl_n[0] = 0;
456
457         /* Fill our thread local array with indices sorted on thread */
458         for (i = start; i < end; i++)
459         {
460             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
461         }
462         /* Now tpl_n contains the cummulative count again */
463     }
464 }
465
466 static void make_thread_local_ind(pme_atomcomm_t *atc,
467                                   int thread, splinedata_t *spline)
468 {
469     int             n, t, i, start, end;
470     thread_plist_t *tpl;
471
472     /* Combine the indices made by each thread into one index */
473
474     n     = 0;
475     start = 0;
476     for (t = 0; t < atc->nthread; t++)
477     {
478         tpl = &atc->thread_plist[t];
479         /* Copy our part (start - end) from the list of thread t */
480         if (thread > 0)
481         {
482             start = tpl->n[thread-1];
483         }
484         end = tpl->n[thread];
485         for (i = start; i < end; i++)
486         {
487             spline->ind[n++] = tpl->i[i];
488         }
489     }
490
491     spline->n = n;
492 }
493
494
495 static void pme_calc_pidx(int start, int end,
496                           matrix recipbox, rvec x[],
497                           pme_atomcomm_t *atc, int *count)
498 {
499     int   nslab, i;
500     int   si;
501     real *xptr, s;
502     real  rxx, ryx, rzx, ryy, rzy;
503     int  *pd;
504
505     /* Calculate PME task index (pidx) for each grid index.
506      * Here we always assign equally sized slabs to each node
507      * for load balancing reasons (the PME grid spacing is not used).
508      */
509
510     nslab = atc->nslab;
511     pd    = atc->pd;
512
513     /* Reset the count */
514     for (i = 0; i < nslab; i++)
515     {
516         count[i] = 0;
517     }
518
519     if (atc->dimind == 0)
520     {
521         rxx = recipbox[XX][XX];
522         ryx = recipbox[YY][XX];
523         rzx = recipbox[ZZ][XX];
524         /* Calculate the node index in x-dimension */
525         for (i = start; i < end; i++)
526         {
527             xptr   = x[i];
528             /* Fractional coordinates along box vectors */
529             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
530             si    = (int)(s + 2*nslab) % nslab;
531             pd[i] = si;
532             count[si]++;
533         }
534     }
535     else
536     {
537         ryy = recipbox[YY][YY];
538         rzy = recipbox[ZZ][YY];
539         /* Calculate the node index in y-dimension */
540         for (i = start; i < end; i++)
541         {
542             xptr   = x[i];
543             /* Fractional coordinates along box vectors */
544             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
545             si    = (int)(s + 2*nslab) % nslab;
546             pd[i] = si;
547             count[si]++;
548         }
549     }
550 }
551
552 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
553                                   pme_atomcomm_t *atc)
554 {
555     int nthread, thread, slab;
556
557     nthread = atc->nthread;
558
559 #pragma omp parallel for num_threads(nthread) schedule(static)
560     for (thread = 0; thread < nthread; thread++)
561     {
562         pme_calc_pidx(natoms* thread   /nthread,
563                       natoms*(thread+1)/nthread,
564                       recipbox, x, atc, atc->count_thread[thread]);
565     }
566     /* Non-parallel reduction, since nslab is small */
567
568     for (thread = 1; thread < nthread; thread++)
569     {
570         for (slab = 0; slab < atc->nslab; slab++)
571         {
572             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
573         }
574     }
575 }
576
577 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
578 {
579     const int padding = 4;
580     int       i;
581
582     srenew(th[XX], nalloc);
583     srenew(th[YY], nalloc);
584     /* In z we add padding, this is only required for the aligned SSE code */
585     srenew(*ptr_z, nalloc+2*padding);
586     th[ZZ] = *ptr_z + padding;
587
588     for (i = 0; i < padding; i++)
589     {
590         (*ptr_z)[               i] = 0;
591         (*ptr_z)[padding+nalloc+i] = 0;
592     }
593 }
594
595 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
596 {
597     int i, d;
598
599     srenew(spline->ind, atc->nalloc);
600     /* Initialize the index to identity so it works without threads */
601     for (i = 0; i < atc->nalloc; i++)
602     {
603         spline->ind[i] = i;
604     }
605
606     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
607                       atc->pme_order*atc->nalloc);
608     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
609                       atc->pme_order*atc->nalloc);
610 }
611
612 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
613 {
614     int nalloc_old, i, j, nalloc_tpl;
615
616     /* We have to avoid a NULL pointer for atc->x to avoid
617      * possible fatal errors in MPI routines.
618      */
619     if (atc->n > atc->nalloc || atc->nalloc == 0)
620     {
621         nalloc_old  = atc->nalloc;
622         atc->nalloc = over_alloc_dd(max(atc->n, 1));
623
624         if (atc->nslab > 1)
625         {
626             srenew(atc->x, atc->nalloc);
627             srenew(atc->q, atc->nalloc);
628             srenew(atc->f, atc->nalloc);
629             for (i = nalloc_old; i < atc->nalloc; i++)
630             {
631                 clear_rvec(atc->f[i]);
632             }
633         }
634         if (atc->bSpread)
635         {
636             srenew(atc->fractx, atc->nalloc);
637             srenew(atc->idx, atc->nalloc);
638
639             if (atc->nthread > 1)
640             {
641                 srenew(atc->thread_idx, atc->nalloc);
642             }
643
644             for (i = 0; i < atc->nthread; i++)
645             {
646                 pme_realloc_splinedata(&atc->spline[i], atc);
647             }
648         }
649     }
650 }
651
652 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
653                          int n, gmx_bool bXF, rvec *x_f, real *charge,
654                          pme_atomcomm_t *atc)
655 /* Redistribute particle data for PME calculation */
656 /* domain decomposition by x coordinate           */
657 {
658     int *idxa;
659     int  i, ii;
660
661     if (FALSE == pme->redist_init)
662     {
663         snew(pme->scounts, atc->nslab);
664         snew(pme->rcounts, atc->nslab);
665         snew(pme->sdispls, atc->nslab);
666         snew(pme->rdispls, atc->nslab);
667         snew(pme->sidx, atc->nslab);
668         pme->redist_init = TRUE;
669     }
670     if (n > pme->redist_buf_nalloc)
671     {
672         pme->redist_buf_nalloc = over_alloc_dd(n);
673         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
674     }
675
676     pme->idxa = atc->pd;
677
678 #ifdef GMX_MPI
679     if (forw && bXF)
680     {
681         /* forward, redistribution from pp to pme */
682
683         /* Calculate send counts and exchange them with other nodes */
684         for (i = 0; (i < atc->nslab); i++)
685         {
686             pme->scounts[i] = 0;
687         }
688         for (i = 0; (i < n); i++)
689         {
690             pme->scounts[pme->idxa[i]]++;
691         }
692         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
693
694         /* Calculate send and receive displacements and index into send
695            buffer */
696         pme->sdispls[0] = 0;
697         pme->rdispls[0] = 0;
698         pme->sidx[0]    = 0;
699         for (i = 1; i < atc->nslab; i++)
700         {
701             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
702             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
703             pme->sidx[i]    = pme->sdispls[i];
704         }
705         /* Total # of particles to be received */
706         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
707
708         pme_realloc_atomcomm_things(atc);
709
710         /* Copy particle coordinates into send buffer and exchange*/
711         for (i = 0; (i < n); i++)
712         {
713             ii = DIM*pme->sidx[pme->idxa[i]];
714             pme->sidx[pme->idxa[i]]++;
715             pme->redist_buf[ii+XX] = x_f[i][XX];
716             pme->redist_buf[ii+YY] = x_f[i][YY];
717             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
718         }
719         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
720                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
721                       pme->rvec_mpi, atc->mpi_comm);
722     }
723     if (forw)
724     {
725         /* Copy charge into send buffer and exchange*/
726         for (i = 0; i < atc->nslab; i++)
727         {
728             pme->sidx[i] = pme->sdispls[i];
729         }
730         for (i = 0; (i < n); i++)
731         {
732             ii = pme->sidx[pme->idxa[i]];
733             pme->sidx[pme->idxa[i]]++;
734             pme->redist_buf[ii] = charge[i];
735         }
736         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
737                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
738                       atc->mpi_comm);
739     }
740     else   /* backward, redistribution from pme to pp */
741     {
742         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
743                       pme->redist_buf, pme->scounts, pme->sdispls,
744                       pme->rvec_mpi, atc->mpi_comm);
745
746         /* Copy data from receive buffer */
747         for (i = 0; i < atc->nslab; i++)
748         {
749             pme->sidx[i] = pme->sdispls[i];
750         }
751         for (i = 0; (i < n); i++)
752         {
753             ii          = DIM*pme->sidx[pme->idxa[i]];
754             x_f[i][XX] += pme->redist_buf[ii+XX];
755             x_f[i][YY] += pme->redist_buf[ii+YY];
756             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
757             pme->sidx[pme->idxa[i]]++;
758         }
759     }
760 #endif
761 }
762
763 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
764                             gmx_bool bBackward, int shift,
765                             void *buf_s, int nbyte_s,
766                             void *buf_r, int nbyte_r)
767 {
768 #ifdef GMX_MPI
769     int        dest, src;
770     MPI_Status stat;
771
772     if (bBackward == FALSE)
773     {
774         dest = atc->node_dest[shift];
775         src  = atc->node_src[shift];
776     }
777     else
778     {
779         dest = atc->node_src[shift];
780         src  = atc->node_dest[shift];
781     }
782
783     if (nbyte_s > 0 && nbyte_r > 0)
784     {
785         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
786                      dest, shift,
787                      buf_r, nbyte_r, MPI_BYTE,
788                      src, shift,
789                      atc->mpi_comm, &stat);
790     }
791     else if (nbyte_s > 0)
792     {
793         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
794                  dest, shift,
795                  atc->mpi_comm);
796     }
797     else if (nbyte_r > 0)
798     {
799         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
800                  src, shift,
801                  atc->mpi_comm, &stat);
802     }
803 #endif
804 }
805
806 static void dd_pmeredist_x_q(gmx_pme_t pme,
807                              int n, gmx_bool bX, rvec *x, real *charge,
808                              pme_atomcomm_t *atc)
809 {
810     int *commnode, *buf_index;
811     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
812
813     commnode  = atc->node_dest;
814     buf_index = atc->buf_index;
815
816     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
817
818     nsend = 0;
819     for (i = 0; i < nnodes_comm; i++)
820     {
821         buf_index[commnode[i]] = nsend;
822         nsend                 += atc->count[commnode[i]];
823     }
824     if (bX)
825     {
826         if (atc->count[atc->nodeid] + nsend != n)
827         {
828             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
829                       "This usually means that your system is not well equilibrated.",
830                       n - (atc->count[atc->nodeid] + nsend),
831                       pme->nodeid, 'x'+atc->dimind);
832         }
833
834         if (nsend > pme->buf_nalloc)
835         {
836             pme->buf_nalloc = over_alloc_dd(nsend);
837             srenew(pme->bufv, pme->buf_nalloc);
838             srenew(pme->bufr, pme->buf_nalloc);
839         }
840
841         atc->n = atc->count[atc->nodeid];
842         for (i = 0; i < nnodes_comm; i++)
843         {
844             scount = atc->count[commnode[i]];
845             /* Communicate the count */
846             if (debug)
847             {
848                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
849                         atc->dimind, atc->nodeid, commnode[i], scount);
850             }
851             pme_dd_sendrecv(atc, FALSE, i,
852                             &scount, sizeof(int),
853                             &atc->rcount[i], sizeof(int));
854             atc->n += atc->rcount[i];
855         }
856
857         pme_realloc_atomcomm_things(atc);
858     }
859
860     local_pos = 0;
861     for (i = 0; i < n; i++)
862     {
863         node = atc->pd[i];
864         if (node == atc->nodeid)
865         {
866             /* Copy direct to the receive buffer */
867             if (bX)
868             {
869                 copy_rvec(x[i], atc->x[local_pos]);
870             }
871             atc->q[local_pos] = charge[i];
872             local_pos++;
873         }
874         else
875         {
876             /* Copy to the send buffer */
877             if (bX)
878             {
879                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
880             }
881             pme->bufr[buf_index[node]] = charge[i];
882             buf_index[node]++;
883         }
884     }
885
886     buf_pos = 0;
887     for (i = 0; i < nnodes_comm; i++)
888     {
889         scount = atc->count[commnode[i]];
890         rcount = atc->rcount[i];
891         if (scount > 0 || rcount > 0)
892         {
893             if (bX)
894             {
895                 /* Communicate the coordinates */
896                 pme_dd_sendrecv(atc, FALSE, i,
897                                 pme->bufv[buf_pos], scount*sizeof(rvec),
898                                 atc->x[local_pos], rcount*sizeof(rvec));
899             }
900             /* Communicate the charges */
901             pme_dd_sendrecv(atc, FALSE, i,
902                             pme->bufr+buf_pos, scount*sizeof(real),
903                             atc->q+local_pos, rcount*sizeof(real));
904             buf_pos   += scount;
905             local_pos += atc->rcount[i];
906         }
907     }
908 }
909
910 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
911                            int n, rvec *f,
912                            gmx_bool bAddF)
913 {
914     int *commnode, *buf_index;
915     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
916
917     commnode  = atc->node_dest;
918     buf_index = atc->buf_index;
919
920     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
921
922     local_pos = atc->count[atc->nodeid];
923     buf_pos   = 0;
924     for (i = 0; i < nnodes_comm; i++)
925     {
926         scount = atc->rcount[i];
927         rcount = atc->count[commnode[i]];
928         if (scount > 0 || rcount > 0)
929         {
930             /* Communicate the forces */
931             pme_dd_sendrecv(atc, TRUE, i,
932                             atc->f[local_pos], scount*sizeof(rvec),
933                             pme->bufv[buf_pos], rcount*sizeof(rvec));
934             local_pos += scount;
935         }
936         buf_index[commnode[i]] = buf_pos;
937         buf_pos               += rcount;
938     }
939
940     local_pos = 0;
941     if (bAddF)
942     {
943         for (i = 0; i < n; i++)
944         {
945             node = atc->pd[i];
946             if (node == atc->nodeid)
947             {
948                 /* Add from the local force array */
949                 rvec_inc(f[i], atc->f[local_pos]);
950                 local_pos++;
951             }
952             else
953             {
954                 /* Add from the receive buffer */
955                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
956                 buf_index[node]++;
957             }
958         }
959     }
960     else
961     {
962         for (i = 0; i < n; i++)
963         {
964             node = atc->pd[i];
965             if (node == atc->nodeid)
966             {
967                 /* Copy from the local force array */
968                 copy_rvec(atc->f[local_pos], f[i]);
969                 local_pos++;
970             }
971             else
972             {
973                 /* Copy from the receive buffer */
974                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
975                 buf_index[node]++;
976             }
977         }
978     }
979 }
980
981 #ifdef GMX_MPI
982 static void
983 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
984 {
985     pme_overlap_t *overlap;
986     int            send_index0, send_nindex;
987     int            recv_index0, recv_nindex;
988     MPI_Status     stat;
989     int            i, j, k, ix, iy, iz, icnt;
990     int            ipulse, send_id, recv_id, datasize;
991     real          *p;
992     real          *sendptr, *recvptr;
993
994     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
995     overlap = &pme->overlap[1];
996
997     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
998     {
999         /* Since we have already (un)wrapped the overlap in the z-dimension,
1000          * we only have to communicate 0 to nkz (not pmegrid_nz).
1001          */
1002         if (direction == GMX_SUM_QGRID_FORWARD)
1003         {
1004             send_id       = overlap->send_id[ipulse];
1005             recv_id       = overlap->recv_id[ipulse];
1006             send_index0   = overlap->comm_data[ipulse].send_index0;
1007             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1008             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1009             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1010         }
1011         else
1012         {
1013             send_id       = overlap->recv_id[ipulse];
1014             recv_id       = overlap->send_id[ipulse];
1015             send_index0   = overlap->comm_data[ipulse].recv_index0;
1016             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1017             recv_index0   = overlap->comm_data[ipulse].send_index0;
1018             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1019         }
1020
1021         /* Copy data to contiguous send buffer */
1022         if (debug)
1023         {
1024             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1025                     pme->nodeid, overlap->nodeid, send_id,
1026                     pme->pmegrid_start_iy,
1027                     send_index0-pme->pmegrid_start_iy,
1028                     send_index0-pme->pmegrid_start_iy+send_nindex);
1029         }
1030         icnt = 0;
1031         for (i = 0; i < pme->pmegrid_nx; i++)
1032         {
1033             ix = i;
1034             for (j = 0; j < send_nindex; j++)
1035             {
1036                 iy = j + send_index0 - pme->pmegrid_start_iy;
1037                 for (k = 0; k < pme->nkz; k++)
1038                 {
1039                     iz = k;
1040                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1041                 }
1042             }
1043         }
1044
1045         datasize      = pme->pmegrid_nx * pme->nkz;
1046
1047         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1048                      send_id, ipulse,
1049                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1050                      recv_id, ipulse,
1051                      overlap->mpi_comm, &stat);
1052
1053         /* Get data from contiguous recv buffer */
1054         if (debug)
1055         {
1056             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1057                     pme->nodeid, overlap->nodeid, recv_id,
1058                     pme->pmegrid_start_iy,
1059                     recv_index0-pme->pmegrid_start_iy,
1060                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1061         }
1062         icnt = 0;
1063         for (i = 0; i < pme->pmegrid_nx; i++)
1064         {
1065             ix = i;
1066             for (j = 0; j < recv_nindex; j++)
1067             {
1068                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1069                 for (k = 0; k < pme->nkz; k++)
1070                 {
1071                     iz = k;
1072                     if (direction == GMX_SUM_QGRID_FORWARD)
1073                     {
1074                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1075                     }
1076                     else
1077                     {
1078                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1079                     }
1080                 }
1081             }
1082         }
1083     }
1084
1085     /* Major dimension is easier, no copying required,
1086      * but we might have to sum to separate array.
1087      * Since we don't copy, we have to communicate up to pmegrid_nz,
1088      * not nkz as for the minor direction.
1089      */
1090     overlap = &pme->overlap[0];
1091
1092     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1093     {
1094         if (direction == GMX_SUM_QGRID_FORWARD)
1095         {
1096             send_id       = overlap->send_id[ipulse];
1097             recv_id       = overlap->recv_id[ipulse];
1098             send_index0   = overlap->comm_data[ipulse].send_index0;
1099             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1100             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1101             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1102             recvptr       = overlap->recvbuf;
1103         }
1104         else
1105         {
1106             send_id       = overlap->recv_id[ipulse];
1107             recv_id       = overlap->send_id[ipulse];
1108             send_index0   = overlap->comm_data[ipulse].recv_index0;
1109             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1110             recv_index0   = overlap->comm_data[ipulse].send_index0;
1111             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1112             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1113         }
1114
1115         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1116         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1117
1118         if (debug)
1119         {
1120             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1121                     pme->nodeid, overlap->nodeid, send_id,
1122                     pme->pmegrid_start_ix,
1123                     send_index0-pme->pmegrid_start_ix,
1124                     send_index0-pme->pmegrid_start_ix+send_nindex);
1125             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1126                     pme->nodeid, overlap->nodeid, recv_id,
1127                     pme->pmegrid_start_ix,
1128                     recv_index0-pme->pmegrid_start_ix,
1129                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1130         }
1131
1132         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1133                      send_id, ipulse,
1134                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1135                      recv_id, ipulse,
1136                      overlap->mpi_comm, &stat);
1137
1138         /* ADD data from contiguous recv buffer */
1139         if (direction == GMX_SUM_QGRID_FORWARD)
1140         {
1141             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1142             for (i = 0; i < recv_nindex*datasize; i++)
1143             {
1144                 p[i] += overlap->recvbuf[i];
1145             }
1146         }
1147     }
1148 }
1149 #endif
1150
1151
1152 static int
1153 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1154 {
1155     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1156     ivec    local_pme_size;
1157     int     i, ix, iy, iz;
1158     int     pmeidx, fftidx;
1159
1160     /* Dimensions should be identical for A/B grid, so we just use A here */
1161     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1162                                    local_fft_ndata,
1163                                    local_fft_offset,
1164                                    local_fft_size);
1165
1166     local_pme_size[0] = pme->pmegrid_nx;
1167     local_pme_size[1] = pme->pmegrid_ny;
1168     local_pme_size[2] = pme->pmegrid_nz;
1169
1170     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1171        the offset is identical, and the PME grid always has more data (due to overlap)
1172      */
1173     {
1174 #ifdef DEBUG_PME
1175         FILE *fp, *fp2;
1176         char  fn[STRLEN], format[STRLEN];
1177         real  val;
1178         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1179         fp = ffopen(fn, "w");
1180         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1181         fp2 = ffopen(fn, "w");
1182         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1183 #endif
1184
1185         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1186         {
1187             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1188             {
1189                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1190                 {
1191                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1192                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1193                     fftgrid[fftidx] = pmegrid[pmeidx];
1194 #ifdef DEBUG_PME
1195                     val = 100*pmegrid[pmeidx];
1196                     if (pmegrid[pmeidx] != 0)
1197                     {
1198                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1199                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1200                     }
1201                     if (pmegrid[pmeidx] != 0)
1202                     {
1203                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1204                                 "qgrid",
1205                                 pme->pmegrid_start_ix + ix,
1206                                 pme->pmegrid_start_iy + iy,
1207                                 pme->pmegrid_start_iz + iz,
1208                                 pmegrid[pmeidx]);
1209                     }
1210 #endif
1211                 }
1212             }
1213         }
1214 #ifdef DEBUG_PME
1215         ffclose(fp);
1216         ffclose(fp2);
1217 #endif
1218     }
1219     return 0;
1220 }
1221
1222
1223 static gmx_cycles_t omp_cyc_start()
1224 {
1225     return gmx_cycles_read();
1226 }
1227
1228 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1229 {
1230     return gmx_cycles_read() - c;
1231 }
1232
1233
1234 static int
1235 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1236                         int nthread, int thread)
1237 {
1238     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1239     ivec          local_pme_size;
1240     int           ixy0, ixy1, ixy, ix, iy, iz;
1241     int           pmeidx, fftidx;
1242 #ifdef PME_TIME_THREADS
1243     gmx_cycles_t  c1;
1244     static double cs1 = 0;
1245     static int    cnt = 0;
1246 #endif
1247
1248 #ifdef PME_TIME_THREADS
1249     c1 = omp_cyc_start();
1250 #endif
1251     /* Dimensions should be identical for A/B grid, so we just use A here */
1252     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1253                                    local_fft_ndata,
1254                                    local_fft_offset,
1255                                    local_fft_size);
1256
1257     local_pme_size[0] = pme->pmegrid_nx;
1258     local_pme_size[1] = pme->pmegrid_ny;
1259     local_pme_size[2] = pme->pmegrid_nz;
1260
1261     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1262        the offset is identical, and the PME grid always has more data (due to overlap)
1263      */
1264     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1265     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1266
1267     for (ixy = ixy0; ixy < ixy1; ixy++)
1268     {
1269         ix = ixy/local_fft_ndata[YY];
1270         iy = ixy - ix*local_fft_ndata[YY];
1271
1272         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1273         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1274         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1275         {
1276             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1277         }
1278     }
1279
1280 #ifdef PME_TIME_THREADS
1281     c1   = omp_cyc_end(c1);
1282     cs1 += (double)c1;
1283     cnt++;
1284     if (cnt % 20 == 0)
1285     {
1286         printf("copy %.2f\n", cs1*1e-9);
1287     }
1288 #endif
1289
1290     return 0;
1291 }
1292
1293
1294 static void
1295 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1296 {
1297     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1298
1299     nx = pme->nkx;
1300     ny = pme->nky;
1301     nz = pme->nkz;
1302
1303     pnx = pme->pmegrid_nx;
1304     pny = pme->pmegrid_ny;
1305     pnz = pme->pmegrid_nz;
1306
1307     overlap = pme->pme_order - 1;
1308
1309     /* Add periodic overlap in z */
1310     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1311     {
1312         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1313         {
1314             for (iz = 0; iz < overlap; iz++)
1315             {
1316                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1317                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1318             }
1319         }
1320     }
1321
1322     if (pme->nnodes_minor == 1)
1323     {
1324         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1325         {
1326             for (iy = 0; iy < overlap; iy++)
1327             {
1328                 for (iz = 0; iz < nz; iz++)
1329                 {
1330                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1331                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1332                 }
1333             }
1334         }
1335     }
1336
1337     if (pme->nnodes_major == 1)
1338     {
1339         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1340
1341         for (ix = 0; ix < overlap; ix++)
1342         {
1343             for (iy = 0; iy < ny_x; iy++)
1344             {
1345                 for (iz = 0; iz < nz; iz++)
1346                 {
1347                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1348                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1349                 }
1350             }
1351         }
1352     }
1353 }
1354
1355
1356 static void
1357 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1358 {
1359     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1360
1361     nx = pme->nkx;
1362     ny = pme->nky;
1363     nz = pme->nkz;
1364
1365     pnx = pme->pmegrid_nx;
1366     pny = pme->pmegrid_ny;
1367     pnz = pme->pmegrid_nz;
1368
1369     overlap = pme->pme_order - 1;
1370
1371     if (pme->nnodes_major == 1)
1372     {
1373         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1374
1375         for (ix = 0; ix < overlap; ix++)
1376         {
1377             int iy, iz;
1378
1379             for (iy = 0; iy < ny_x; iy++)
1380             {
1381                 for (iz = 0; iz < nz; iz++)
1382                 {
1383                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1384                         pmegrid[(ix*pny+iy)*pnz+iz];
1385                 }
1386             }
1387         }
1388     }
1389
1390     if (pme->nnodes_minor == 1)
1391     {
1392 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1393         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1394         {
1395             int iy, iz;
1396
1397             for (iy = 0; iy < overlap; iy++)
1398             {
1399                 for (iz = 0; iz < nz; iz++)
1400                 {
1401                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1402                         pmegrid[(ix*pny+iy)*pnz+iz];
1403                 }
1404             }
1405         }
1406     }
1407
1408     /* Copy periodic overlap in z */
1409 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1410     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1411     {
1412         int iy, iz;
1413
1414         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1415         {
1416             for (iz = 0; iz < overlap; iz++)
1417             {
1418                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1419                     pmegrid[(ix*pny+iy)*pnz+iz];
1420             }
1421         }
1422     }
1423 }
1424
1425 static void clear_grid(int nx, int ny, int nz, real *grid,
1426                        ivec fs, int *flag,
1427                        int fx, int fy, int fz,
1428                        int order)
1429 {
1430     int nc, ncz;
1431     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1432     int flind;
1433
1434     nc  = 2 + (order - 2)/FLBS;
1435     ncz = 2 + (order - 2)/FLBSZ;
1436
1437     for (fsx = fx; fsx < fx+nc; fsx++)
1438     {
1439         for (fsy = fy; fsy < fy+nc; fsy++)
1440         {
1441             for (fsz = fz; fsz < fz+ncz; fsz++)
1442             {
1443                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1444                 if (flag[flind] == 0)
1445                 {
1446                     gx  = fsx*FLBS;
1447                     gy  = fsy*FLBS;
1448                     gz  = fsz*FLBSZ;
1449                     g0x = (gx*ny + gy)*nz + gz;
1450                     for (x = 0; x < FLBS; x++)
1451                     {
1452                         g0y = g0x;
1453                         for (y = 0; y < FLBS; y++)
1454                         {
1455                             for (z = 0; z < FLBSZ; z++)
1456                             {
1457                                 grid[g0y+z] = 0;
1458                             }
1459                             g0y += nz;
1460                         }
1461                         g0x += ny*nz;
1462                     }
1463
1464                     flag[flind] = 1;
1465                 }
1466             }
1467         }
1468     }
1469 }
1470
1471 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1472 #define DO_BSPLINE(order)                            \
1473     for (ithx = 0; (ithx < order); ithx++)                    \
1474     {                                                    \
1475         index_x = (i0+ithx)*pny*pnz;                     \
1476         valx    = qn*thx[ithx];                          \
1477                                                      \
1478         for (ithy = 0; (ithy < order); ithy++)                \
1479         {                                                \
1480             valxy    = valx*thy[ithy];                   \
1481             index_xy = index_x+(j0+ithy)*pnz;            \
1482                                                      \
1483             for (ithz = 0; (ithz < order); ithz++)            \
1484             {                                            \
1485                 index_xyz        = index_xy+(k0+ithz);   \
1486                 grid[index_xyz] += valxy*thz[ithz];      \
1487             }                                            \
1488         }                                                \
1489     }
1490
1491
1492 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1493                                      pme_atomcomm_t *atc, splinedata_t *spline,
1494                                      pme_spline_work_t *work)
1495 {
1496
1497     /* spread charges from home atoms to local grid */
1498     real          *grid;
1499     pme_overlap_t *ol;
1500     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1501     int       *    idxptr;
1502     int            order, norder, index_x, index_xy, index_xyz;
1503     real           valx, valxy, qn;
1504     real          *thx, *thy, *thz;
1505     int            localsize, bndsize;
1506     int            pnx, pny, pnz, ndatatot;
1507     int            offx, offy, offz;
1508
1509     pnx = pmegrid->s[XX];
1510     pny = pmegrid->s[YY];
1511     pnz = pmegrid->s[ZZ];
1512
1513     offx = pmegrid->offset[XX];
1514     offy = pmegrid->offset[YY];
1515     offz = pmegrid->offset[ZZ];
1516
1517     ndatatot = pnx*pny*pnz;
1518     grid     = pmegrid->grid;
1519     for (i = 0; i < ndatatot; i++)
1520     {
1521         grid[i] = 0;
1522     }
1523
1524     order = pmegrid->order;
1525
1526     for (nn = 0; nn < spline->n; nn++)
1527     {
1528         n  = spline->ind[nn];
1529         qn = atc->q[n];
1530
1531         if (qn != 0)
1532         {
1533             idxptr = atc->idx[n];
1534             norder = nn*order;
1535
1536             i0   = idxptr[XX] - offx;
1537             j0   = idxptr[YY] - offy;
1538             k0   = idxptr[ZZ] - offz;
1539
1540             thx = spline->theta[XX] + norder;
1541             thy = spline->theta[YY] + norder;
1542             thz = spline->theta[ZZ] + norder;
1543
1544             switch (order)
1545             {
1546                 case 4:
1547 #ifdef PME_SSE
1548 #ifdef PME_SSE_UNALIGNED
1549 #define PME_SPREAD_SSE_ORDER4
1550 #else
1551 #define PME_SPREAD_SSE_ALIGNED
1552 #define PME_ORDER 4
1553 #endif
1554 #include "pme_sse_single.h"
1555 #else
1556                     DO_BSPLINE(4);
1557 #endif
1558                     break;
1559                 case 5:
1560 #ifdef PME_SSE
1561 #define PME_SPREAD_SSE_ALIGNED
1562 #define PME_ORDER 5
1563 #include "pme_sse_single.h"
1564 #else
1565                     DO_BSPLINE(5);
1566 #endif
1567                     break;
1568                 default:
1569                     DO_BSPLINE(order);
1570                     break;
1571             }
1572         }
1573     }
1574 }
1575
1576 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1577 {
1578 #ifdef PME_SSE
1579     if (pme_order == 5
1580 #ifndef PME_SSE_UNALIGNED
1581         || pme_order == 4
1582 #endif
1583         )
1584     {
1585         /* Round nz up to a multiple of 4 to ensure alignment */
1586         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1587     }
1588 #endif
1589 }
1590
1591 static void set_gridsize_alignment(int *gridsize, int pme_order)
1592 {
1593 #ifdef PME_SSE
1594 #ifndef PME_SSE_UNALIGNED
1595     if (pme_order == 4)
1596     {
1597         /* Add extra elements to ensured aligned operations do not go
1598          * beyond the allocated grid size.
1599          * Note that for pme_order=5, the pme grid z-size alignment
1600          * ensures that we will not go beyond the grid size.
1601          */
1602         *gridsize += 4;
1603     }
1604 #endif
1605 #endif
1606 }
1607
1608 static void pmegrid_init(pmegrid_t *grid,
1609                          int cx, int cy, int cz,
1610                          int x0, int y0, int z0,
1611                          int x1, int y1, int z1,
1612                          gmx_bool set_alignment,
1613                          int pme_order,
1614                          real *ptr)
1615 {
1616     int nz, gridsize;
1617
1618     grid->ci[XX]     = cx;
1619     grid->ci[YY]     = cy;
1620     grid->ci[ZZ]     = cz;
1621     grid->offset[XX] = x0;
1622     grid->offset[YY] = y0;
1623     grid->offset[ZZ] = z0;
1624     grid->n[XX]      = x1 - x0 + pme_order - 1;
1625     grid->n[YY]      = y1 - y0 + pme_order - 1;
1626     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1627     copy_ivec(grid->n, grid->s);
1628
1629     nz = grid->s[ZZ];
1630     set_grid_alignment(&nz, pme_order);
1631     if (set_alignment)
1632     {
1633         grid->s[ZZ] = nz;
1634     }
1635     else if (nz != grid->s[ZZ])
1636     {
1637         gmx_incons("pmegrid_init call with an unaligned z size");
1638     }
1639
1640     grid->order = pme_order;
1641     if (ptr == NULL)
1642     {
1643         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1644         set_gridsize_alignment(&gridsize, pme_order);
1645         snew_aligned(grid->grid, gridsize, 16);
1646     }
1647     else
1648     {
1649         grid->grid = ptr;
1650     }
1651 }
1652
1653 static int div_round_up(int enumerator, int denominator)
1654 {
1655     return (enumerator + denominator - 1)/denominator;
1656 }
1657
1658 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1659                                   ivec nsub)
1660 {
1661     int gsize_opt, gsize;
1662     int nsx, nsy, nsz;
1663     char *env;
1664
1665     gsize_opt = -1;
1666     for (nsx = 1; nsx <= nthread; nsx++)
1667     {
1668         if (nthread % nsx == 0)
1669         {
1670             for (nsy = 1; nsy <= nthread; nsy++)
1671             {
1672                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1673                 {
1674                     nsz = nthread/(nsx*nsy);
1675
1676                     /* Determine the number of grid points per thread */
1677                     gsize =
1678                         (div_round_up(n[XX], nsx) + ovl)*
1679                         (div_round_up(n[YY], nsy) + ovl)*
1680                         (div_round_up(n[ZZ], nsz) + ovl);
1681
1682                     /* Minimize the number of grids points per thread
1683                      * and, secondarily, the number of cuts in minor dimensions.
1684                      */
1685                     if (gsize_opt == -1 ||
1686                         gsize < gsize_opt ||
1687                         (gsize == gsize_opt &&
1688                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1689                     {
1690                         nsub[XX]  = nsx;
1691                         nsub[YY]  = nsy;
1692                         nsub[ZZ]  = nsz;
1693                         gsize_opt = gsize;
1694                     }
1695                 }
1696             }
1697         }
1698     }
1699
1700     env = getenv("GMX_PME_THREAD_DIVISION");
1701     if (env != NULL)
1702     {
1703         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1704     }
1705
1706     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1707     {
1708         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1709     }
1710 }
1711
1712 static void pmegrids_init(pmegrids_t *grids,
1713                           int nx, int ny, int nz, int nz_base,
1714                           int pme_order,
1715                           gmx_bool bUseThreads,
1716                           int nthread,
1717                           int overlap_x,
1718                           int overlap_y)
1719 {
1720     ivec n, n_base, g0, g1;
1721     int t, x, y, z, d, i, tfac;
1722     int max_comm_lines = -1;
1723
1724     n[XX] = nx - (pme_order - 1);
1725     n[YY] = ny - (pme_order - 1);
1726     n[ZZ] = nz - (pme_order - 1);
1727
1728     copy_ivec(n, n_base);
1729     n_base[ZZ] = nz_base;
1730
1731     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1732                  NULL);
1733
1734     grids->nthread = nthread;
1735
1736     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1737
1738     if (bUseThreads)
1739     {
1740         ivec nst;
1741         int gridsize;
1742
1743         for (d = 0; d < DIM; d++)
1744         {
1745             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1746         }
1747         set_grid_alignment(&nst[ZZ], pme_order);
1748
1749         if (debug)
1750         {
1751             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1752                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1753             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1754                     nx, ny, nz,
1755                     nst[XX], nst[YY], nst[ZZ]);
1756         }
1757
1758         snew(grids->grid_th, grids->nthread);
1759         t        = 0;
1760         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1761         set_gridsize_alignment(&gridsize, pme_order);
1762         snew_aligned(grids->grid_all,
1763                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1764                      16);
1765
1766         for (x = 0; x < grids->nc[XX]; x++)
1767         {
1768             for (y = 0; y < grids->nc[YY]; y++)
1769             {
1770                 for (z = 0; z < grids->nc[ZZ]; z++)
1771                 {
1772                     pmegrid_init(&grids->grid_th[t],
1773                                  x, y, z,
1774                                  (n[XX]*(x  ))/grids->nc[XX],
1775                                  (n[YY]*(y  ))/grids->nc[YY],
1776                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1777                                  (n[XX]*(x+1))/grids->nc[XX],
1778                                  (n[YY]*(y+1))/grids->nc[YY],
1779                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1780                                  TRUE,
1781                                  pme_order,
1782                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1783                     t++;
1784                 }
1785             }
1786         }
1787     }
1788     else
1789     {
1790         grids->grid_th = NULL;
1791     }
1792
1793     snew(grids->g2t, DIM);
1794     tfac = 1;
1795     for (d = DIM-1; d >= 0; d--)
1796     {
1797         snew(grids->g2t[d], n[d]);
1798         t = 0;
1799         for (i = 0; i < n[d]; i++)
1800         {
1801             /* The second check should match the parameters
1802              * of the pmegrid_init call above.
1803              */
1804             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1805             {
1806                 t++;
1807             }
1808             grids->g2t[d][i] = t*tfac;
1809         }
1810
1811         tfac *= grids->nc[d];
1812
1813         switch (d)
1814         {
1815             case XX: max_comm_lines = overlap_x;     break;
1816             case YY: max_comm_lines = overlap_y;     break;
1817             case ZZ: max_comm_lines = pme_order - 1; break;
1818         }
1819         grids->nthread_comm[d] = 0;
1820         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1821                grids->nthread_comm[d] < grids->nc[d])
1822         {
1823             grids->nthread_comm[d]++;
1824         }
1825         if (debug != NULL)
1826         {
1827             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1828                     'x'+d, grids->nthread_comm[d]);
1829         }
1830         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1831          * work, but this is not a problematic restriction.
1832          */
1833         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1834         {
1835             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1836         }
1837     }
1838 }
1839
1840
1841 static void pmegrids_destroy(pmegrids_t *grids)
1842 {
1843     int t;
1844
1845     if (grids->grid.grid != NULL)
1846     {
1847         sfree(grids->grid.grid);
1848
1849         if (grids->nthread > 0)
1850         {
1851             for (t = 0; t < grids->nthread; t++)
1852             {
1853                 sfree(grids->grid_th[t].grid);
1854             }
1855             sfree(grids->grid_th);
1856         }
1857     }
1858 }
1859
1860
1861 static void realloc_work(pme_work_t *work, int nkx)
1862 {
1863     if (nkx > work->nalloc)
1864     {
1865         work->nalloc = nkx;
1866         srenew(work->mhx, work->nalloc);
1867         srenew(work->mhy, work->nalloc);
1868         srenew(work->mhz, work->nalloc);
1869         srenew(work->m2, work->nalloc);
1870         /* Allocate an aligned pointer for SSE operations, including 3 extra
1871          * elements at the end since SSE operates on 4 elements at a time.
1872          */
1873         sfree_aligned(work->denom);
1874         sfree_aligned(work->tmp1);
1875         sfree_aligned(work->eterm);
1876         snew_aligned(work->denom, work->nalloc+3, 16);
1877         snew_aligned(work->tmp1, work->nalloc+3, 16);
1878         snew_aligned(work->eterm, work->nalloc+3, 16);
1879         srenew(work->m2inv, work->nalloc);
1880     }
1881 }
1882
1883
1884 static void free_work(pme_work_t *work)
1885 {
1886     sfree(work->mhx);
1887     sfree(work->mhy);
1888     sfree(work->mhz);
1889     sfree(work->m2);
1890     sfree_aligned(work->denom);
1891     sfree_aligned(work->tmp1);
1892     sfree_aligned(work->eterm);
1893     sfree(work->m2inv);
1894 }
1895
1896
1897 #ifdef PME_SSE
1898 /* Calculate exponentials through SSE in float precision */
1899 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1900 {
1901     {
1902         const __m128 two = _mm_set_ps(2.0f, 2.0f, 2.0f, 2.0f);
1903         __m128 f_sse;
1904         __m128 lu;
1905         __m128 tmp_d1, d_inv, tmp_r, tmp_e;
1906         int kx;
1907         f_sse = _mm_load1_ps(&f);
1908         for (kx = 0; kx < end; kx += 4)
1909         {
1910             tmp_d1   = _mm_load_ps(d_aligned+kx);
1911             lu       = _mm_rcp_ps(tmp_d1);
1912             d_inv    = _mm_mul_ps(lu, _mm_sub_ps(two, _mm_mul_ps(lu, tmp_d1)));
1913             tmp_r    = _mm_load_ps(r_aligned+kx);
1914             tmp_r    = gmx_mm_exp_ps(tmp_r);
1915             tmp_e    = _mm_mul_ps(f_sse, d_inv);
1916             tmp_e    = _mm_mul_ps(tmp_e, tmp_r);
1917             _mm_store_ps(e_aligned+kx, tmp_e);
1918         }
1919     }
1920 }
1921 #else
1922 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1923 {
1924     int kx;
1925     for (kx = start; kx < end; kx++)
1926     {
1927         d[kx] = 1.0/d[kx];
1928     }
1929     for (kx = start; kx < end; kx++)
1930     {
1931         r[kx] = exp(r[kx]);
1932     }
1933     for (kx = start; kx < end; kx++)
1934     {
1935         e[kx] = f*r[kx]*d[kx];
1936     }
1937 }
1938 #endif
1939
1940
1941 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1942                          real ewaldcoeff, real vol,
1943                          gmx_bool bEnerVir,
1944                          int nthread, int thread)
1945 {
1946     /* do recip sum over local cells in grid */
1947     /* y major, z middle, x minor or continuous */
1948     t_complex *p0;
1949     int     kx, ky, kz, maxkx, maxky, maxkz;
1950     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1951     real    mx, my, mz;
1952     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1953     real    ets2, struct2, vfactor, ets2vf;
1954     real    d1, d2, energy = 0;
1955     real    by, bz;
1956     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1957     real    rxx, ryx, ryy, rzx, rzy, rzz;
1958     pme_work_t *work;
1959     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1960     real    mhxk, mhyk, mhzk, m2k;
1961     real    corner_fac;
1962     ivec    complex_order;
1963     ivec    local_ndata, local_offset, local_size;
1964     real    elfac;
1965
1966     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1967
1968     nx = pme->nkx;
1969     ny = pme->nky;
1970     nz = pme->nkz;
1971
1972     /* Dimensions should be identical for A/B grid, so we just use A here */
1973     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1974                                       complex_order,
1975                                       local_ndata,
1976                                       local_offset,
1977                                       local_size);
1978
1979     rxx = pme->recipbox[XX][XX];
1980     ryx = pme->recipbox[YY][XX];
1981     ryy = pme->recipbox[YY][YY];
1982     rzx = pme->recipbox[ZZ][XX];
1983     rzy = pme->recipbox[ZZ][YY];
1984     rzz = pme->recipbox[ZZ][ZZ];
1985
1986     maxkx = (nx+1)/2;
1987     maxky = (ny+1)/2;
1988     maxkz = nz/2+1;
1989
1990     work  = &pme->work[thread];
1991     mhx   = work->mhx;
1992     mhy   = work->mhy;
1993     mhz   = work->mhz;
1994     m2    = work->m2;
1995     denom = work->denom;
1996     tmp1  = work->tmp1;
1997     eterm = work->eterm;
1998     m2inv = work->m2inv;
1999
2000     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2001     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2002
2003     for (iyz = iyz0; iyz < iyz1; iyz++)
2004     {
2005         iy = iyz/local_ndata[ZZ];
2006         iz = iyz - iy*local_ndata[ZZ];
2007
2008         ky = iy + local_offset[YY];
2009
2010         if (ky < maxky)
2011         {
2012             my = ky;
2013         }
2014         else
2015         {
2016             my = (ky - ny);
2017         }
2018
2019         by = M_PI*vol*pme->bsp_mod[YY][ky];
2020
2021         kz = iz + local_offset[ZZ];
2022
2023         mz = kz;
2024
2025         bz = pme->bsp_mod[ZZ][kz];
2026
2027         /* 0.5 correction for corner points */
2028         corner_fac = 1;
2029         if (kz == 0 || kz == (nz+1)/2)
2030         {
2031             corner_fac = 0.5;
2032         }
2033
2034         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2035
2036         /* We should skip the k-space point (0,0,0) */
2037         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2038         {
2039             kxstart = local_offset[XX];
2040         }
2041         else
2042         {
2043             kxstart = local_offset[XX] + 1;
2044             p0++;
2045         }
2046         kxend = local_offset[XX] + local_ndata[XX];
2047
2048         if (bEnerVir)
2049         {
2050             /* More expensive inner loop, especially because of the storage
2051              * of the mh elements in array's.
2052              * Because x is the minor grid index, all mh elements
2053              * depend on kx for triclinic unit cells.
2054              */
2055
2056             /* Two explicit loops to avoid a conditional inside the loop */
2057             for (kx = kxstart; kx < maxkx; kx++)
2058             {
2059                 mx = kx;
2060
2061                 mhxk      = mx * rxx;
2062                 mhyk      = mx * ryx + my * ryy;
2063                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2064                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2065                 mhx[kx]   = mhxk;
2066                 mhy[kx]   = mhyk;
2067                 mhz[kx]   = mhzk;
2068                 m2[kx]    = m2k;
2069                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2070                 tmp1[kx]  = -factor*m2k;
2071             }
2072
2073             for (kx = maxkx; kx < kxend; kx++)
2074             {
2075                 mx = (kx - nx);
2076
2077                 mhxk      = mx * rxx;
2078                 mhyk      = mx * ryx + my * ryy;
2079                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2080                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2081                 mhx[kx]   = mhxk;
2082                 mhy[kx]   = mhyk;
2083                 mhz[kx]   = mhzk;
2084                 m2[kx]    = m2k;
2085                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2086                 tmp1[kx]  = -factor*m2k;
2087             }
2088
2089             for (kx = kxstart; kx < kxend; kx++)
2090             {
2091                 m2inv[kx] = 1.0/m2[kx];
2092             }
2093
2094             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2095
2096             for (kx = kxstart; kx < kxend; kx++, p0++)
2097             {
2098                 d1      = p0->re;
2099                 d2      = p0->im;
2100
2101                 p0->re  = d1*eterm[kx];
2102                 p0->im  = d2*eterm[kx];
2103
2104                 struct2 = 2.0*(d1*d1+d2*d2);
2105
2106                 tmp1[kx] = eterm[kx]*struct2;
2107             }
2108
2109             for (kx = kxstart; kx < kxend; kx++)
2110             {
2111                 ets2     = corner_fac*tmp1[kx];
2112                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2113                 energy  += ets2;
2114
2115                 ets2vf   = ets2*vfactor;
2116                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2117                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2118                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2119                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2120                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2121                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2122             }
2123         }
2124         else
2125         {
2126             /* We don't need to calculate the energy and the virial.
2127              * In this case the triclinic overhead is small.
2128              */
2129
2130             /* Two explicit loops to avoid a conditional inside the loop */
2131
2132             for (kx = kxstart; kx < maxkx; kx++)
2133             {
2134                 mx = kx;
2135
2136                 mhxk      = mx * rxx;
2137                 mhyk      = mx * ryx + my * ryy;
2138                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2139                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2140                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2141                 tmp1[kx]  = -factor*m2k;
2142             }
2143
2144             for (kx = maxkx; kx < kxend; kx++)
2145             {
2146                 mx = (kx - nx);
2147
2148                 mhxk      = mx * rxx;
2149                 mhyk      = mx * ryx + my * ryy;
2150                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2151                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2152                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2153                 tmp1[kx]  = -factor*m2k;
2154             }
2155
2156             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2157
2158             for (kx = kxstart; kx < kxend; kx++, p0++)
2159             {
2160                 d1      = p0->re;
2161                 d2      = p0->im;
2162
2163                 p0->re  = d1*eterm[kx];
2164                 p0->im  = d2*eterm[kx];
2165             }
2166         }
2167     }
2168
2169     if (bEnerVir)
2170     {
2171         /* Update virial with local values.
2172          * The virial is symmetric by definition.
2173          * this virial seems ok for isotropic scaling, but I'm
2174          * experiencing problems on semiisotropic membranes.
2175          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2176          */
2177         work->vir[XX][XX] = 0.25*virxx;
2178         work->vir[YY][YY] = 0.25*viryy;
2179         work->vir[ZZ][ZZ] = 0.25*virzz;
2180         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2181         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2182         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2183
2184         /* This energy should be corrected for a charged system */
2185         work->energy = 0.5*energy;
2186     }
2187
2188     /* Return the loop count */
2189     return local_ndata[YY]*local_ndata[XX];
2190 }
2191
2192 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2193                              real *mesh_energy, matrix vir)
2194 {
2195     /* This function sums output over threads
2196      * and should therefore only be called after thread synchronization.
2197      */
2198     int thread;
2199
2200     *mesh_energy = pme->work[0].energy;
2201     copy_mat(pme->work[0].vir, vir);
2202
2203     for (thread = 1; thread < nthread; thread++)
2204     {
2205         *mesh_energy += pme->work[thread].energy;
2206         m_add(vir, pme->work[thread].vir, vir);
2207     }
2208 }
2209
2210 #define DO_FSPLINE(order)                      \
2211     for (ithx = 0; (ithx < order); ithx++)              \
2212     {                                              \
2213         index_x = (i0+ithx)*pny*pnz;               \
2214         tx      = thx[ithx];                       \
2215         dx      = dthx[ithx];                      \
2216                                                \
2217         for (ithy = 0; (ithy < order); ithy++)          \
2218         {                                          \
2219             index_xy = index_x+(j0+ithy)*pnz;      \
2220             ty       = thy[ithy];                  \
2221             dy       = dthy[ithy];                 \
2222             fxy1     = fz1 = 0;                    \
2223                                                \
2224             for (ithz = 0; (ithz < order); ithz++)      \
2225             {                                      \
2226                 gval  = grid[index_xy+(k0+ithz)];  \
2227                 fxy1 += thz[ithz]*gval;            \
2228                 fz1  += dthz[ithz]*gval;           \
2229             }                                      \
2230             fx += dx*ty*fxy1;                      \
2231             fy += tx*dy*fxy1;                      \
2232             fz += tx*ty*fz1;                       \
2233         }                                          \
2234     }
2235
2236
2237 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2238                               gmx_bool bClearF, pme_atomcomm_t *atc,
2239                               splinedata_t *spline,
2240                               real scale)
2241 {
2242     /* sum forces for local particles */
2243     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2244     int     index_x, index_xy;
2245     int     nx, ny, nz, pnx, pny, pnz;
2246     int *   idxptr;
2247     real    tx, ty, dx, dy, qn;
2248     real    fx, fy, fz, gval;
2249     real    fxy1, fz1;
2250     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2251     int     norder;
2252     real    rxx, ryx, ryy, rzx, rzy, rzz;
2253     int     order;
2254
2255     pme_spline_work_t *work;
2256
2257     work = pme->spline_work;
2258
2259     order = pme->pme_order;
2260     thx   = spline->theta[XX];
2261     thy   = spline->theta[YY];
2262     thz   = spline->theta[ZZ];
2263     dthx  = spline->dtheta[XX];
2264     dthy  = spline->dtheta[YY];
2265     dthz  = spline->dtheta[ZZ];
2266     nx    = pme->nkx;
2267     ny    = pme->nky;
2268     nz    = pme->nkz;
2269     pnx   = pme->pmegrid_nx;
2270     pny   = pme->pmegrid_ny;
2271     pnz   = pme->pmegrid_nz;
2272
2273     rxx   = pme->recipbox[XX][XX];
2274     ryx   = pme->recipbox[YY][XX];
2275     ryy   = pme->recipbox[YY][YY];
2276     rzx   = pme->recipbox[ZZ][XX];
2277     rzy   = pme->recipbox[ZZ][YY];
2278     rzz   = pme->recipbox[ZZ][ZZ];
2279
2280     for (nn = 0; nn < spline->n; nn++)
2281     {
2282         n  = spline->ind[nn];
2283         qn = scale*atc->q[n];
2284
2285         if (bClearF)
2286         {
2287             atc->f[n][XX] = 0;
2288             atc->f[n][YY] = 0;
2289             atc->f[n][ZZ] = 0;
2290         }
2291         if (qn != 0)
2292         {
2293             fx     = 0;
2294             fy     = 0;
2295             fz     = 0;
2296             idxptr = atc->idx[n];
2297             norder = nn*order;
2298
2299             i0   = idxptr[XX];
2300             j0   = idxptr[YY];
2301             k0   = idxptr[ZZ];
2302
2303             /* Pointer arithmetic alert, next six statements */
2304             thx  = spline->theta[XX] + norder;
2305             thy  = spline->theta[YY] + norder;
2306             thz  = spline->theta[ZZ] + norder;
2307             dthx = spline->dtheta[XX] + norder;
2308             dthy = spline->dtheta[YY] + norder;
2309             dthz = spline->dtheta[ZZ] + norder;
2310
2311             switch (order)
2312             {
2313                 case 4:
2314 #ifdef PME_SSE
2315 #ifdef PME_SSE_UNALIGNED
2316 #define PME_GATHER_F_SSE_ORDER4
2317 #else
2318 #define PME_GATHER_F_SSE_ALIGNED
2319 #define PME_ORDER 4
2320 #endif
2321 #include "pme_sse_single.h"
2322 #else
2323                     DO_FSPLINE(4);
2324 #endif
2325                     break;
2326                 case 5:
2327 #ifdef PME_SSE
2328 #define PME_GATHER_F_SSE_ALIGNED
2329 #define PME_ORDER 5
2330 #include "pme_sse_single.h"
2331 #else
2332                     DO_FSPLINE(5);
2333 #endif
2334                     break;
2335                 default:
2336                     DO_FSPLINE(order);
2337                     break;
2338             }
2339
2340             atc->f[n][XX] += -qn*( fx*nx*rxx );
2341             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2342             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2343         }
2344     }
2345     /* Since the energy and not forces are interpolated
2346      * the net force might not be exactly zero.
2347      * This can be solved by also interpolating F, but
2348      * that comes at a cost.
2349      * A better hack is to remove the net force every
2350      * step, but that must be done at a higher level
2351      * since this routine doesn't see all atoms if running
2352      * in parallel. Don't know how important it is?  EL 990726
2353      */
2354 }
2355
2356
2357 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2358                                    pme_atomcomm_t *atc)
2359 {
2360     splinedata_t *spline;
2361     int     n, ithx, ithy, ithz, i0, j0, k0;
2362     int     index_x, index_xy;
2363     int *   idxptr;
2364     real    energy, pot, tx, ty, qn, gval;
2365     real    *thx, *thy, *thz;
2366     int     norder;
2367     int     order;
2368
2369     spline = &atc->spline[0];
2370
2371     order = pme->pme_order;
2372
2373     energy = 0;
2374     for (n = 0; (n < atc->n); n++)
2375     {
2376         qn      = atc->q[n];
2377
2378         if (qn != 0)
2379         {
2380             idxptr = atc->idx[n];
2381             norder = n*order;
2382
2383             i0   = idxptr[XX];
2384             j0   = idxptr[YY];
2385             k0   = idxptr[ZZ];
2386
2387             /* Pointer arithmetic alert, next three statements */
2388             thx  = spline->theta[XX] + norder;
2389             thy  = spline->theta[YY] + norder;
2390             thz  = spline->theta[ZZ] + norder;
2391
2392             pot = 0;
2393             for (ithx = 0; (ithx < order); ithx++)
2394             {
2395                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2396                 tx      = thx[ithx];
2397
2398                 for (ithy = 0; (ithy < order); ithy++)
2399                 {
2400                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2401                     ty       = thy[ithy];
2402
2403                     for (ithz = 0; (ithz < order); ithz++)
2404                     {
2405                         gval  = grid[index_xy+(k0+ithz)];
2406                         pot  += tx*ty*thz[ithz]*gval;
2407                     }
2408
2409                 }
2410             }
2411
2412             energy += pot*qn;
2413         }
2414     }
2415
2416     return energy;
2417 }
2418
2419 /* Macro to force loop unrolling by fixing order.
2420  * This gives a significant performance gain.
2421  */
2422 #define CALC_SPLINE(order)                     \
2423     {                                              \
2424         int j, k, l;                                 \
2425         real dr, div;                               \
2426         real data[PME_ORDER_MAX];                  \
2427         real ddata[PME_ORDER_MAX];                 \
2428                                                \
2429         for (j = 0; (j < DIM); j++)                     \
2430         {                                          \
2431             dr  = xptr[j];                         \
2432                                                \
2433             /* dr is relative offset from lower cell limit */ \
2434             data[order-1] = 0;                     \
2435             data[1]       = dr;                          \
2436             data[0]       = 1 - dr;                      \
2437                                                \
2438             for (k = 3; (k < order); k++)               \
2439             {                                      \
2440                 div       = 1.0/(k - 1.0);               \
2441                 data[k-1] = div*dr*data[k-2];      \
2442                 for (l = 1; (l < (k-1)); l++)           \
2443                 {                                  \
2444                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2445                                        data[k-l-1]);                \
2446                 }                                  \
2447                 data[0] = div*(1-dr)*data[0];      \
2448             }                                      \
2449             /* differentiate */                    \
2450             ddata[0] = -data[0];                   \
2451             for (k = 1; (k < order); k++)               \
2452             {                                      \
2453                 ddata[k] = data[k-1] - data[k];    \
2454             }                                      \
2455                                                \
2456             div           = 1.0/(order - 1);                 \
2457             data[order-1] = div*dr*data[order-2];  \
2458             for (l = 1; (l < (order-1)); l++)           \
2459             {                                      \
2460                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2461                                        (order-l-dr)*data[order-l-1]); \
2462             }                                      \
2463             data[0] = div*(1 - dr)*data[0];        \
2464                                                \
2465             for (k = 0; k < order; k++)                 \
2466             {                                      \
2467                 theta[j][i*order+k]  = data[k];    \
2468                 dtheta[j][i*order+k] = ddata[k];   \
2469             }                                      \
2470         }                                          \
2471     }
2472
2473 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2474                    rvec fractx[], int nr, int ind[], real charge[],
2475                    gmx_bool bFreeEnergy)
2476 {
2477     /* construct splines for local atoms */
2478     int  i, ii;
2479     real *xptr;
2480
2481     for (i = 0; i < nr; i++)
2482     {
2483         /* With free energy we do not use the charge check.
2484          * In most cases this will be more efficient than calling make_bsplines
2485          * twice, since usually more than half the particles have charges.
2486          */
2487         ii = ind[i];
2488         if (bFreeEnergy || charge[ii] != 0.0)
2489         {
2490             xptr = fractx[ii];
2491             switch (order)
2492             {
2493                 case 4:  CALC_SPLINE(4);     break;
2494                 case 5:  CALC_SPLINE(5);     break;
2495                 default: CALC_SPLINE(order); break;
2496             }
2497         }
2498     }
2499 }
2500
2501
2502 void make_dft_mod(real *mod, real *data, int ndata)
2503 {
2504     int i, j;
2505     real sc, ss, arg;
2506
2507     for (i = 0; i < ndata; i++)
2508     {
2509         sc = ss = 0;
2510         for (j = 0; j < ndata; j++)
2511         {
2512             arg = (2.0*M_PI*i*j)/ndata;
2513             sc += data[j]*cos(arg);
2514             ss += data[j]*sin(arg);
2515         }
2516         mod[i] = sc*sc+ss*ss;
2517     }
2518     for (i = 0; i < ndata; i++)
2519     {
2520         if (mod[i] < 1e-7)
2521         {
2522             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2523         }
2524     }
2525 }
2526
2527
2528 static void make_bspline_moduli(splinevec bsp_mod,
2529                                 int nx, int ny, int nz, int order)
2530 {
2531     int nmax = max(nx, max(ny, nz));
2532     real *data, *ddata, *bsp_data;
2533     int i, k, l;
2534     real div;
2535
2536     snew(data, order);
2537     snew(ddata, order);
2538     snew(bsp_data, nmax);
2539
2540     data[order-1] = 0;
2541     data[1]       = 0;
2542     data[0]       = 1;
2543
2544     for (k = 3; k < order; k++)
2545     {
2546         div       = 1.0/(k-1.0);
2547         data[k-1] = 0;
2548         for (l = 1; l < (k-1); l++)
2549         {
2550             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2551         }
2552         data[0] = div*data[0];
2553     }
2554     /* differentiate */
2555     ddata[0] = -data[0];
2556     for (k = 1; k < order; k++)
2557     {
2558         ddata[k] = data[k-1]-data[k];
2559     }
2560     div           = 1.0/(order-1);
2561     data[order-1] = 0;
2562     for (l = 1; l < (order-1); l++)
2563     {
2564         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2565     }
2566     data[0] = div*data[0];
2567
2568     for (i = 0; i < nmax; i++)
2569     {
2570         bsp_data[i] = 0;
2571     }
2572     for (i = 1; i <= order; i++)
2573     {
2574         bsp_data[i] = data[i-1];
2575     }
2576
2577     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2578     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2579     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2580
2581     sfree(data);
2582     sfree(ddata);
2583     sfree(bsp_data);
2584 }
2585
2586
2587 /* Return the P3M optimal influence function */
2588 static double do_p3m_influence(double z, int order)
2589 {
2590     double z2, z4;
2591
2592     z2 = z*z;
2593     z4 = z2*z2;
2594
2595     /* The formula and most constants can be found in:
2596      * Ballenegger et al., JCTC 8, 936 (2012)
2597      */
2598     switch (order)
2599     {
2600         case 2:
2601             return 1.0 - 2.0*z2/3.0;
2602             break;
2603         case 3:
2604             return 1.0 - z2 + 2.0*z4/15.0;
2605             break;
2606         case 4:
2607             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2608             break;
2609         case 5:
2610             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2611             break;
2612         case 6:
2613             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2614             break;
2615         case 7:
2616             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2617         case 8:
2618             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2619             break;
2620     }
2621
2622     return 0.0;
2623 }
2624
2625 /* Calculate the P3M B-spline moduli for one dimension */
2626 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2627 {
2628     double zarg, zai, sinzai, infl;
2629     int    maxk, i;
2630
2631     if (order > 8)
2632     {
2633         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2634     }
2635
2636     zarg = M_PI/n;
2637
2638     maxk = (n + 1)/2;
2639
2640     for (i = -maxk; i < 0; i++)
2641     {
2642         zai          = zarg*i;
2643         sinzai       = sin(zai);
2644         infl         = do_p3m_influence(sinzai, order);
2645         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2646     }
2647     bsp_mod[0] = 1.0;
2648     for (i = 1; i < maxk; i++)
2649     {
2650         zai        = zarg*i;
2651         sinzai     = sin(zai);
2652         infl       = do_p3m_influence(sinzai, order);
2653         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2654     }
2655 }
2656
2657 /* Calculate the P3M B-spline moduli */
2658 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2659                                     int nx, int ny, int nz, int order)
2660 {
2661     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2662     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2663     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2664 }
2665
2666
2667 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2668 {
2669     int nslab, n, i;
2670     int fw, bw;
2671
2672     nslab = atc->nslab;
2673
2674     n = 0;
2675     for (i = 1; i <= nslab/2; i++)
2676     {
2677         fw = (atc->nodeid + i) % nslab;
2678         bw = (atc->nodeid - i + nslab) % nslab;
2679         if (n < nslab - 1)
2680         {
2681             atc->node_dest[n] = fw;
2682             atc->node_src[n]  = bw;
2683             n++;
2684         }
2685         if (n < nslab - 1)
2686         {
2687             atc->node_dest[n] = bw;
2688             atc->node_src[n]  = fw;
2689             n++;
2690         }
2691     }
2692 }
2693
2694 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2695 {
2696     int thread;
2697
2698     if (NULL != log)
2699     {
2700         fprintf(log, "Destroying PME data structures.\n");
2701     }
2702
2703     sfree((*pmedata)->nnx);
2704     sfree((*pmedata)->nny);
2705     sfree((*pmedata)->nnz);
2706
2707     pmegrids_destroy(&(*pmedata)->pmegridA);
2708
2709     sfree((*pmedata)->fftgridA);
2710     sfree((*pmedata)->cfftgridA);
2711     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2712
2713     if ((*pmedata)->pmegridB.grid.grid != NULL)
2714     {
2715         pmegrids_destroy(&(*pmedata)->pmegridB);
2716         sfree((*pmedata)->fftgridB);
2717         sfree((*pmedata)->cfftgridB);
2718         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2719     }
2720     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2721     {
2722         free_work(&(*pmedata)->work[thread]);
2723     }
2724     sfree((*pmedata)->work);
2725
2726     sfree(*pmedata);
2727     *pmedata = NULL;
2728
2729     return 0;
2730 }
2731
2732 static int mult_up(int n, int f)
2733 {
2734     return ((n + f - 1)/f)*f;
2735 }
2736
2737
2738 static double pme_load_imbalance(gmx_pme_t pme)
2739 {
2740     int    nma, nmi;
2741     double n1, n2, n3;
2742
2743     nma = pme->nnodes_major;
2744     nmi = pme->nnodes_minor;
2745
2746     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2747     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2748     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2749
2750     /* pme_solve is roughly double the cost of an fft */
2751
2752     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2753 }
2754
2755 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2756                           int dimind, gmx_bool bSpread)
2757 {
2758     int nk, k, s, thread;
2759
2760     atc->dimind    = dimind;
2761     atc->nslab     = 1;
2762     atc->nodeid    = 0;
2763     atc->pd_nalloc = 0;
2764 #ifdef GMX_MPI
2765     if (pme->nnodes > 1)
2766     {
2767         atc->mpi_comm = pme->mpi_comm_d[dimind];
2768         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2769         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2770     }
2771     if (debug)
2772     {
2773         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2774     }
2775 #endif
2776
2777     atc->bSpread   = bSpread;
2778     atc->pme_order = pme->pme_order;
2779
2780     if (atc->nslab > 1)
2781     {
2782         /* These three allocations are not required for particle decomp. */
2783         snew(atc->node_dest, atc->nslab);
2784         snew(atc->node_src, atc->nslab);
2785         setup_coordinate_communication(atc);
2786
2787         snew(atc->count_thread, pme->nthread);
2788         for (thread = 0; thread < pme->nthread; thread++)
2789         {
2790             snew(atc->count_thread[thread], atc->nslab);
2791         }
2792         atc->count = atc->count_thread[0];
2793         snew(atc->rcount, atc->nslab);
2794         snew(atc->buf_index, atc->nslab);
2795     }
2796
2797     atc->nthread = pme->nthread;
2798     if (atc->nthread > 1)
2799     {
2800         snew(atc->thread_plist, atc->nthread);
2801     }
2802     snew(atc->spline, atc->nthread);
2803     for (thread = 0; thread < atc->nthread; thread++)
2804     {
2805         if (atc->nthread > 1)
2806         {
2807             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2808             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2809         }
2810         snew(atc->spline[thread].thread_one, pme->nthread);
2811         atc->spline[thread].thread_one[thread] = 1;
2812     }
2813 }
2814
2815 static void
2816 init_overlap_comm(pme_overlap_t *  ol,
2817                   int              norder,
2818 #ifdef GMX_MPI
2819                   MPI_Comm         comm,
2820 #endif
2821                   int              nnodes,
2822                   int              nodeid,
2823                   int              ndata,
2824                   int              commplainsize)
2825 {
2826     int lbnd, rbnd, maxlr, b, i;
2827     int exten;
2828     int nn, nk;
2829     pme_grid_comm_t *pgc;
2830     gmx_bool bCont;
2831     int fft_start, fft_end, send_index1, recv_index1;
2832 #ifdef GMX_MPI
2833     MPI_Status stat;
2834
2835     ol->mpi_comm = comm;
2836 #endif
2837
2838     ol->nnodes = nnodes;
2839     ol->nodeid = nodeid;
2840
2841     /* Linear translation of the PME grid won't affect reciprocal space
2842      * calculations, so to optimize we only interpolate "upwards",
2843      * which also means we only have to consider overlap in one direction.
2844      * I.e., particles on this node might also be spread to grid indices
2845      * that belong to higher nodes (modulo nnodes)
2846      */
2847
2848     snew(ol->s2g0, ol->nnodes+1);
2849     snew(ol->s2g1, ol->nnodes);
2850     if (debug)
2851     {
2852         fprintf(debug, "PME slab boundaries:");
2853     }
2854     for (i = 0; i < nnodes; i++)
2855     {
2856         /* s2g0 the local interpolation grid start.
2857          * s2g1 the local interpolation grid end.
2858          * Because grid overlap communication only goes forward,
2859          * the grid the slabs for fft's should be rounded down.
2860          */
2861         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2862         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2863
2864         if (debug)
2865         {
2866             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2867         }
2868     }
2869     ol->s2g0[nnodes] = ndata;
2870     if (debug)
2871     {
2872         fprintf(debug, "\n");
2873     }
2874
2875     /* Determine with how many nodes we need to communicate the grid overlap */
2876     b = 0;
2877     do
2878     {
2879         b++;
2880         bCont = FALSE;
2881         for (i = 0; i < nnodes; i++)
2882         {
2883             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2884                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2885             {
2886                 bCont = TRUE;
2887             }
2888         }
2889     }
2890     while (bCont && b < nnodes);
2891     ol->noverlap_nodes = b - 1;
2892
2893     snew(ol->send_id, ol->noverlap_nodes);
2894     snew(ol->recv_id, ol->noverlap_nodes);
2895     for (b = 0; b < ol->noverlap_nodes; b++)
2896     {
2897         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2898         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2899     }
2900     snew(ol->comm_data, ol->noverlap_nodes);
2901
2902     ol->send_size = 0;
2903     for (b = 0; b < ol->noverlap_nodes; b++)
2904     {
2905         pgc = &ol->comm_data[b];
2906         /* Send */
2907         fft_start        = ol->s2g0[ol->send_id[b]];
2908         fft_end          = ol->s2g0[ol->send_id[b]+1];
2909         if (ol->send_id[b] < nodeid)
2910         {
2911             fft_start += ndata;
2912             fft_end   += ndata;
2913         }
2914         send_index1       = ol->s2g1[nodeid];
2915         send_index1       = min(send_index1, fft_end);
2916         pgc->send_index0  = fft_start;
2917         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2918         ol->send_size    += pgc->send_nindex;
2919
2920         /* We always start receiving to the first index of our slab */
2921         fft_start        = ol->s2g0[ol->nodeid];
2922         fft_end          = ol->s2g0[ol->nodeid+1];
2923         recv_index1      = ol->s2g1[ol->recv_id[b]];
2924         if (ol->recv_id[b] > nodeid)
2925         {
2926             recv_index1 -= ndata;
2927         }
2928         recv_index1      = min(recv_index1, fft_end);
2929         pgc->recv_index0 = fft_start;
2930         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2931     }
2932
2933 #ifdef GMX_MPI
2934     /* Communicate the buffer sizes to receive */
2935     for (b = 0; b < ol->noverlap_nodes; b++)
2936     {
2937         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2938                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2939                      ol->mpi_comm, &stat);
2940     }
2941 #endif
2942
2943     /* For non-divisible grid we need pme_order iso pme_order-1 */
2944     snew(ol->sendbuf, norder*commplainsize);
2945     snew(ol->recvbuf, norder*commplainsize);
2946 }
2947
2948 static void
2949 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2950                               int **global_to_local,
2951                               real **fraction_shift)
2952 {
2953     int i;
2954     int * gtl;
2955     real * fsh;
2956
2957     snew(gtl, 5*n);
2958     snew(fsh, 5*n);
2959     for (i = 0; (i < 5*n); i++)
2960     {
2961         /* Determine the global to local grid index */
2962         gtl[i] = (i - local_start + n) % n;
2963         /* For coordinates that fall within the local grid the fraction
2964          * is correct, we don't need to shift it.
2965          */
2966         fsh[i] = 0;
2967         if (local_range < n)
2968         {
2969             /* Due to rounding issues i could be 1 beyond the lower or
2970              * upper boundary of the local grid. Correct the index for this.
2971              * If we shift the index, we need to shift the fraction by
2972              * the same amount in the other direction to not affect
2973              * the weights.
2974              * Note that due to this shifting the weights at the end of
2975              * the spline might change, but that will only involve values
2976              * between zero and values close to the precision of a real,
2977              * which is anyhow the accuracy of the whole mesh calculation.
2978              */
2979             /* With local_range=0 we should not change i=local_start */
2980             if (i % n != local_start)
2981             {
2982                 if (gtl[i] == n-1)
2983                 {
2984                     gtl[i] = 0;
2985                     fsh[i] = -1;
2986                 }
2987                 else if (gtl[i] == local_range)
2988                 {
2989                     gtl[i] = local_range - 1;
2990                     fsh[i] = 1;
2991                 }
2992             }
2993         }
2994     }
2995
2996     *global_to_local = gtl;
2997     *fraction_shift  = fsh;
2998 }
2999
3000 static pme_spline_work_t *make_pme_spline_work(int order)
3001 {
3002     pme_spline_work_t *work;
3003
3004 #ifdef PME_SSE
3005     float  tmp[8];
3006     __m128 zero_SSE;
3007     int    of, i;
3008
3009     snew_aligned(work, 1, 16);
3010
3011     zero_SSE = _mm_setzero_ps();
3012
3013     /* Generate bit masks to mask out the unused grid entries,
3014      * as we only operate on order of the 8 grid entries that are
3015      * load into 2 SSE float registers.
3016      */
3017     for (of = 0; of < 8-(order-1); of++)
3018     {
3019         for (i = 0; i < 8; i++)
3020         {
3021             tmp[i] = (i >= of && i < of+order ? 1 : 0);
3022         }
3023         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
3024         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
3025         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
3026         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
3027     }
3028 #else
3029     work = NULL;
3030 #endif
3031
3032     return work;
3033 }
3034
3035 void gmx_pme_check_restrictions(int pme_order,
3036                                 int nkx, int nky, int nkz,
3037                                 int nnodes_major,
3038                                 int nnodes_minor,
3039                                 gmx_bool bUseThreads,
3040                                 gmx_bool bFatal,
3041                                 gmx_bool *bValidSettings)
3042 {
3043     if (pme_order > PME_ORDER_MAX)
3044     {
3045         if (!bFatal)
3046         {
3047             *bValidSettings = FALSE;
3048             return;
3049         }
3050         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3051                   pme_order, PME_ORDER_MAX);
3052     }
3053
3054     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3055         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3056         nkz <= pme_order)
3057     {
3058         if (!bFatal)
3059         {
3060             *bValidSettings = FALSE;
3061             return;
3062         }
3063         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3064                   pme_order);
3065     }
3066
3067     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3068      * We only allow multiple communication pulses in dim 1, not in dim 0.
3069      */
3070     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3071                         nkx != nnodes_major*(pme_order - 1)))
3072     {
3073         if (!bFatal)
3074         {
3075             *bValidSettings = FALSE;
3076             return;
3077         }
3078         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3079                   nkx/(double)nnodes_major, pme_order);
3080     }
3081
3082     if (bValidSettings != NULL)
3083     {
3084         *bValidSettings = TRUE;
3085     }
3086
3087     return;
3088 }
3089
3090 int gmx_pme_init(gmx_pme_t *         pmedata,
3091                  t_commrec *         cr,
3092                  int                 nnodes_major,
3093                  int                 nnodes_minor,
3094                  t_inputrec *        ir,
3095                  int                 homenr,
3096                  gmx_bool            bFreeEnergy,
3097                  gmx_bool            bReproducible,
3098                  int                 nthread)
3099 {
3100     gmx_pme_t pme = NULL;
3101
3102     int  use_threads, sum_use_threads;
3103     ivec ndata;
3104
3105     if (debug)
3106     {
3107         fprintf(debug, "Creating PME data structures.\n");
3108     }
3109     snew(pme, 1);
3110
3111     pme->redist_init         = FALSE;
3112     pme->sum_qgrid_tmp       = NULL;
3113     pme->sum_qgrid_dd_tmp    = NULL;
3114     pme->buf_nalloc          = 0;
3115     pme->redist_buf_nalloc   = 0;
3116
3117     pme->nnodes              = 1;
3118     pme->bPPnode             = TRUE;
3119
3120     pme->nnodes_major        = nnodes_major;
3121     pme->nnodes_minor        = nnodes_minor;
3122
3123 #ifdef GMX_MPI
3124     if (nnodes_major*nnodes_minor > 1)
3125     {
3126         pme->mpi_comm = cr->mpi_comm_mygroup;
3127
3128         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3129         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3130         if (pme->nnodes != nnodes_major*nnodes_minor)
3131         {
3132             gmx_incons("PME node count mismatch");
3133         }
3134     }
3135     else
3136     {
3137         pme->mpi_comm = MPI_COMM_NULL;
3138     }
3139 #endif
3140
3141     if (pme->nnodes == 1)
3142     {
3143 #ifdef GMX_MPI
3144         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3145         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3146 #endif
3147         pme->ndecompdim   = 0;
3148         pme->nodeid_major = 0;
3149         pme->nodeid_minor = 0;
3150 #ifdef GMX_MPI
3151         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3152 #endif
3153     }
3154     else
3155     {
3156         if (nnodes_minor == 1)
3157         {
3158 #ifdef GMX_MPI
3159             pme->mpi_comm_d[0] = pme->mpi_comm;
3160             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3161 #endif
3162             pme->ndecompdim   = 1;
3163             pme->nodeid_major = pme->nodeid;
3164             pme->nodeid_minor = 0;
3165
3166         }
3167         else if (nnodes_major == 1)
3168         {
3169 #ifdef GMX_MPI
3170             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3171             pme->mpi_comm_d[1] = pme->mpi_comm;
3172 #endif
3173             pme->ndecompdim   = 1;
3174             pme->nodeid_major = 0;
3175             pme->nodeid_minor = pme->nodeid;
3176         }
3177         else
3178         {
3179             if (pme->nnodes % nnodes_major != 0)
3180             {
3181                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3182             }
3183             pme->ndecompdim = 2;
3184
3185 #ifdef GMX_MPI
3186             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3187                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3188             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3189                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3190
3191             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3192             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3193             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3194             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3195 #endif
3196         }
3197         pme->bPPnode = (cr->duty & DUTY_PP);
3198     }
3199
3200     pme->nthread = nthread;
3201
3202      /* Check if any of the PME MPI ranks uses threads */
3203     use_threads = (pme->nthread > 1 ? 1 : 0);
3204 #ifdef GMX_MPI
3205     if (pme->nnodes > 1)
3206     {
3207         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3208                       MPI_SUM, pme->mpi_comm);
3209     }
3210     else
3211 #endif
3212     {
3213         sum_use_threads = use_threads;
3214     }
3215     pme->bUseThreads = (sum_use_threads > 0);
3216
3217     if (ir->ePBC == epbcSCREW)
3218     {
3219         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3220     }
3221
3222     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3223     pme->nkx         = ir->nkx;
3224     pme->nky         = ir->nky;
3225     pme->nkz         = ir->nkz;
3226     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3227     pme->pme_order   = ir->pme_order;
3228     pme->epsilon_r   = ir->epsilon_r;
3229
3230     /* If we violate restrictions, generate a fatal error here */
3231     gmx_pme_check_restrictions(pme->pme_order,
3232                                pme->nkx, pme->nky, pme->nkz,
3233                                pme->nnodes_major,
3234                                pme->nnodes_minor,
3235                                pme->bUseThreads,
3236                                TRUE,
3237                                NULL);
3238
3239     if (pme->nnodes > 1)
3240     {
3241         double imbal;
3242
3243 #ifdef GMX_MPI
3244         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3245         MPI_Type_commit(&(pme->rvec_mpi));
3246 #endif
3247
3248         /* Note that the charge spreading and force gathering, which usually
3249          * takes about the same amount of time as FFT+solve_pme,
3250          * is always fully load balanced
3251          * (unless the charge distribution is inhomogeneous).
3252          */
3253
3254         imbal = pme_load_imbalance(pme);
3255         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3256         {
3257             fprintf(stderr,
3258                     "\n"
3259                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3260                     "      For optimal PME load balancing\n"
3261                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3262                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3263                     "\n",
3264                     (int)((imbal-1)*100 + 0.5),
3265                     pme->nkx, pme->nky, pme->nnodes_major,
3266                     pme->nky, pme->nkz, pme->nnodes_minor);
3267         }
3268     }
3269
3270     /* For non-divisible grid we need pme_order iso pme_order-1 */
3271     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3272      * y is always copied through a buffer: we don't need padding in z,
3273      * but we do need the overlap in x because of the communication order.
3274      */
3275     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3276 #ifdef GMX_MPI
3277                       pme->mpi_comm_d[0],
3278 #endif
3279                       pme->nnodes_major, pme->nodeid_major,
3280                       pme->nkx,
3281                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3282
3283     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3284      * We do this with an offset buffer of equal size, so we need to allocate
3285      * extra for the offset. That's what the (+1)*pme->nkz is for.
3286      */
3287     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3288 #ifdef GMX_MPI
3289                       pme->mpi_comm_d[1],
3290 #endif
3291                       pme->nnodes_minor, pme->nodeid_minor,
3292                       pme->nky,
3293                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3294
3295     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3296      * Note that gmx_pme_check_restrictions checked for this already.
3297      */
3298     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3299     {
3300         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3301     }
3302
3303     snew(pme->bsp_mod[XX], pme->nkx);
3304     snew(pme->bsp_mod[YY], pme->nky);
3305     snew(pme->bsp_mod[ZZ], pme->nkz);
3306
3307     /* The required size of the interpolation grid, including overlap.
3308      * The allocated size (pmegrid_n?) might be slightly larger.
3309      */
3310     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3311         pme->overlap[0].s2g0[pme->nodeid_major];
3312     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3313         pme->overlap[1].s2g0[pme->nodeid_minor];
3314     pme->pmegrid_nz_base = pme->nkz;
3315     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3316     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3317
3318     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3319     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3320     pme->pmegrid_start_iz = 0;
3321
3322     make_gridindex5_to_localindex(pme->nkx,
3323                                   pme->pmegrid_start_ix,
3324                                   pme->pmegrid_nx - (pme->pme_order-1),
3325                                   &pme->nnx, &pme->fshx);
3326     make_gridindex5_to_localindex(pme->nky,
3327                                   pme->pmegrid_start_iy,
3328                                   pme->pmegrid_ny - (pme->pme_order-1),
3329                                   &pme->nny, &pme->fshy);
3330     make_gridindex5_to_localindex(pme->nkz,
3331                                   pme->pmegrid_start_iz,
3332                                   pme->pmegrid_nz_base,
3333                                   &pme->nnz, &pme->fshz);
3334
3335     pmegrids_init(&pme->pmegridA,
3336                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3337                   pme->pmegrid_nz_base,
3338                   pme->pme_order,
3339                   pme->bUseThreads,
3340                   pme->nthread,
3341                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3342                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3343
3344     pme->spline_work = make_pme_spline_work(pme->pme_order);
3345
3346     ndata[0] = pme->nkx;
3347     ndata[1] = pme->nky;
3348     ndata[2] = pme->nkz;
3349
3350     /* This routine will allocate the grid data to fit the FFTs */
3351     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3352                             &pme->fftgridA, &pme->cfftgridA,
3353                             pme->mpi_comm_d,
3354                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3355                             bReproducible, pme->nthread);
3356
3357     if (bFreeEnergy)
3358     {
3359         pmegrids_init(&pme->pmegridB,
3360                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3361                       pme->pmegrid_nz_base,
3362                       pme->pme_order,
3363                       pme->bUseThreads,
3364                       pme->nthread,
3365                       pme->nkx % pme->nnodes_major != 0,
3366                       pme->nky % pme->nnodes_minor != 0);
3367
3368         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3369                                 &pme->fftgridB, &pme->cfftgridB,
3370                                 pme->mpi_comm_d,
3371                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3372                                 bReproducible, pme->nthread);
3373     }
3374     else
3375     {
3376         pme->pmegridB.grid.grid = NULL;
3377         pme->fftgridB           = NULL;
3378         pme->cfftgridB          = NULL;
3379     }
3380
3381     if (!pme->bP3M)
3382     {
3383         /* Use plain SPME B-spline interpolation */
3384         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3385     }
3386     else
3387     {
3388         /* Use the P3M grid-optimized influence function */
3389         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3390     }
3391
3392     /* Use atc[0] for spreading */
3393     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3394     if (pme->ndecompdim >= 2)
3395     {
3396         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3397     }
3398
3399     if (pme->nnodes == 1)
3400     {
3401         pme->atc[0].n = homenr;
3402         pme_realloc_atomcomm_things(&pme->atc[0]);
3403     }
3404
3405     {
3406         int thread;
3407
3408         /* Use fft5d, order after FFT is y major, z, x minor */
3409
3410         snew(pme->work, pme->nthread);
3411         for (thread = 0; thread < pme->nthread; thread++)
3412         {
3413             realloc_work(&pme->work[thread], pme->nkx);
3414         }
3415     }
3416
3417     *pmedata = pme;
3418
3419     return 0;
3420 }
3421
3422 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3423 {
3424     int d, t;
3425
3426     for (d = 0; d < DIM; d++)
3427     {
3428         if (new->grid.n[d] > old->grid.n[d])
3429         {
3430             return;
3431         }
3432     }
3433
3434     sfree_aligned(new->grid.grid);
3435     new->grid.grid = old->grid.grid;
3436
3437     if (new->grid_th != NULL && new->nthread == old->nthread)
3438     {
3439         sfree_aligned(new->grid_all);
3440         for (t = 0; t < new->nthread; t++)
3441         {
3442             new->grid_th[t].grid = old->grid_th[t].grid;
3443         }
3444     }
3445 }
3446
3447 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3448                    t_commrec *         cr,
3449                    gmx_pme_t           pme_src,
3450                    const t_inputrec *  ir,
3451                    ivec                grid_size)
3452 {
3453     t_inputrec irc;
3454     int homenr;
3455     int ret;
3456
3457     irc     = *ir;
3458     irc.nkx = grid_size[XX];
3459     irc.nky = grid_size[YY];
3460     irc.nkz = grid_size[ZZ];
3461
3462     if (pme_src->nnodes == 1)
3463     {
3464         homenr = pme_src->atc[0].n;
3465     }
3466     else
3467     {
3468         homenr = -1;
3469     }
3470
3471     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3472                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3473
3474     if (ret == 0)
3475     {
3476         /* We can easily reuse the allocated pme grids in pme_src */
3477         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3478         /* We would like to reuse the fft grids, but that's harder */
3479     }
3480
3481     return ret;
3482 }
3483
3484
3485 static void copy_local_grid(gmx_pme_t pme,
3486                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3487 {
3488     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3489     int  fft_my, fft_mz;
3490     int  nsx, nsy, nsz;
3491     ivec nf;
3492     int  offx, offy, offz, x, y, z, i0, i0t;
3493     int  d;
3494     pmegrid_t *pmegrid;
3495     real *grid_th;
3496
3497     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3498                                    local_fft_ndata,
3499                                    local_fft_offset,
3500                                    local_fft_size);
3501     fft_my = local_fft_size[YY];
3502     fft_mz = local_fft_size[ZZ];
3503
3504     pmegrid = &pmegrids->grid_th[thread];
3505
3506     nsx = pmegrid->s[XX];
3507     nsy = pmegrid->s[YY];
3508     nsz = pmegrid->s[ZZ];
3509
3510     for (d = 0; d < DIM; d++)
3511     {
3512         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3513                     local_fft_ndata[d] - pmegrid->offset[d]);
3514     }
3515
3516     offx = pmegrid->offset[XX];
3517     offy = pmegrid->offset[YY];
3518     offz = pmegrid->offset[ZZ];
3519
3520     /* Directly copy the non-overlapping parts of the local grids.
3521      * This also initializes the full grid.
3522      */
3523     grid_th = pmegrid->grid;
3524     for (x = 0; x < nf[XX]; x++)
3525     {
3526         for (y = 0; y < nf[YY]; y++)
3527         {
3528             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3529             i0t = (x*nsy + y)*nsz;
3530             for (z = 0; z < nf[ZZ]; z++)
3531             {
3532                 fftgrid[i0+z] = grid_th[i0t+z];
3533             }
3534         }
3535     }
3536 }
3537
3538 static void
3539 reduce_threadgrid_overlap(gmx_pme_t pme,
3540                           const pmegrids_t *pmegrids, int thread,
3541                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3542 {
3543     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3544     int  fft_nx, fft_ny, fft_nz;
3545     int  fft_my, fft_mz;
3546     int  buf_my = -1;
3547     int  nsx, nsy, nsz;
3548     ivec ne;
3549     int  offx, offy, offz, x, y, z, i0, i0t;
3550     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3551     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3552     gmx_bool bCommX, bCommY;
3553     int  d;
3554     int  thread_f;
3555     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3556     const real *grid_th;
3557     real *commbuf = NULL;
3558
3559     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3560                                    local_fft_ndata,
3561                                    local_fft_offset,
3562                                    local_fft_size);
3563     fft_nx = local_fft_ndata[XX];
3564     fft_ny = local_fft_ndata[YY];
3565     fft_nz = local_fft_ndata[ZZ];
3566
3567     fft_my = local_fft_size[YY];
3568     fft_mz = local_fft_size[ZZ];
3569
3570     /* This routine is called when all thread have finished spreading.
3571      * Here each thread sums grid contributions calculated by other threads
3572      * to the thread local grid volume.
3573      * To minimize the number of grid copying operations,
3574      * this routines sums immediately from the pmegrid to the fftgrid.
3575      */
3576
3577     /* Determine which part of the full node grid we should operate on,
3578      * this is our thread local part of the full grid.
3579      */
3580     pmegrid = &pmegrids->grid_th[thread];
3581
3582     for (d = 0; d < DIM; d++)
3583     {
3584         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3585                     local_fft_ndata[d]);
3586     }
3587
3588     offx = pmegrid->offset[XX];
3589     offy = pmegrid->offset[YY];
3590     offz = pmegrid->offset[ZZ];
3591
3592
3593     bClearBufX  = TRUE;
3594     bClearBufY  = TRUE;
3595     bClearBufXY = TRUE;
3596
3597     /* Now loop over all the thread data blocks that contribute
3598      * to the grid region we (our thread) are operating on.
3599      */
3600     /* Note that ffy_nx/y is equal to the number of grid points
3601      * between the first point of our node grid and the one of the next node.
3602      */
3603     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3604     {
3605         fx     = pmegrid->ci[XX] + sx;
3606         ox     = 0;
3607         bCommX = FALSE;
3608         if (fx < 0)
3609         {
3610             fx    += pmegrids->nc[XX];
3611             ox    -= fft_nx;
3612             bCommX = (pme->nnodes_major > 1);
3613         }
3614         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3615         ox       += pmegrid_g->offset[XX];
3616         if (!bCommX)
3617         {
3618             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3619         }
3620         else
3621         {
3622             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3623         }
3624
3625         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3626         {
3627             fy     = pmegrid->ci[YY] + sy;
3628             oy     = 0;
3629             bCommY = FALSE;
3630             if (fy < 0)
3631             {
3632                 fy    += pmegrids->nc[YY];
3633                 oy    -= fft_ny;
3634                 bCommY = (pme->nnodes_minor > 1);
3635             }
3636             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3637             oy       += pmegrid_g->offset[YY];
3638             if (!bCommY)
3639             {
3640                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3641             }
3642             else
3643             {
3644                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3645             }
3646
3647             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3648             {
3649                 fz = pmegrid->ci[ZZ] + sz;
3650                 oz = 0;
3651                 if (fz < 0)
3652                 {
3653                     fz += pmegrids->nc[ZZ];
3654                     oz -= fft_nz;
3655                 }
3656                 pmegrid_g = &pmegrids->grid_th[fz];
3657                 oz       += pmegrid_g->offset[ZZ];
3658                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3659
3660                 if (sx == 0 && sy == 0 && sz == 0)
3661                 {
3662                     /* We have already added our local contribution
3663                      * before calling this routine, so skip it here.
3664                      */
3665                     continue;
3666                 }
3667
3668                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3669
3670                 pmegrid_f = &pmegrids->grid_th[thread_f];
3671
3672                 grid_th = pmegrid_f->grid;
3673
3674                 nsx = pmegrid_f->s[XX];
3675                 nsy = pmegrid_f->s[YY];
3676                 nsz = pmegrid_f->s[ZZ];
3677
3678 #ifdef DEBUG_PME_REDUCE
3679                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3680                        pme->nodeid, thread, thread_f,
3681                        pme->pmegrid_start_ix,
3682                        pme->pmegrid_start_iy,
3683                        pme->pmegrid_start_iz,
3684                        sx, sy, sz,
3685                        offx-ox, tx1-ox, offx, tx1,
3686                        offy-oy, ty1-oy, offy, ty1,
3687                        offz-oz, tz1-oz, offz, tz1);
3688 #endif
3689
3690                 if (!(bCommX || bCommY))
3691                 {
3692                     /* Copy from the thread local grid to the node grid */
3693                     for (x = offx; x < tx1; x++)
3694                     {
3695                         for (y = offy; y < ty1; y++)
3696                         {
3697                             i0  = (x*fft_my + y)*fft_mz;
3698                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3699                             for (z = offz; z < tz1; z++)
3700                             {
3701                                 fftgrid[i0+z] += grid_th[i0t+z];
3702                             }
3703                         }
3704                     }
3705                 }
3706                 else
3707                 {
3708                     /* The order of this conditional decides
3709                      * where the corner volume gets stored with x+y decomp.
3710                      */
3711                     if (bCommY)
3712                     {
3713                         commbuf = commbuf_y;
3714                         buf_my  = ty1 - offy;
3715                         if (bCommX)
3716                         {
3717                             /* We index commbuf modulo the local grid size */
3718                             commbuf += buf_my*fft_nx*fft_nz;
3719
3720                             bClearBuf   = bClearBufXY;
3721                             bClearBufXY = FALSE;
3722                         }
3723                         else
3724                         {
3725                             bClearBuf  = bClearBufY;
3726                             bClearBufY = FALSE;
3727                         }
3728                     }
3729                     else
3730                     {
3731                         commbuf    = commbuf_x;
3732                         buf_my     = fft_ny;
3733                         bClearBuf  = bClearBufX;
3734                         bClearBufX = FALSE;
3735                     }
3736
3737                     /* Copy to the communication buffer */
3738                     for (x = offx; x < tx1; x++)
3739                     {
3740                         for (y = offy; y < ty1; y++)
3741                         {
3742                             i0  = (x*buf_my + y)*fft_nz;
3743                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3744
3745                             if (bClearBuf)
3746                             {
3747                                 /* First access of commbuf, initialize it */
3748                                 for (z = offz; z < tz1; z++)
3749                                 {
3750                                     commbuf[i0+z]  = grid_th[i0t+z];
3751                                 }
3752                             }
3753                             else
3754                             {
3755                                 for (z = offz; z < tz1; z++)
3756                                 {
3757                                     commbuf[i0+z] += grid_th[i0t+z];
3758                                 }
3759                             }
3760                         }
3761                     }
3762                 }
3763             }
3764         }
3765     }
3766 }
3767
3768
3769 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3770 {
3771     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3772     pme_overlap_t *overlap;
3773     int  send_index0, send_nindex;
3774     int  recv_nindex;
3775 #ifdef GMX_MPI
3776     MPI_Status stat;
3777 #endif
3778     int  send_size_y, recv_size_y;
3779     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3780     real *sendptr, *recvptr;
3781     int  x, y, z, indg, indb;
3782
3783     /* Note that this routine is only used for forward communication.
3784      * Since the force gathering, unlike the charge spreading,
3785      * can be trivially parallelized over the particles,
3786      * the backwards process is much simpler and can use the "old"
3787      * communication setup.
3788      */
3789
3790     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3791                                    local_fft_ndata,
3792                                    local_fft_offset,
3793                                    local_fft_size);
3794
3795     if (pme->nnodes_minor > 1)
3796     {
3797         /* Major dimension */
3798         overlap = &pme->overlap[1];
3799
3800         if (pme->nnodes_major > 1)
3801         {
3802             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3803         }
3804         else
3805         {
3806             size_yx = 0;
3807         }
3808         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3809
3810         send_size_y = overlap->send_size;
3811
3812         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3813         {
3814             send_id       = overlap->send_id[ipulse];
3815             recv_id       = overlap->recv_id[ipulse];
3816             send_index0   =
3817                 overlap->comm_data[ipulse].send_index0 -
3818                 overlap->comm_data[0].send_index0;
3819             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3820             /* We don't use recv_index0, as we always receive starting at 0 */
3821             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3822             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3823
3824             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3825             recvptr = overlap->recvbuf;
3826
3827 #ifdef GMX_MPI
3828             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3829                          send_id, ipulse,
3830                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3831                          recv_id, ipulse,
3832                          overlap->mpi_comm, &stat);
3833 #endif
3834
3835             for (x = 0; x < local_fft_ndata[XX]; x++)
3836             {
3837                 for (y = 0; y < recv_nindex; y++)
3838                 {
3839                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3840                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3841                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3842                     {
3843                         fftgrid[indg+z] += recvptr[indb+z];
3844                     }
3845                 }
3846             }
3847
3848             if (pme->nnodes_major > 1)
3849             {
3850                 /* Copy from the received buffer to the send buffer for dim 0 */
3851                 sendptr = pme->overlap[0].sendbuf;
3852                 for (x = 0; x < size_yx; x++)
3853                 {
3854                     for (y = 0; y < recv_nindex; y++)
3855                     {
3856                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3857                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3858                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3859                         {
3860                             sendptr[indg+z] += recvptr[indb+z];
3861                         }
3862                     }
3863                 }
3864             }
3865         }
3866     }
3867
3868     /* We only support a single pulse here.
3869      * This is not a severe limitation, as this code is only used
3870      * with OpenMP and with OpenMP the (PME) domains can be larger.
3871      */
3872     if (pme->nnodes_major > 1)
3873     {
3874         /* Major dimension */
3875         overlap = &pme->overlap[0];
3876
3877         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3878         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3879
3880         ipulse = 0;
3881
3882         send_id       = overlap->send_id[ipulse];
3883         recv_id       = overlap->recv_id[ipulse];
3884         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3885         /* We don't use recv_index0, as we always receive starting at 0 */
3886         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3887
3888         sendptr = overlap->sendbuf;
3889         recvptr = overlap->recvbuf;
3890
3891         if (debug != NULL)
3892         {
3893             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3894                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3895         }
3896
3897 #ifdef GMX_MPI
3898         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3899                      send_id, ipulse,
3900                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3901                      recv_id, ipulse,
3902                      overlap->mpi_comm, &stat);
3903 #endif
3904
3905         for (x = 0; x < recv_nindex; x++)
3906         {
3907             for (y = 0; y < local_fft_ndata[YY]; y++)
3908             {
3909                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3910                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3911                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3912                 {
3913                     fftgrid[indg+z] += recvptr[indb+z];
3914                 }
3915             }
3916         }
3917     }
3918 }
3919
3920
3921 static void spread_on_grid(gmx_pme_t pme,
3922                            pme_atomcomm_t *atc, pmegrids_t *grids,
3923                            gmx_bool bCalcSplines, gmx_bool bSpread,
3924                            real *fftgrid)
3925 {
3926     int nthread, thread;
3927 #ifdef PME_TIME_THREADS
3928     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3929     static double cs1     = 0, cs2 = 0, cs3 = 0;
3930     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3931     static int cnt        = 0;
3932 #endif
3933
3934     nthread = pme->nthread;
3935     assert(nthread > 0);
3936
3937 #ifdef PME_TIME_THREADS
3938     c1 = omp_cyc_start();
3939 #endif
3940     if (bCalcSplines)
3941     {
3942 #pragma omp parallel for num_threads(nthread) schedule(static)
3943         for (thread = 0; thread < nthread; thread++)
3944         {
3945             int start, end;
3946
3947             start = atc->n* thread   /nthread;
3948             end   = atc->n*(thread+1)/nthread;
3949
3950             /* Compute fftgrid index for all atoms,
3951              * with help of some extra variables.
3952              */
3953             calc_interpolation_idx(pme, atc, start, end, thread);
3954         }
3955     }
3956 #ifdef PME_TIME_THREADS
3957     c1   = omp_cyc_end(c1);
3958     cs1 += (double)c1;
3959 #endif
3960
3961 #ifdef PME_TIME_THREADS
3962     c2 = omp_cyc_start();
3963 #endif
3964 #pragma omp parallel for num_threads(nthread) schedule(static)
3965     for (thread = 0; thread < nthread; thread++)
3966     {
3967         splinedata_t *spline;
3968         pmegrid_t *grid = NULL;
3969
3970         /* make local bsplines  */
3971         if (grids == NULL || !pme->bUseThreads)
3972         {
3973             spline = &atc->spline[0];
3974
3975             spline->n = atc->n;
3976
3977             if (bSpread)
3978             {
3979                 grid = &grids->grid;
3980             }
3981         }
3982         else
3983         {
3984             spline = &atc->spline[thread];
3985
3986             if (grids->nthread == 1)
3987             {
3988                 /* One thread, we operate on all charges */
3989                 spline->n = atc->n;
3990             }
3991             else
3992             {
3993                 /* Get the indices our thread should operate on */
3994                 make_thread_local_ind(atc, thread, spline);
3995             }
3996
3997             grid = &grids->grid_th[thread];
3998         }
3999
4000         if (bCalcSplines)
4001         {
4002             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
4003                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
4004         }
4005
4006         if (bSpread)
4007         {
4008             /* put local atoms on grid. */
4009 #ifdef PME_TIME_SPREAD
4010             ct1a = omp_cyc_start();
4011 #endif
4012             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4013
4014             if (pme->bUseThreads)
4015             {
4016                 copy_local_grid(pme, grids, thread, fftgrid);
4017             }
4018 #ifdef PME_TIME_SPREAD
4019             ct1a          = omp_cyc_end(ct1a);
4020             cs1a[thread] += (double)ct1a;
4021 #endif
4022         }
4023     }
4024 #ifdef PME_TIME_THREADS
4025     c2   = omp_cyc_end(c2);
4026     cs2 += (double)c2;
4027 #endif
4028
4029     if (bSpread && pme->bUseThreads)
4030     {
4031 #ifdef PME_TIME_THREADS
4032         c3 = omp_cyc_start();
4033 #endif
4034 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4035         for (thread = 0; thread < grids->nthread; thread++)
4036         {
4037             reduce_threadgrid_overlap(pme, grids, thread,
4038                                       fftgrid,
4039                                       pme->overlap[0].sendbuf,
4040                                       pme->overlap[1].sendbuf);
4041         }
4042 #ifdef PME_TIME_THREADS
4043         c3   = omp_cyc_end(c3);
4044         cs3 += (double)c3;
4045 #endif
4046
4047         if (pme->nnodes > 1)
4048         {
4049             /* Communicate the overlapping part of the fftgrid.
4050              * For this communication call we need to check pme->bUseThreads
4051              * to have all ranks communicate here, regardless of pme->nthread.
4052              */
4053             sum_fftgrid_dd(pme, fftgrid);
4054         }
4055     }
4056
4057 #ifdef PME_TIME_THREADS
4058     cnt++;
4059     if (cnt % 20 == 0)
4060     {
4061         printf("idx %.2f spread %.2f red %.2f",
4062                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4063 #ifdef PME_TIME_SPREAD
4064         for (thread = 0; thread < nthread; thread++)
4065         {
4066             printf(" %.2f", cs1a[thread]*1e-9);
4067         }
4068 #endif
4069         printf("\n");
4070     }
4071 #endif
4072 }
4073
4074
4075 static void dump_grid(FILE *fp,
4076                       int sx, int sy, int sz, int nx, int ny, int nz,
4077                       int my, int mz, const real *g)
4078 {
4079     int x, y, z;
4080
4081     for (x = 0; x < nx; x++)
4082     {
4083         for (y = 0; y < ny; y++)
4084         {
4085             for (z = 0; z < nz; z++)
4086             {
4087                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4088                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4089             }
4090         }
4091     }
4092 }
4093
4094 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4095 {
4096     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4097
4098     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4099                                    local_fft_ndata,
4100                                    local_fft_offset,
4101                                    local_fft_size);
4102
4103     dump_grid(stderr,
4104               pme->pmegrid_start_ix,
4105               pme->pmegrid_start_iy,
4106               pme->pmegrid_start_iz,
4107               pme->pmegrid_nx-pme->pme_order+1,
4108               pme->pmegrid_ny-pme->pme_order+1,
4109               pme->pmegrid_nz-pme->pme_order+1,
4110               local_fft_size[YY],
4111               local_fft_size[ZZ],
4112               fftgrid);
4113 }
4114
4115
4116 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4117 {
4118     pme_atomcomm_t *atc;
4119     pmegrids_t *grid;
4120
4121     if (pme->nnodes > 1)
4122     {
4123         gmx_incons("gmx_pme_calc_energy called in parallel");
4124     }
4125     if (pme->bFEP > 1)
4126     {
4127         gmx_incons("gmx_pme_calc_energy with free energy");
4128     }
4129
4130     atc            = &pme->atc_energy;
4131     atc->nthread   = 1;
4132     if (atc->spline == NULL)
4133     {
4134         snew(atc->spline, atc->nthread);
4135     }
4136     atc->nslab     = 1;
4137     atc->bSpread   = TRUE;
4138     atc->pme_order = pme->pme_order;
4139     atc->n         = n;
4140     pme_realloc_atomcomm_things(atc);
4141     atc->x         = x;
4142     atc->q         = q;
4143
4144     /* We only use the A-charges grid */
4145     grid = &pme->pmegridA;
4146
4147     /* Only calculate the spline coefficients, don't actually spread */
4148     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4149
4150     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4151 }
4152
4153
4154 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4155                                    t_nrnb *nrnb, t_inputrec *ir,
4156                                    gmx_large_int_t step)
4157 {
4158     /* Reset all the counters related to performance over the run */
4159     wallcycle_stop(wcycle, ewcRUN);
4160     wallcycle_reset_all(wcycle);
4161     init_nrnb(nrnb);
4162     if (ir->nsteps >= 0)
4163     {
4164         /* ir->nsteps is not used here, but we update it for consistency */
4165         ir->nsteps -= step - ir->init_step;
4166     }
4167     ir->init_step = step;
4168     wallcycle_start(wcycle, ewcRUN);
4169 }
4170
4171
4172 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4173                                ivec grid_size,
4174                                t_commrec *cr, t_inputrec *ir,
4175                                gmx_pme_t *pme_ret)
4176 {
4177     int ind;
4178     gmx_pme_t pme = NULL;
4179
4180     ind = 0;
4181     while (ind < *npmedata)
4182     {
4183         pme = (*pmedata)[ind];
4184         if (pme->nkx == grid_size[XX] &&
4185             pme->nky == grid_size[YY] &&
4186             pme->nkz == grid_size[ZZ])
4187         {
4188             *pme_ret = pme;
4189
4190             return;
4191         }
4192
4193         ind++;
4194     }
4195
4196     (*npmedata)++;
4197     srenew(*pmedata, *npmedata);
4198
4199     /* Generate a new PME data structure, copying part of the old pointers */
4200     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4201
4202     *pme_ret = (*pmedata)[ind];
4203 }
4204
4205
4206 int gmx_pmeonly(gmx_pme_t pme,
4207                 t_commrec *cr,    t_nrnb *nrnb,
4208                 gmx_wallcycle_t wcycle,
4209                 real ewaldcoeff,  gmx_bool bGatherOnly,
4210                 t_inputrec *ir)
4211 {
4212     int npmedata;
4213     gmx_pme_t *pmedata;
4214     gmx_pme_pp_t pme_pp;
4215     int  ret;
4216     int  natoms;
4217     matrix box;
4218     rvec *x_pp      = NULL, *f_pp = NULL;
4219     real *chargeA   = NULL, *chargeB = NULL;
4220     real lambda     = 0;
4221     int  maxshift_x = 0, maxshift_y = 0;
4222     real energy, dvdlambda;
4223     matrix vir;
4224     float cycles;
4225     int  count;
4226     gmx_bool bEnerVir;
4227     gmx_large_int_t step, step_rel;
4228     ivec grid_switch;
4229
4230     /* This data will only use with PME tuning, i.e. switching PME grids */
4231     npmedata = 1;
4232     snew(pmedata, npmedata);
4233     pmedata[0] = pme;
4234
4235     pme_pp = gmx_pme_pp_init(cr);
4236
4237     init_nrnb(nrnb);
4238
4239     count = 0;
4240     do /****** this is a quasi-loop over time steps! */
4241     {
4242         /* The reason for having a loop here is PME grid tuning/switching */
4243         do
4244         {
4245             /* Domain decomposition */
4246             ret = gmx_pme_recv_q_x(pme_pp,
4247                                    &natoms,
4248                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4249                                    &maxshift_x, &maxshift_y,
4250                                    &pme->bFEP, &lambda,
4251                                    &bEnerVir,
4252                                    &step,
4253                                    grid_switch, &ewaldcoeff);
4254
4255             if (ret == pmerecvqxSWITCHGRID)
4256             {
4257                 /* Switch the PME grid to grid_switch */
4258                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4259             }
4260
4261             if (ret == pmerecvqxRESETCOUNTERS)
4262             {
4263                 /* Reset the cycle and flop counters */
4264                 reset_pmeonly_counters(cr, wcycle, nrnb, ir, step);
4265             }
4266         }
4267         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4268
4269         if (ret == pmerecvqxFINISH)
4270         {
4271             /* We should stop: break out of the loop */
4272             break;
4273         }
4274
4275         step_rel = step - ir->init_step;
4276
4277         if (count == 0)
4278         {
4279             wallcycle_start(wcycle, ewcRUN);
4280         }
4281
4282         wallcycle_start(wcycle, ewcPMEMESH);
4283
4284         dvdlambda = 0;
4285         clear_mat(vir);
4286         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4287                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4288                    &energy, lambda, &dvdlambda,
4289                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4290
4291         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4292
4293         gmx_pme_send_force_vir_ener(pme_pp,
4294                                     f_pp, vir, energy, dvdlambda,
4295                                     cycles);
4296
4297         count++;
4298     } /***** end of quasi-loop, we stop with the break above */
4299     while (TRUE);
4300
4301     return 0;
4302 }
4303
4304 int gmx_pme_do(gmx_pme_t pme,
4305                int start,       int homenr,
4306                rvec x[],        rvec f[],
4307                real *chargeA,   real *chargeB,
4308                matrix box, t_commrec *cr,
4309                int  maxshift_x, int maxshift_y,
4310                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4311                matrix vir,      real ewaldcoeff,
4312                real *energy,    real lambda,
4313                real *dvdlambda, int flags)
4314 {
4315     int     q, d, i, j, ntot, npme;
4316     int     nx, ny, nz;
4317     int     n_d, local_ny;
4318     pme_atomcomm_t *atc = NULL;
4319     pmegrids_t *pmegrid = NULL;
4320     real    *grid       = NULL;
4321     real    *ptr;
4322     rvec    *x_d, *f_d;
4323     real    *charge = NULL, *q_d;
4324     real    energy_AB[2];
4325     matrix  vir_AB[2];
4326     gmx_bool bClearF;
4327     gmx_parallel_3dfft_t pfft_setup;
4328     real *  fftgrid;
4329     t_complex * cfftgrid;
4330     int     thread;
4331     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4332     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4333
4334     assert(pme->nnodes > 0);
4335     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4336
4337     if (pme->nnodes > 1)
4338     {
4339         atc      = &pme->atc[0];
4340         atc->npd = homenr;
4341         if (atc->npd > atc->pd_nalloc)
4342         {
4343             atc->pd_nalloc = over_alloc_dd(atc->npd);
4344             srenew(atc->pd, atc->pd_nalloc);
4345         }
4346         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4347     }
4348     else
4349     {
4350         /* This could be necessary for TPI */
4351         pme->atc[0].n = homenr;
4352     }
4353
4354     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4355     {
4356         if (q == 0)
4357         {
4358             pmegrid    = &pme->pmegridA;
4359             fftgrid    = pme->fftgridA;
4360             cfftgrid   = pme->cfftgridA;
4361             pfft_setup = pme->pfft_setupA;
4362             charge     = chargeA+start;
4363         }
4364         else
4365         {
4366             pmegrid    = &pme->pmegridB;
4367             fftgrid    = pme->fftgridB;
4368             cfftgrid   = pme->cfftgridB;
4369             pfft_setup = pme->pfft_setupB;
4370             charge     = chargeB+start;
4371         }
4372         grid = pmegrid->grid.grid;
4373         /* Unpack structure */
4374         if (debug)
4375         {
4376             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4377                     cr->nnodes, cr->nodeid);
4378             fprintf(debug, "Grid = %p\n", (void*)grid);
4379             if (grid == NULL)
4380             {
4381                 gmx_fatal(FARGS, "No grid!");
4382             }
4383         }
4384         where();
4385
4386         m_inv_ur0(box, pme->recipbox);
4387
4388         if (pme->nnodes == 1)
4389         {
4390             atc = &pme->atc[0];
4391             if (DOMAINDECOMP(cr))
4392             {
4393                 atc->n = homenr;
4394                 pme_realloc_atomcomm_things(atc);
4395             }
4396             atc->x = x;
4397             atc->q = charge;
4398             atc->f = f;
4399         }
4400         else
4401         {
4402             wallcycle_start(wcycle, ewcPME_REDISTXF);
4403             for (d = pme->ndecompdim-1; d >= 0; d--)
4404             {
4405                 if (d == pme->ndecompdim-1)
4406                 {
4407                     n_d = homenr;
4408                     x_d = x + start;
4409                     q_d = charge;
4410                 }
4411                 else
4412                 {
4413                     n_d = pme->atc[d+1].n;
4414                     x_d = atc->x;
4415                     q_d = atc->q;
4416                 }
4417                 atc      = &pme->atc[d];
4418                 atc->npd = n_d;
4419                 if (atc->npd > atc->pd_nalloc)
4420                 {
4421                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4422                     srenew(atc->pd, atc->pd_nalloc);
4423                 }
4424                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4425                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4426                 where();
4427
4428                 GMX_BARRIER(cr->mpi_comm_mygroup);
4429                 /* Redistribute x (only once) and qA or qB */
4430                 if (DOMAINDECOMP(cr))
4431                 {
4432                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4433                 }
4434                 else
4435                 {
4436                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4437                 }
4438             }
4439             where();
4440
4441             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4442         }
4443
4444         if (debug)
4445         {
4446             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4447                     cr->nodeid, atc->n);
4448         }
4449
4450         if (flags & GMX_PME_SPREAD_Q)
4451         {
4452             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4453
4454             /* Spread the charges on a grid */
4455             GMX_MPE_LOG(ev_spread_on_grid_start);
4456
4457             /* Spread the charges on a grid */
4458             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4459             GMX_MPE_LOG(ev_spread_on_grid_finish);
4460
4461             if (q == 0)
4462             {
4463                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4464             }
4465             inc_nrnb(nrnb, eNR_SPREADQBSP,
4466                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4467
4468             if (!pme->bUseThreads)
4469             {
4470                 wrap_periodic_pmegrid(pme, grid);
4471
4472                 /* sum contributions to local grid from other nodes */
4473 #ifdef GMX_MPI
4474                 if (pme->nnodes > 1)
4475                 {
4476                     GMX_BARRIER(cr->mpi_comm_mygroup);
4477                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4478                     where();
4479                 }
4480 #endif
4481
4482                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4483             }
4484
4485             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4486
4487             /*
4488                dump_local_fftgrid(pme,fftgrid);
4489                exit(0);
4490              */
4491         }
4492
4493         /* Here we start a large thread parallel region */
4494 #pragma omp parallel num_threads(pme->nthread) private(thread)
4495         {
4496             thread = gmx_omp_get_thread_num();
4497             if (flags & GMX_PME_SOLVE)
4498             {
4499                 int loop_count;
4500
4501                 /* do 3d-fft */
4502                 if (thread == 0)
4503                 {
4504                     GMX_BARRIER(cr->mpi_comm_mygroup);
4505                     GMX_MPE_LOG(ev_gmxfft3d_start);
4506                     wallcycle_start(wcycle, ewcPME_FFT);
4507                 }
4508                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4509                                            fftgrid, cfftgrid, thread, wcycle);
4510                 if (thread == 0)
4511                 {
4512                     wallcycle_stop(wcycle, ewcPME_FFT);
4513                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4514                 }
4515                 where();
4516
4517                 /* solve in k-space for our local cells */
4518                 if (thread == 0)
4519                 {
4520                     GMX_BARRIER(cr->mpi_comm_mygroup);
4521                     GMX_MPE_LOG(ev_solve_pme_start);
4522                     wallcycle_start(wcycle, ewcPME_SOLVE);
4523                 }
4524                 loop_count =
4525                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4526                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4527                                   bCalcEnerVir,
4528                                   pme->nthread, thread);
4529                 if (thread == 0)
4530                 {
4531                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4532                     where();
4533                     GMX_MPE_LOG(ev_solve_pme_finish);
4534                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4535                 }
4536             }
4537
4538             if (bCalcF)
4539             {
4540                 /* do 3d-invfft */
4541                 if (thread == 0)
4542                 {
4543                     GMX_BARRIER(cr->mpi_comm_mygroup);
4544                     GMX_MPE_LOG(ev_gmxfft3d_start);
4545                     where();
4546                     wallcycle_start(wcycle, ewcPME_FFT);
4547                 }
4548                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4549                                            cfftgrid, fftgrid, thread, wcycle);
4550                 if (thread == 0)
4551                 {
4552                     wallcycle_stop(wcycle, ewcPME_FFT);
4553
4554                     where();
4555                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4556
4557                     if (pme->nodeid == 0)
4558                     {
4559                         ntot  = pme->nkx*pme->nky*pme->nkz;
4560                         npme  = ntot*log((real)ntot)/log(2.0);
4561                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4562                     }
4563
4564                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4565                 }
4566
4567                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4568             }
4569         }
4570         /* End of thread parallel section.
4571          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4572          */
4573
4574         if (bCalcF)
4575         {
4576             /* distribute local grid to all nodes */
4577 #ifdef GMX_MPI
4578             if (pme->nnodes > 1)
4579             {
4580                 GMX_BARRIER(cr->mpi_comm_mygroup);
4581                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4582             }
4583 #endif
4584             where();
4585
4586             unwrap_periodic_pmegrid(pme, grid);
4587
4588             /* interpolate forces for our local atoms */
4589             GMX_BARRIER(cr->mpi_comm_mygroup);
4590             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4591
4592             where();
4593
4594             /* If we are running without parallelization,
4595              * atc->f is the actual force array, not a buffer,
4596              * therefore we should not clear it.
4597              */
4598             bClearF = (q == 0 && PAR(cr));
4599 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4600             for (thread = 0; thread < pme->nthread; thread++)
4601             {
4602                 gather_f_bsplines(pme, grid, bClearF, atc,
4603                                   &atc->spline[thread],
4604                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4605             }
4606
4607             where();
4608
4609             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4610
4611             inc_nrnb(nrnb, eNR_GATHERFBSP,
4612                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4613             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4614         }
4615
4616         if (bCalcEnerVir)
4617         {
4618             /* This should only be called on the master thread
4619              * and after the threads have synchronized.
4620              */
4621             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4622         }
4623     } /* of q-loop */
4624
4625     if (bCalcF && pme->nnodes > 1)
4626     {
4627         wallcycle_start(wcycle, ewcPME_REDISTXF);
4628         for (d = 0; d < pme->ndecompdim; d++)
4629         {
4630             atc = &pme->atc[d];
4631             if (d == pme->ndecompdim - 1)
4632             {
4633                 n_d = homenr;
4634                 f_d = f + start;
4635             }
4636             else
4637             {
4638                 n_d = pme->atc[d+1].n;
4639                 f_d = pme->atc[d+1].f;
4640             }
4641             GMX_BARRIER(cr->mpi_comm_mygroup);
4642             if (DOMAINDECOMP(cr))
4643             {
4644                 dd_pmeredist_f(pme, atc, n_d, f_d,
4645                                d == pme->ndecompdim-1 && pme->bPPnode);
4646             }
4647             else
4648             {
4649                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4650             }
4651         }
4652
4653         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4654     }
4655     where();
4656
4657     if (bCalcEnerVir)
4658     {
4659         if (!pme->bFEP)
4660         {
4661             *energy = energy_AB[0];
4662             m_add(vir, vir_AB[0], vir);
4663         }
4664         else
4665         {
4666             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4667             *dvdlambda += energy_AB[1] - energy_AB[0];
4668             for (i = 0; i < DIM; i++)
4669             {
4670                 for (j = 0; j < DIM; j++)
4671                 {
4672                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4673                         lambda*vir_AB[1][i][j];
4674                 }
4675             }
4676         }
4677     }
4678     else
4679     {
4680         *energy = 0;
4681     }
4682
4683     if (debug)
4684     {
4685         fprintf(debug, "PME mesh energy: %g\n", *energy);
4686     }
4687
4688     return 0;
4689 }