d5113abdbe69ef9eddf9b5a10836049590837296
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team,
6  * check out http://www.gromacs.org for more information.
7  * Copyright (c) 2012,2013, by the GROMACS development team, led by
8  * David van der Spoel, Berk Hess, Erik Lindahl, and including many
9  * others, as listed in the AUTHORS file in the top-level source
10  * directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* IMPORTANT FOR DEVELOPERS:
39  *
40  * Triclinic pme stuff isn't entirely trivial, and we've experienced
41  * some bugs during development (many of them due to me). To avoid
42  * this in the future, please check the following things if you make
43  * changes in this file:
44  *
45  * 1. You should obtain identical (at least to the PME precision)
46  *    energies, forces, and virial for
47  *    a rectangular box and a triclinic one where the z (or y) axis is
48  *    tilted a whole box side. For instance you could use these boxes:
49  *
50  *    rectangular       triclinic
51  *     2  0  0           2  0  0
52  *     0  2  0           0  2  0
53  *     0  0  6           2  2  6
54  *
55  * 2. You should check the energy conservation in a triclinic box.
56  *
57  * It might seem an overkill, but better safe than sorry.
58  * /Erik 001109
59  */
60
61 #ifdef HAVE_CONFIG_H
62 #include <config.h>
63 #endif
64
65 #ifdef GMX_LIB_MPI
66 #include <mpi.h>
67 #endif
68 #ifdef GMX_THREAD_MPI
69 #include "tmpi.h"
70 #endif
71
72 #include <stdio.h>
73 #include <string.h>
74 #include <math.h>
75 #include <assert.h>
76 #include "typedefs.h"
77 #include "txtdump.h"
78 #include "vec.h"
79 #include "gmxcomplex.h"
80 #include "smalloc.h"
81 #include "futil.h"
82 #include "coulomb.h"
83 #include "gmx_fatal.h"
84 #include "pme.h"
85 #include "network.h"
86 #include "physics.h"
87 #include "nrnb.h"
88 #include "copyrite.h"
89 #include "gmx_wallcycle.h"
90 #include "gmx_parallel_3dfft.h"
91 #include "pdbio.h"
92 #include "gmx_cyclecounter.h"
93 #include "gmx_omp.h"
94
95 /* Single precision, with SSE2 or higher available */
96 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
97
98 #include "gmx_x86_simd_single.h"
99
100 #define PME_SSE
101 /* Some old AMD processors could have problems with unaligned loads+stores */
102 #ifndef GMX_FAHCORE
103 #define PME_SSE_UNALIGNED
104 #endif
105 #endif
106
107 #include "mpelogging.h"
108
109 #define DFT_TOL 1e-7
110 /* #define PRT_FORCE */
111 /* conditions for on the fly time-measurement */
112 /* #define TAKETIME (step > 1 && timesteps < 10) */
113 #define TAKETIME FALSE
114
115 /* #define PME_TIME_THREADS */
116
117 #ifdef GMX_DOUBLE
118 #define mpi_type MPI_DOUBLE
119 #else
120 #define mpi_type MPI_FLOAT
121 #endif
122
123 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
124 #define GMX_CACHE_SEP 64
125
126 /* We only define a maximum to be able to use local arrays without allocation.
127  * An order larger than 12 should never be needed, even for test cases.
128  * If needed it can be changed here.
129  */
130 #define PME_ORDER_MAX 12
131
132 /* Internal datastructures */
133 typedef struct {
134     int send_index0;
135     int send_nindex;
136     int recv_index0;
137     int recv_nindex;
138     int recv_size;   /* Receive buffer width, used with OpenMP */
139 } pme_grid_comm_t;
140
141 typedef struct {
142 #ifdef GMX_MPI
143     MPI_Comm         mpi_comm;
144 #endif
145     int              nnodes, nodeid;
146     int             *s2g0;
147     int             *s2g1;
148     int              noverlap_nodes;
149     int             *send_id, *recv_id;
150     int              send_size; /* Send buffer width, used with OpenMP */
151     pme_grid_comm_t *comm_data;
152     real            *sendbuf;
153     real            *recvbuf;
154 } pme_overlap_t;
155
156 typedef struct {
157     int *n;      /* Cumulative counts of the number of particles per thread */
158     int  nalloc; /* Allocation size of i */
159     int *i;      /* Particle indices ordered on thread index (n) */
160 } thread_plist_t;
161
162 typedef struct {
163     int      *thread_one;
164     int       n;
165     int      *ind;
166     splinevec theta;
167     real     *ptr_theta_z;
168     splinevec dtheta;
169     real     *ptr_dtheta_z;
170 } splinedata_t;
171
172 typedef struct {
173     int      dimind;        /* The index of the dimension, 0=x, 1=y */
174     int      nslab;
175     int      nodeid;
176 #ifdef GMX_MPI
177     MPI_Comm mpi_comm;
178 #endif
179
180     int     *node_dest;     /* The nodes to send x and q to with DD */
181     int     *node_src;      /* The nodes to receive x and q from with DD */
182     int     *buf_index;     /* Index for commnode into the buffers */
183
184     int      maxshift;
185
186     int      npd;
187     int      pd_nalloc;
188     int     *pd;
189     int     *count;         /* The number of atoms to send to each node */
190     int    **count_thread;
191     int     *rcount;        /* The number of atoms to receive */
192
193     int      n;
194     int      nalloc;
195     rvec    *x;
196     real    *q;
197     rvec    *f;
198     gmx_bool bSpread;       /* These coordinates are used for spreading */
199     int      pme_order;
200     ivec    *idx;
201     rvec    *fractx;            /* Fractional coordinate relative to the
202                                  * lower cell boundary
203                                  */
204     int             nthread;
205     int            *thread_idx; /* Which thread should spread which charge */
206     thread_plist_t *thread_plist;
207     splinedata_t   *spline;
208 } pme_atomcomm_t;
209
210 #define FLBS  3
211 #define FLBSZ 4
212
213 typedef struct {
214     ivec  ci;     /* The spatial location of this grid         */
215     ivec  n;      /* The used size of *grid, including order-1 */
216     ivec  offset; /* The grid offset from the full node grid   */
217     int   order;  /* PME spreading order                       */
218     ivec  s;      /* The allocated size of *grid, s >= n       */
219     real *grid;   /* The grid local thread, size n             */
220 } pmegrid_t;
221
222 typedef struct {
223     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
224     int        nthread;      /* The number of threads operating on this grid     */
225     ivec       nc;           /* The local spatial decomposition over the threads */
226     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
227     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
228     int      **g2t;          /* The grid to thread index                         */
229     ivec       nthread_comm; /* The number of threads to communicate with        */
230 } pmegrids_t;
231
232
233 typedef struct {
234 #ifdef PME_SSE
235     /* Masks for SSE aligned spreading and gathering */
236     __m128 mask_SSE0[6], mask_SSE1[6];
237 #else
238     int    dummy; /* C89 requires that struct has at least one member */
239 #endif
240 } pme_spline_work_t;
241
242 typedef struct {
243     /* work data for solve_pme */
244     int      nalloc;
245     real *   mhx;
246     real *   mhy;
247     real *   mhz;
248     real *   m2;
249     real *   denom;
250     real *   tmp1_alloc;
251     real *   tmp1;
252     real *   eterm;
253     real *   m2inv;
254
255     real     energy;
256     matrix   vir;
257 } pme_work_t;
258
259 typedef struct gmx_pme {
260     int           ndecompdim; /* The number of decomposition dimensions */
261     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
262     int           nodeid_major;
263     int           nodeid_minor;
264     int           nnodes;    /* The number of nodes doing PME */
265     int           nnodes_major;
266     int           nnodes_minor;
267
268     MPI_Comm      mpi_comm;
269     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
270 #ifdef GMX_MPI
271     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
272 #endif
273
274     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
275     int        nthread;       /* The number of threads doing PME on our rank */
276
277     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
278     gmx_bool   bFEP;          /* Compute Free energy contribution */
279     int        nkx, nky, nkz; /* Grid dimensions */
280     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
281     int        pme_order;
282     real       epsilon_r;
283
284     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
285     pmegrids_t pmegridB;
286     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
287     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
288     /* pmegrid_nz might be larger than strictly necessary to ensure
289      * memory alignment, pmegrid_nz_base gives the real base size.
290      */
291     int     pmegrid_nz_base;
292     /* The local PME grid starting indices */
293     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
294
295     /* Work data for spreading and gathering */
296     pme_spline_work_t    *spline_work;
297
298     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
299     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
300     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
301
302     t_complex            *cfftgridA;  /* Grids for complex FFT data */
303     t_complex            *cfftgridB;
304     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
305
306     gmx_parallel_3dfft_t  pfft_setupA;
307     gmx_parallel_3dfft_t  pfft_setupB;
308
309     int                  *nnx, *nny, *nnz;
310     real                 *fshx, *fshy, *fshz;
311
312     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
313     matrix                recipbox;
314     splinevec             bsp_mod;
315
316     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
317
318     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
319
320     rvec                 *bufv;       /* Communication buffer */
321     real                 *bufr;       /* Communication buffer */
322     int                   buf_nalloc; /* The communication buffer size */
323
324     /* thread local work data for solve_pme */
325     pme_work_t *work;
326
327     /* Work data for PME_redist */
328     gmx_bool redist_init;
329     int *    scounts;
330     int *    rcounts;
331     int *    sdispls;
332     int *    rdispls;
333     int *    sidx;
334     int *    idxa;
335     real *   redist_buf;
336     int      redist_buf_nalloc;
337
338     /* Work data for sum_qgrid */
339     real *   sum_qgrid_tmp;
340     real *   sum_qgrid_dd_tmp;
341 } t_gmx_pme;
342
343
344 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
345                                    int start, int end, int thread)
346 {
347     int             i;
348     int            *idxptr, tix, tiy, tiz;
349     real           *xptr, *fptr, tx, ty, tz;
350     real            rxx, ryx, ryy, rzx, rzy, rzz;
351     int             nx, ny, nz;
352     int             start_ix, start_iy, start_iz;
353     int            *g2tx, *g2ty, *g2tz;
354     gmx_bool        bThreads;
355     int            *thread_idx = NULL;
356     thread_plist_t *tpl        = NULL;
357     int            *tpl_n      = NULL;
358     int             thread_i;
359
360     nx  = pme->nkx;
361     ny  = pme->nky;
362     nz  = pme->nkz;
363
364     start_ix = pme->pmegrid_start_ix;
365     start_iy = pme->pmegrid_start_iy;
366     start_iz = pme->pmegrid_start_iz;
367
368     rxx = pme->recipbox[XX][XX];
369     ryx = pme->recipbox[YY][XX];
370     ryy = pme->recipbox[YY][YY];
371     rzx = pme->recipbox[ZZ][XX];
372     rzy = pme->recipbox[ZZ][YY];
373     rzz = pme->recipbox[ZZ][ZZ];
374
375     g2tx = pme->pmegridA.g2t[XX];
376     g2ty = pme->pmegridA.g2t[YY];
377     g2tz = pme->pmegridA.g2t[ZZ];
378
379     bThreads = (atc->nthread > 1);
380     if (bThreads)
381     {
382         thread_idx = atc->thread_idx;
383
384         tpl   = &atc->thread_plist[thread];
385         tpl_n = tpl->n;
386         for (i = 0; i < atc->nthread; i++)
387         {
388             tpl_n[i] = 0;
389         }
390     }
391
392     for (i = start; i < end; i++)
393     {
394         xptr   = atc->x[i];
395         idxptr = atc->idx[i];
396         fptr   = atc->fractx[i];
397
398         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
399         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
400         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
401         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
402
403         tix = (int)(tx);
404         tiy = (int)(ty);
405         tiz = (int)(tz);
406
407         /* Because decomposition only occurs in x and y,
408          * we never have a fraction correction in z.
409          */
410         fptr[XX] = tx - tix + pme->fshx[tix];
411         fptr[YY] = ty - tiy + pme->fshy[tiy];
412         fptr[ZZ] = tz - tiz;
413
414         idxptr[XX] = pme->nnx[tix];
415         idxptr[YY] = pme->nny[tiy];
416         idxptr[ZZ] = pme->nnz[tiz];
417
418 #ifdef DEBUG
419         range_check(idxptr[XX], 0, pme->pmegrid_nx);
420         range_check(idxptr[YY], 0, pme->pmegrid_ny);
421         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
422 #endif
423
424         if (bThreads)
425         {
426             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
427             thread_idx[i] = thread_i;
428             tpl_n[thread_i]++;
429         }
430     }
431
432     if (bThreads)
433     {
434         /* Make a list of particle indices sorted on thread */
435
436         /* Get the cumulative count */
437         for (i = 1; i < atc->nthread; i++)
438         {
439             tpl_n[i] += tpl_n[i-1];
440         }
441         /* The current implementation distributes particles equally
442          * over the threads, so we could actually allocate for that
443          * in pme_realloc_atomcomm_things.
444          */
445         if (tpl_n[atc->nthread-1] > tpl->nalloc)
446         {
447             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
448             srenew(tpl->i, tpl->nalloc);
449         }
450         /* Set tpl_n to the cumulative start */
451         for (i = atc->nthread-1; i >= 1; i--)
452         {
453             tpl_n[i] = tpl_n[i-1];
454         }
455         tpl_n[0] = 0;
456
457         /* Fill our thread local array with indices sorted on thread */
458         for (i = start; i < end; i++)
459         {
460             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
461         }
462         /* Now tpl_n contains the cummulative count again */
463     }
464 }
465
466 static void make_thread_local_ind(pme_atomcomm_t *atc,
467                                   int thread, splinedata_t *spline)
468 {
469     int             n, t, i, start, end;
470     thread_plist_t *tpl;
471
472     /* Combine the indices made by each thread into one index */
473
474     n     = 0;
475     start = 0;
476     for (t = 0; t < atc->nthread; t++)
477     {
478         tpl = &atc->thread_plist[t];
479         /* Copy our part (start - end) from the list of thread t */
480         if (thread > 0)
481         {
482             start = tpl->n[thread-1];
483         }
484         end = tpl->n[thread];
485         for (i = start; i < end; i++)
486         {
487             spline->ind[n++] = tpl->i[i];
488         }
489     }
490
491     spline->n = n;
492 }
493
494
495 static void pme_calc_pidx(int start, int end,
496                           matrix recipbox, rvec x[],
497                           pme_atomcomm_t *atc, int *count)
498 {
499     int   nslab, i;
500     int   si;
501     real *xptr, s;
502     real  rxx, ryx, rzx, ryy, rzy;
503     int  *pd;
504
505     /* Calculate PME task index (pidx) for each grid index.
506      * Here we always assign equally sized slabs to each node
507      * for load balancing reasons (the PME grid spacing is not used).
508      */
509
510     nslab = atc->nslab;
511     pd    = atc->pd;
512
513     /* Reset the count */
514     for (i = 0; i < nslab; i++)
515     {
516         count[i] = 0;
517     }
518
519     if (atc->dimind == 0)
520     {
521         rxx = recipbox[XX][XX];
522         ryx = recipbox[YY][XX];
523         rzx = recipbox[ZZ][XX];
524         /* Calculate the node index in x-dimension */
525         for (i = start; i < end; i++)
526         {
527             xptr   = x[i];
528             /* Fractional coordinates along box vectors */
529             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
530             si    = (int)(s + 2*nslab) % nslab;
531             pd[i] = si;
532             count[si]++;
533         }
534     }
535     else
536     {
537         ryy = recipbox[YY][YY];
538         rzy = recipbox[ZZ][YY];
539         /* Calculate the node index in y-dimension */
540         for (i = start; i < end; i++)
541         {
542             xptr   = x[i];
543             /* Fractional coordinates along box vectors */
544             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
545             si    = (int)(s + 2*nslab) % nslab;
546             pd[i] = si;
547             count[si]++;
548         }
549     }
550 }
551
552 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
553                                   pme_atomcomm_t *atc)
554 {
555     int nthread, thread, slab;
556
557     nthread = atc->nthread;
558
559 #pragma omp parallel for num_threads(nthread) schedule(static)
560     for (thread = 0; thread < nthread; thread++)
561     {
562         pme_calc_pidx(natoms* thread   /nthread,
563                       natoms*(thread+1)/nthread,
564                       recipbox, x, atc, atc->count_thread[thread]);
565     }
566     /* Non-parallel reduction, since nslab is small */
567
568     for (thread = 1; thread < nthread; thread++)
569     {
570         for (slab = 0; slab < atc->nslab; slab++)
571         {
572             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
573         }
574     }
575 }
576
577 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
578 {
579     const int padding = 4;
580     int       i;
581
582     srenew(th[XX], nalloc);
583     srenew(th[YY], nalloc);
584     /* In z we add padding, this is only required for the aligned SSE code */
585     srenew(*ptr_z, nalloc+2*padding);
586     th[ZZ] = *ptr_z + padding;
587
588     for (i = 0; i < padding; i++)
589     {
590         (*ptr_z)[               i] = 0;
591         (*ptr_z)[padding+nalloc+i] = 0;
592     }
593 }
594
595 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
596 {
597     int i, d;
598
599     srenew(spline->ind, atc->nalloc);
600     /* Initialize the index to identity so it works without threads */
601     for (i = 0; i < atc->nalloc; i++)
602     {
603         spline->ind[i] = i;
604     }
605
606     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
607                       atc->pme_order*atc->nalloc);
608     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
609                       atc->pme_order*atc->nalloc);
610 }
611
612 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
613 {
614     int nalloc_old, i, j, nalloc_tpl;
615
616     /* We have to avoid a NULL pointer for atc->x to avoid
617      * possible fatal errors in MPI routines.
618      */
619     if (atc->n > atc->nalloc || atc->nalloc == 0)
620     {
621         nalloc_old  = atc->nalloc;
622         atc->nalloc = over_alloc_dd(max(atc->n, 1));
623
624         if (atc->nslab > 1)
625         {
626             srenew(atc->x, atc->nalloc);
627             srenew(atc->q, atc->nalloc);
628             srenew(atc->f, atc->nalloc);
629             for (i = nalloc_old; i < atc->nalloc; i++)
630             {
631                 clear_rvec(atc->f[i]);
632             }
633         }
634         if (atc->bSpread)
635         {
636             srenew(atc->fractx, atc->nalloc);
637             srenew(atc->idx, atc->nalloc);
638
639             if (atc->nthread > 1)
640             {
641                 srenew(atc->thread_idx, atc->nalloc);
642             }
643
644             for (i = 0; i < atc->nthread; i++)
645             {
646                 pme_realloc_splinedata(&atc->spline[i], atc);
647             }
648         }
649     }
650 }
651
652 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
653                          int n, gmx_bool bXF, rvec *x_f, real *charge,
654                          pme_atomcomm_t *atc)
655 /* Redistribute particle data for PME calculation */
656 /* domain decomposition by x coordinate           */
657 {
658     int *idxa;
659     int  i, ii;
660
661     if (FALSE == pme->redist_init)
662     {
663         snew(pme->scounts, atc->nslab);
664         snew(pme->rcounts, atc->nslab);
665         snew(pme->sdispls, atc->nslab);
666         snew(pme->rdispls, atc->nslab);
667         snew(pme->sidx, atc->nslab);
668         pme->redist_init = TRUE;
669     }
670     if (n > pme->redist_buf_nalloc)
671     {
672         pme->redist_buf_nalloc = over_alloc_dd(n);
673         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
674     }
675
676     pme->idxa = atc->pd;
677
678 #ifdef GMX_MPI
679     if (forw && bXF)
680     {
681         /* forward, redistribution from pp to pme */
682
683         /* Calculate send counts and exchange them with other nodes */
684         for (i = 0; (i < atc->nslab); i++)
685         {
686             pme->scounts[i] = 0;
687         }
688         for (i = 0; (i < n); i++)
689         {
690             pme->scounts[pme->idxa[i]]++;
691         }
692         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
693
694         /* Calculate send and receive displacements and index into send
695            buffer */
696         pme->sdispls[0] = 0;
697         pme->rdispls[0] = 0;
698         pme->sidx[0]    = 0;
699         for (i = 1; i < atc->nslab; i++)
700         {
701             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
702             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
703             pme->sidx[i]    = pme->sdispls[i];
704         }
705         /* Total # of particles to be received */
706         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
707
708         pme_realloc_atomcomm_things(atc);
709
710         /* Copy particle coordinates into send buffer and exchange*/
711         for (i = 0; (i < n); i++)
712         {
713             ii = DIM*pme->sidx[pme->idxa[i]];
714             pme->sidx[pme->idxa[i]]++;
715             pme->redist_buf[ii+XX] = x_f[i][XX];
716             pme->redist_buf[ii+YY] = x_f[i][YY];
717             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
718         }
719         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
720                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
721                       pme->rvec_mpi, atc->mpi_comm);
722     }
723     if (forw)
724     {
725         /* Copy charge into send buffer and exchange*/
726         for (i = 0; i < atc->nslab; i++)
727         {
728             pme->sidx[i] = pme->sdispls[i];
729         }
730         for (i = 0; (i < n); i++)
731         {
732             ii = pme->sidx[pme->idxa[i]];
733             pme->sidx[pme->idxa[i]]++;
734             pme->redist_buf[ii] = charge[i];
735         }
736         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
737                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
738                       atc->mpi_comm);
739     }
740     else   /* backward, redistribution from pme to pp */
741     {
742         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
743                       pme->redist_buf, pme->scounts, pme->sdispls,
744                       pme->rvec_mpi, atc->mpi_comm);
745
746         /* Copy data from receive buffer */
747         for (i = 0; i < atc->nslab; i++)
748         {
749             pme->sidx[i] = pme->sdispls[i];
750         }
751         for (i = 0; (i < n); i++)
752         {
753             ii          = DIM*pme->sidx[pme->idxa[i]];
754             x_f[i][XX] += pme->redist_buf[ii+XX];
755             x_f[i][YY] += pme->redist_buf[ii+YY];
756             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
757             pme->sidx[pme->idxa[i]]++;
758         }
759     }
760 #endif
761 }
762
763 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
764                             gmx_bool bBackward, int shift,
765                             void *buf_s, int nbyte_s,
766                             void *buf_r, int nbyte_r)
767 {
768 #ifdef GMX_MPI
769     int        dest, src;
770     MPI_Status stat;
771
772     if (bBackward == FALSE)
773     {
774         dest = atc->node_dest[shift];
775         src  = atc->node_src[shift];
776     }
777     else
778     {
779         dest = atc->node_src[shift];
780         src  = atc->node_dest[shift];
781     }
782
783     if (nbyte_s > 0 && nbyte_r > 0)
784     {
785         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
786                      dest, shift,
787                      buf_r, nbyte_r, MPI_BYTE,
788                      src, shift,
789                      atc->mpi_comm, &stat);
790     }
791     else if (nbyte_s > 0)
792     {
793         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
794                  dest, shift,
795                  atc->mpi_comm);
796     }
797     else if (nbyte_r > 0)
798     {
799         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
800                  src, shift,
801                  atc->mpi_comm, &stat);
802     }
803 #endif
804 }
805
806 static void dd_pmeredist_x_q(gmx_pme_t pme,
807                              int n, gmx_bool bX, rvec *x, real *charge,
808                              pme_atomcomm_t *atc)
809 {
810     int *commnode, *buf_index;
811     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
812
813     commnode  = atc->node_dest;
814     buf_index = atc->buf_index;
815
816     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
817
818     nsend = 0;
819     for (i = 0; i < nnodes_comm; i++)
820     {
821         buf_index[commnode[i]] = nsend;
822         nsend                 += atc->count[commnode[i]];
823     }
824     if (bX)
825     {
826         if (atc->count[atc->nodeid] + nsend != n)
827         {
828             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
829                       "This usually means that your system is not well equilibrated.",
830                       n - (atc->count[atc->nodeid] + nsend),
831                       pme->nodeid, 'x'+atc->dimind);
832         }
833
834         if (nsend > pme->buf_nalloc)
835         {
836             pme->buf_nalloc = over_alloc_dd(nsend);
837             srenew(pme->bufv, pme->buf_nalloc);
838             srenew(pme->bufr, pme->buf_nalloc);
839         }
840
841         atc->n = atc->count[atc->nodeid];
842         for (i = 0; i < nnodes_comm; i++)
843         {
844             scount = atc->count[commnode[i]];
845             /* Communicate the count */
846             if (debug)
847             {
848                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
849                         atc->dimind, atc->nodeid, commnode[i], scount);
850             }
851             pme_dd_sendrecv(atc, FALSE, i,
852                             &scount, sizeof(int),
853                             &atc->rcount[i], sizeof(int));
854             atc->n += atc->rcount[i];
855         }
856
857         pme_realloc_atomcomm_things(atc);
858     }
859
860     local_pos = 0;
861     for (i = 0; i < n; i++)
862     {
863         node = atc->pd[i];
864         if (node == atc->nodeid)
865         {
866             /* Copy direct to the receive buffer */
867             if (bX)
868             {
869                 copy_rvec(x[i], atc->x[local_pos]);
870             }
871             atc->q[local_pos] = charge[i];
872             local_pos++;
873         }
874         else
875         {
876             /* Copy to the send buffer */
877             if (bX)
878             {
879                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
880             }
881             pme->bufr[buf_index[node]] = charge[i];
882             buf_index[node]++;
883         }
884     }
885
886     buf_pos = 0;
887     for (i = 0; i < nnodes_comm; i++)
888     {
889         scount = atc->count[commnode[i]];
890         rcount = atc->rcount[i];
891         if (scount > 0 || rcount > 0)
892         {
893             if (bX)
894             {
895                 /* Communicate the coordinates */
896                 pme_dd_sendrecv(atc, FALSE, i,
897                                 pme->bufv[buf_pos], scount*sizeof(rvec),
898                                 atc->x[local_pos], rcount*sizeof(rvec));
899             }
900             /* Communicate the charges */
901             pme_dd_sendrecv(atc, FALSE, i,
902                             pme->bufr+buf_pos, scount*sizeof(real),
903                             atc->q+local_pos, rcount*sizeof(real));
904             buf_pos   += scount;
905             local_pos += atc->rcount[i];
906         }
907     }
908 }
909
910 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
911                            int n, rvec *f,
912                            gmx_bool bAddF)
913 {
914     int *commnode, *buf_index;
915     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
916
917     commnode  = atc->node_dest;
918     buf_index = atc->buf_index;
919
920     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
921
922     local_pos = atc->count[atc->nodeid];
923     buf_pos   = 0;
924     for (i = 0; i < nnodes_comm; i++)
925     {
926         scount = atc->rcount[i];
927         rcount = atc->count[commnode[i]];
928         if (scount > 0 || rcount > 0)
929         {
930             /* Communicate the forces */
931             pme_dd_sendrecv(atc, TRUE, i,
932                             atc->f[local_pos], scount*sizeof(rvec),
933                             pme->bufv[buf_pos], rcount*sizeof(rvec));
934             local_pos += scount;
935         }
936         buf_index[commnode[i]] = buf_pos;
937         buf_pos               += rcount;
938     }
939
940     local_pos = 0;
941     if (bAddF)
942     {
943         for (i = 0; i < n; i++)
944         {
945             node = atc->pd[i];
946             if (node == atc->nodeid)
947             {
948                 /* Add from the local force array */
949                 rvec_inc(f[i], atc->f[local_pos]);
950                 local_pos++;
951             }
952             else
953             {
954                 /* Add from the receive buffer */
955                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
956                 buf_index[node]++;
957             }
958         }
959     }
960     else
961     {
962         for (i = 0; i < n; i++)
963         {
964             node = atc->pd[i];
965             if (node == atc->nodeid)
966             {
967                 /* Copy from the local force array */
968                 copy_rvec(atc->f[local_pos], f[i]);
969                 local_pos++;
970             }
971             else
972             {
973                 /* Copy from the receive buffer */
974                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
975                 buf_index[node]++;
976             }
977         }
978     }
979 }
980
981 #ifdef GMX_MPI
982 static void
983 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
984 {
985     pme_overlap_t *overlap;
986     int            send_index0, send_nindex;
987     int            recv_index0, recv_nindex;
988     MPI_Status     stat;
989     int            i, j, k, ix, iy, iz, icnt;
990     int            ipulse, send_id, recv_id, datasize;
991     real          *p;
992     real          *sendptr, *recvptr;
993
994     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
995     overlap = &pme->overlap[1];
996
997     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
998     {
999         /* Since we have already (un)wrapped the overlap in the z-dimension,
1000          * we only have to communicate 0 to nkz (not pmegrid_nz).
1001          */
1002         if (direction == GMX_SUM_QGRID_FORWARD)
1003         {
1004             send_id       = overlap->send_id[ipulse];
1005             recv_id       = overlap->recv_id[ipulse];
1006             send_index0   = overlap->comm_data[ipulse].send_index0;
1007             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1008             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1009             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1010         }
1011         else
1012         {
1013             send_id       = overlap->recv_id[ipulse];
1014             recv_id       = overlap->send_id[ipulse];
1015             send_index0   = overlap->comm_data[ipulse].recv_index0;
1016             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1017             recv_index0   = overlap->comm_data[ipulse].send_index0;
1018             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1019         }
1020
1021         /* Copy data to contiguous send buffer */
1022         if (debug)
1023         {
1024             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1025                     pme->nodeid, overlap->nodeid, send_id,
1026                     pme->pmegrid_start_iy,
1027                     send_index0-pme->pmegrid_start_iy,
1028                     send_index0-pme->pmegrid_start_iy+send_nindex);
1029         }
1030         icnt = 0;
1031         for (i = 0; i < pme->pmegrid_nx; i++)
1032         {
1033             ix = i;
1034             for (j = 0; j < send_nindex; j++)
1035             {
1036                 iy = j + send_index0 - pme->pmegrid_start_iy;
1037                 for (k = 0; k < pme->nkz; k++)
1038                 {
1039                     iz = k;
1040                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1041                 }
1042             }
1043         }
1044
1045         datasize      = pme->pmegrid_nx * pme->nkz;
1046
1047         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1048                      send_id, ipulse,
1049                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1050                      recv_id, ipulse,
1051                      overlap->mpi_comm, &stat);
1052
1053         /* Get data from contiguous recv buffer */
1054         if (debug)
1055         {
1056             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1057                     pme->nodeid, overlap->nodeid, recv_id,
1058                     pme->pmegrid_start_iy,
1059                     recv_index0-pme->pmegrid_start_iy,
1060                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1061         }
1062         icnt = 0;
1063         for (i = 0; i < pme->pmegrid_nx; i++)
1064         {
1065             ix = i;
1066             for (j = 0; j < recv_nindex; j++)
1067             {
1068                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1069                 for (k = 0; k < pme->nkz; k++)
1070                 {
1071                     iz = k;
1072                     if (direction == GMX_SUM_QGRID_FORWARD)
1073                     {
1074                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1075                     }
1076                     else
1077                     {
1078                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1079                     }
1080                 }
1081             }
1082         }
1083     }
1084
1085     /* Major dimension is easier, no copying required,
1086      * but we might have to sum to separate array.
1087      * Since we don't copy, we have to communicate up to pmegrid_nz,
1088      * not nkz as for the minor direction.
1089      */
1090     overlap = &pme->overlap[0];
1091
1092     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1093     {
1094         if (direction == GMX_SUM_QGRID_FORWARD)
1095         {
1096             send_id       = overlap->send_id[ipulse];
1097             recv_id       = overlap->recv_id[ipulse];
1098             send_index0   = overlap->comm_data[ipulse].send_index0;
1099             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1100             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1101             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1102             recvptr       = overlap->recvbuf;
1103         }
1104         else
1105         {
1106             send_id       = overlap->recv_id[ipulse];
1107             recv_id       = overlap->send_id[ipulse];
1108             send_index0   = overlap->comm_data[ipulse].recv_index0;
1109             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1110             recv_index0   = overlap->comm_data[ipulse].send_index0;
1111             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1112             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1113         }
1114
1115         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1116         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1117
1118         if (debug)
1119         {
1120             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1121                     pme->nodeid, overlap->nodeid, send_id,
1122                     pme->pmegrid_start_ix,
1123                     send_index0-pme->pmegrid_start_ix,
1124                     send_index0-pme->pmegrid_start_ix+send_nindex);
1125             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1126                     pme->nodeid, overlap->nodeid, recv_id,
1127                     pme->pmegrid_start_ix,
1128                     recv_index0-pme->pmegrid_start_ix,
1129                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1130         }
1131
1132         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1133                      send_id, ipulse,
1134                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1135                      recv_id, ipulse,
1136                      overlap->mpi_comm, &stat);
1137
1138         /* ADD data from contiguous recv buffer */
1139         if (direction == GMX_SUM_QGRID_FORWARD)
1140         {
1141             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1142             for (i = 0; i < recv_nindex*datasize; i++)
1143             {
1144                 p[i] += overlap->recvbuf[i];
1145             }
1146         }
1147     }
1148 }
1149 #endif
1150
1151
1152 static int
1153 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1154 {
1155     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1156     ivec    local_pme_size;
1157     int     i, ix, iy, iz;
1158     int     pmeidx, fftidx;
1159
1160     /* Dimensions should be identical for A/B grid, so we just use A here */
1161     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1162                                    local_fft_ndata,
1163                                    local_fft_offset,
1164                                    local_fft_size);
1165
1166     local_pme_size[0] = pme->pmegrid_nx;
1167     local_pme_size[1] = pme->pmegrid_ny;
1168     local_pme_size[2] = pme->pmegrid_nz;
1169
1170     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1171        the offset is identical, and the PME grid always has more data (due to overlap)
1172      */
1173     {
1174 #ifdef DEBUG_PME
1175         FILE *fp, *fp2;
1176         char  fn[STRLEN], format[STRLEN];
1177         real  val;
1178         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1179         fp = ffopen(fn, "w");
1180         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1181         fp2 = ffopen(fn, "w");
1182         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1183 #endif
1184
1185         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1186         {
1187             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1188             {
1189                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1190                 {
1191                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1192                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1193                     fftgrid[fftidx] = pmegrid[pmeidx];
1194 #ifdef DEBUG_PME
1195                     val = 100*pmegrid[pmeidx];
1196                     if (pmegrid[pmeidx] != 0)
1197                     {
1198                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1199                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1200                     }
1201                     if (pmegrid[pmeidx] != 0)
1202                     {
1203                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1204                                 "qgrid",
1205                                 pme->pmegrid_start_ix + ix,
1206                                 pme->pmegrid_start_iy + iy,
1207                                 pme->pmegrid_start_iz + iz,
1208                                 pmegrid[pmeidx]);
1209                     }
1210 #endif
1211                 }
1212             }
1213         }
1214 #ifdef DEBUG_PME
1215         ffclose(fp);
1216         ffclose(fp2);
1217 #endif
1218     }
1219     return 0;
1220 }
1221
1222
1223 static gmx_cycles_t omp_cyc_start()
1224 {
1225     return gmx_cycles_read();
1226 }
1227
1228 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1229 {
1230     return gmx_cycles_read() - c;
1231 }
1232
1233
1234 static int
1235 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1236                         int nthread, int thread)
1237 {
1238     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1239     ivec          local_pme_size;
1240     int           ixy0, ixy1, ixy, ix, iy, iz;
1241     int           pmeidx, fftidx;
1242 #ifdef PME_TIME_THREADS
1243     gmx_cycles_t  c1;
1244     static double cs1 = 0;
1245     static int    cnt = 0;
1246 #endif
1247
1248 #ifdef PME_TIME_THREADS
1249     c1 = omp_cyc_start();
1250 #endif
1251     /* Dimensions should be identical for A/B grid, so we just use A here */
1252     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1253                                    local_fft_ndata,
1254                                    local_fft_offset,
1255                                    local_fft_size);
1256
1257     local_pme_size[0] = pme->pmegrid_nx;
1258     local_pme_size[1] = pme->pmegrid_ny;
1259     local_pme_size[2] = pme->pmegrid_nz;
1260
1261     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1262        the offset is identical, and the PME grid always has more data (due to overlap)
1263      */
1264     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1265     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1266
1267     for (ixy = ixy0; ixy < ixy1; ixy++)
1268     {
1269         ix = ixy/local_fft_ndata[YY];
1270         iy = ixy - ix*local_fft_ndata[YY];
1271
1272         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1273         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1274         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1275         {
1276             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1277         }
1278     }
1279
1280 #ifdef PME_TIME_THREADS
1281     c1   = omp_cyc_end(c1);
1282     cs1 += (double)c1;
1283     cnt++;
1284     if (cnt % 20 == 0)
1285     {
1286         printf("copy %.2f\n", cs1*1e-9);
1287     }
1288 #endif
1289
1290     return 0;
1291 }
1292
1293
1294 static void
1295 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1296 {
1297     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1298
1299     nx = pme->nkx;
1300     ny = pme->nky;
1301     nz = pme->nkz;
1302
1303     pnx = pme->pmegrid_nx;
1304     pny = pme->pmegrid_ny;
1305     pnz = pme->pmegrid_nz;
1306
1307     overlap = pme->pme_order - 1;
1308
1309     /* Add periodic overlap in z */
1310     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1311     {
1312         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1313         {
1314             for (iz = 0; iz < overlap; iz++)
1315             {
1316                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1317                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1318             }
1319         }
1320     }
1321
1322     if (pme->nnodes_minor == 1)
1323     {
1324         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1325         {
1326             for (iy = 0; iy < overlap; iy++)
1327             {
1328                 for (iz = 0; iz < nz; iz++)
1329                 {
1330                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1331                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1332                 }
1333             }
1334         }
1335     }
1336
1337     if (pme->nnodes_major == 1)
1338     {
1339         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1340
1341         for (ix = 0; ix < overlap; ix++)
1342         {
1343             for (iy = 0; iy < ny_x; iy++)
1344             {
1345                 for (iz = 0; iz < nz; iz++)
1346                 {
1347                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1348                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1349                 }
1350             }
1351         }
1352     }
1353 }
1354
1355
1356 static void
1357 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1358 {
1359     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1360
1361     nx = pme->nkx;
1362     ny = pme->nky;
1363     nz = pme->nkz;
1364
1365     pnx = pme->pmegrid_nx;
1366     pny = pme->pmegrid_ny;
1367     pnz = pme->pmegrid_nz;
1368
1369     overlap = pme->pme_order - 1;
1370
1371     if (pme->nnodes_major == 1)
1372     {
1373         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1374
1375         for (ix = 0; ix < overlap; ix++)
1376         {
1377             int iy, iz;
1378
1379             for (iy = 0; iy < ny_x; iy++)
1380             {
1381                 for (iz = 0; iz < nz; iz++)
1382                 {
1383                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1384                         pmegrid[(ix*pny+iy)*pnz+iz];
1385                 }
1386             }
1387         }
1388     }
1389
1390     if (pme->nnodes_minor == 1)
1391     {
1392 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1393         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1394         {
1395             int iy, iz;
1396
1397             for (iy = 0; iy < overlap; iy++)
1398             {
1399                 for (iz = 0; iz < nz; iz++)
1400                 {
1401                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1402                         pmegrid[(ix*pny+iy)*pnz+iz];
1403                 }
1404             }
1405         }
1406     }
1407
1408     /* Copy periodic overlap in z */
1409 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1410     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1411     {
1412         int iy, iz;
1413
1414         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1415         {
1416             for (iz = 0; iz < overlap; iz++)
1417             {
1418                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1419                     pmegrid[(ix*pny+iy)*pnz+iz];
1420             }
1421         }
1422     }
1423 }
1424
1425 static void clear_grid(int nx, int ny, int nz, real *grid,
1426                        ivec fs, int *flag,
1427                        int fx, int fy, int fz,
1428                        int order)
1429 {
1430     int nc, ncz;
1431     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1432     int flind;
1433
1434     nc  = 2 + (order - 2)/FLBS;
1435     ncz = 2 + (order - 2)/FLBSZ;
1436
1437     for (fsx = fx; fsx < fx+nc; fsx++)
1438     {
1439         for (fsy = fy; fsy < fy+nc; fsy++)
1440         {
1441             for (fsz = fz; fsz < fz+ncz; fsz++)
1442             {
1443                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1444                 if (flag[flind] == 0)
1445                 {
1446                     gx  = fsx*FLBS;
1447                     gy  = fsy*FLBS;
1448                     gz  = fsz*FLBSZ;
1449                     g0x = (gx*ny + gy)*nz + gz;
1450                     for (x = 0; x < FLBS; x++)
1451                     {
1452                         g0y = g0x;
1453                         for (y = 0; y < FLBS; y++)
1454                         {
1455                             for (z = 0; z < FLBSZ; z++)
1456                             {
1457                                 grid[g0y+z] = 0;
1458                             }
1459                             g0y += nz;
1460                         }
1461                         g0x += ny*nz;
1462                     }
1463
1464                     flag[flind] = 1;
1465                 }
1466             }
1467         }
1468     }
1469 }
1470
1471 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1472 #define DO_BSPLINE(order)                            \
1473     for (ithx = 0; (ithx < order); ithx++)                    \
1474     {                                                    \
1475         index_x = (i0+ithx)*pny*pnz;                     \
1476         valx    = qn*thx[ithx];                          \
1477                                                      \
1478         for (ithy = 0; (ithy < order); ithy++)                \
1479         {                                                \
1480             valxy    = valx*thy[ithy];                   \
1481             index_xy = index_x+(j0+ithy)*pnz;            \
1482                                                      \
1483             for (ithz = 0; (ithz < order); ithz++)            \
1484             {                                            \
1485                 index_xyz        = index_xy+(k0+ithz);   \
1486                 grid[index_xyz] += valxy*thz[ithz];      \
1487             }                                            \
1488         }                                                \
1489     }
1490
1491
1492 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1493                                      pme_atomcomm_t *atc, splinedata_t *spline,
1494                                      pme_spline_work_t *work)
1495 {
1496
1497     /* spread charges from home atoms to local grid */
1498     real          *grid;
1499     pme_overlap_t *ol;
1500     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1501     int       *    idxptr;
1502     int            order, norder, index_x, index_xy, index_xyz;
1503     real           valx, valxy, qn;
1504     real          *thx, *thy, *thz;
1505     int            localsize, bndsize;
1506     int            pnx, pny, pnz, ndatatot;
1507     int            offx, offy, offz;
1508
1509     pnx = pmegrid->s[XX];
1510     pny = pmegrid->s[YY];
1511     pnz = pmegrid->s[ZZ];
1512
1513     offx = pmegrid->offset[XX];
1514     offy = pmegrid->offset[YY];
1515     offz = pmegrid->offset[ZZ];
1516
1517     ndatatot = pnx*pny*pnz;
1518     grid     = pmegrid->grid;
1519     for (i = 0; i < ndatatot; i++)
1520     {
1521         grid[i] = 0;
1522     }
1523
1524     order = pmegrid->order;
1525
1526     for (nn = 0; nn < spline->n; nn++)
1527     {
1528         n  = spline->ind[nn];
1529         qn = atc->q[n];
1530
1531         if (qn != 0)
1532         {
1533             idxptr = atc->idx[n];
1534             norder = nn*order;
1535
1536             i0   = idxptr[XX] - offx;
1537             j0   = idxptr[YY] - offy;
1538             k0   = idxptr[ZZ] - offz;
1539
1540             thx = spline->theta[XX] + norder;
1541             thy = spline->theta[YY] + norder;
1542             thz = spline->theta[ZZ] + norder;
1543
1544             switch (order)
1545             {
1546                 case 4:
1547 #ifdef PME_SSE
1548 #ifdef PME_SSE_UNALIGNED
1549 #define PME_SPREAD_SSE_ORDER4
1550 #else
1551 #define PME_SPREAD_SSE_ALIGNED
1552 #define PME_ORDER 4
1553 #endif
1554 #include "pme_sse_single.h"
1555 #else
1556                     DO_BSPLINE(4);
1557 #endif
1558                     break;
1559                 case 5:
1560 #ifdef PME_SSE
1561 #define PME_SPREAD_SSE_ALIGNED
1562 #define PME_ORDER 5
1563 #include "pme_sse_single.h"
1564 #else
1565                     DO_BSPLINE(5);
1566 #endif
1567                     break;
1568                 default:
1569                     DO_BSPLINE(order);
1570                     break;
1571             }
1572         }
1573     }
1574 }
1575
1576 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1577 {
1578 #ifdef PME_SSE
1579     if (pme_order == 5
1580 #ifndef PME_SSE_UNALIGNED
1581         || pme_order == 4
1582 #endif
1583         )
1584     {
1585         /* Round nz up to a multiple of 4 to ensure alignment */
1586         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1587     }
1588 #endif
1589 }
1590
1591 static void set_gridsize_alignment(int *gridsize, int pme_order)
1592 {
1593 #ifdef PME_SSE
1594 #ifndef PME_SSE_UNALIGNED
1595     if (pme_order == 4)
1596     {
1597         /* Add extra elements to ensured aligned operations do not go
1598          * beyond the allocated grid size.
1599          * Note that for pme_order=5, the pme grid z-size alignment
1600          * ensures that we will not go beyond the grid size.
1601          */
1602         *gridsize += 4;
1603     }
1604 #endif
1605 #endif
1606 }
1607
1608 static void pmegrid_init(pmegrid_t *grid,
1609                          int cx, int cy, int cz,
1610                          int x0, int y0, int z0,
1611                          int x1, int y1, int z1,
1612                          gmx_bool set_alignment,
1613                          int pme_order,
1614                          real *ptr)
1615 {
1616     int nz, gridsize;
1617
1618     grid->ci[XX]     = cx;
1619     grid->ci[YY]     = cy;
1620     grid->ci[ZZ]     = cz;
1621     grid->offset[XX] = x0;
1622     grid->offset[YY] = y0;
1623     grid->offset[ZZ] = z0;
1624     grid->n[XX]      = x1 - x0 + pme_order - 1;
1625     grid->n[YY]      = y1 - y0 + pme_order - 1;
1626     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1627     copy_ivec(grid->n, grid->s);
1628
1629     nz = grid->s[ZZ];
1630     set_grid_alignment(&nz, pme_order);
1631     if (set_alignment)
1632     {
1633         grid->s[ZZ] = nz;
1634     }
1635     else if (nz != grid->s[ZZ])
1636     {
1637         gmx_incons("pmegrid_init call with an unaligned z size");
1638     }
1639
1640     grid->order = pme_order;
1641     if (ptr == NULL)
1642     {
1643         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1644         set_gridsize_alignment(&gridsize, pme_order);
1645         snew_aligned(grid->grid, gridsize, 16);
1646     }
1647     else
1648     {
1649         grid->grid = ptr;
1650     }
1651 }
1652
1653 static int div_round_up(int enumerator, int denominator)
1654 {
1655     return (enumerator + denominator - 1)/denominator;
1656 }
1657
1658 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1659                                   ivec nsub)
1660 {
1661     int gsize_opt, gsize;
1662     int nsx, nsy, nsz;
1663     char *env;
1664
1665     gsize_opt = -1;
1666     for (nsx = 1; nsx <= nthread; nsx++)
1667     {
1668         if (nthread % nsx == 0)
1669         {
1670             for (nsy = 1; nsy <= nthread; nsy++)
1671             {
1672                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1673                 {
1674                     nsz = nthread/(nsx*nsy);
1675
1676                     /* Determine the number of grid points per thread */
1677                     gsize =
1678                         (div_round_up(n[XX], nsx) + ovl)*
1679                         (div_round_up(n[YY], nsy) + ovl)*
1680                         (div_round_up(n[ZZ], nsz) + ovl);
1681
1682                     /* Minimize the number of grids points per thread
1683                      * and, secondarily, the number of cuts in minor dimensions.
1684                      */
1685                     if (gsize_opt == -1 ||
1686                         gsize < gsize_opt ||
1687                         (gsize == gsize_opt &&
1688                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1689                     {
1690                         nsub[XX]  = nsx;
1691                         nsub[YY]  = nsy;
1692                         nsub[ZZ]  = nsz;
1693                         gsize_opt = gsize;
1694                     }
1695                 }
1696             }
1697         }
1698     }
1699
1700     env = getenv("GMX_PME_THREAD_DIVISION");
1701     if (env != NULL)
1702     {
1703         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1704     }
1705
1706     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1707     {
1708         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1709     }
1710 }
1711
1712 static void pmegrids_init(pmegrids_t *grids,
1713                           int nx, int ny, int nz, int nz_base,
1714                           int pme_order,
1715                           gmx_bool bUseThreads,
1716                           int nthread,
1717                           int overlap_x,
1718                           int overlap_y)
1719 {
1720     ivec n, n_base, g0, g1;
1721     int t, x, y, z, d, i, tfac;
1722     int max_comm_lines = -1;
1723
1724     n[XX] = nx - (pme_order - 1);
1725     n[YY] = ny - (pme_order - 1);
1726     n[ZZ] = nz - (pme_order - 1);
1727
1728     copy_ivec(n, n_base);
1729     n_base[ZZ] = nz_base;
1730
1731     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1732                  NULL);
1733
1734     grids->nthread = nthread;
1735
1736     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1737
1738     if (bUseThreads)
1739     {
1740         ivec nst;
1741         int gridsize;
1742
1743         for (d = 0; d < DIM; d++)
1744         {
1745             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1746         }
1747         set_grid_alignment(&nst[ZZ], pme_order);
1748
1749         if (debug)
1750         {
1751             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1752                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1753             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1754                     nx, ny, nz,
1755                     nst[XX], nst[YY], nst[ZZ]);
1756         }
1757
1758         snew(grids->grid_th, grids->nthread);
1759         t        = 0;
1760         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1761         set_gridsize_alignment(&gridsize, pme_order);
1762         snew_aligned(grids->grid_all,
1763                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1764                      16);
1765
1766         for (x = 0; x < grids->nc[XX]; x++)
1767         {
1768             for (y = 0; y < grids->nc[YY]; y++)
1769             {
1770                 for (z = 0; z < grids->nc[ZZ]; z++)
1771                 {
1772                     pmegrid_init(&grids->grid_th[t],
1773                                  x, y, z,
1774                                  (n[XX]*(x  ))/grids->nc[XX],
1775                                  (n[YY]*(y  ))/grids->nc[YY],
1776                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1777                                  (n[XX]*(x+1))/grids->nc[XX],
1778                                  (n[YY]*(y+1))/grids->nc[YY],
1779                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1780                                  TRUE,
1781                                  pme_order,
1782                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1783                     t++;
1784                 }
1785             }
1786         }
1787     }
1788     else
1789     {
1790         grids->grid_th = NULL;
1791     }
1792
1793     snew(grids->g2t, DIM);
1794     tfac = 1;
1795     for (d = DIM-1; d >= 0; d--)
1796     {
1797         snew(grids->g2t[d], n[d]);
1798         t = 0;
1799         for (i = 0; i < n[d]; i++)
1800         {
1801             /* The second check should match the parameters
1802              * of the pmegrid_init call above.
1803              */
1804             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1805             {
1806                 t++;
1807             }
1808             grids->g2t[d][i] = t*tfac;
1809         }
1810
1811         tfac *= grids->nc[d];
1812
1813         switch (d)
1814         {
1815             case XX: max_comm_lines = overlap_x;     break;
1816             case YY: max_comm_lines = overlap_y;     break;
1817             case ZZ: max_comm_lines = pme_order - 1; break;
1818         }
1819         grids->nthread_comm[d] = 0;
1820         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1821                grids->nthread_comm[d] < grids->nc[d])
1822         {
1823             grids->nthread_comm[d]++;
1824         }
1825         if (debug != NULL)
1826         {
1827             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1828                     'x'+d, grids->nthread_comm[d]);
1829         }
1830         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1831          * work, but this is not a problematic restriction.
1832          */
1833         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1834         {
1835             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1836         }
1837     }
1838 }
1839
1840
1841 static void pmegrids_destroy(pmegrids_t *grids)
1842 {
1843     int t;
1844
1845     if (grids->grid.grid != NULL)
1846     {
1847         sfree(grids->grid.grid);
1848
1849         if (grids->nthread > 0)
1850         {
1851             for (t = 0; t < grids->nthread; t++)
1852             {
1853                 sfree(grids->grid_th[t].grid);
1854             }
1855             sfree(grids->grid_th);
1856         }
1857     }
1858 }
1859
1860
1861 static void realloc_work(pme_work_t *work, int nkx)
1862 {
1863     if (nkx > work->nalloc)
1864     {
1865         work->nalloc = nkx;
1866         srenew(work->mhx, work->nalloc);
1867         srenew(work->mhy, work->nalloc);
1868         srenew(work->mhz, work->nalloc);
1869         srenew(work->m2, work->nalloc);
1870         /* Allocate an aligned pointer for SSE operations, including 3 extra
1871          * elements at the end since SSE operates on 4 elements at a time.
1872          */
1873         sfree_aligned(work->denom);
1874         sfree_aligned(work->tmp1);
1875         sfree_aligned(work->eterm);
1876         snew_aligned(work->denom, work->nalloc+3, 16);
1877         snew_aligned(work->tmp1, work->nalloc+3, 16);
1878         snew_aligned(work->eterm, work->nalloc+3, 16);
1879         srenew(work->m2inv, work->nalloc);
1880     }
1881 }
1882
1883
1884 static void free_work(pme_work_t *work)
1885 {
1886     sfree(work->mhx);
1887     sfree(work->mhy);
1888     sfree(work->mhz);
1889     sfree(work->m2);
1890     sfree_aligned(work->denom);
1891     sfree_aligned(work->tmp1);
1892     sfree_aligned(work->eterm);
1893     sfree(work->m2inv);
1894 }
1895
1896
1897 #ifdef PME_SSE
1898 /* Calculate exponentials through SSE in float precision */
1899 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1900 {
1901     {
1902         const __m128 two = _mm_set_ps(2.0f, 2.0f, 2.0f, 2.0f);
1903         __m128 f_sse;
1904         __m128 lu;
1905         __m128 tmp_d1, d_inv, tmp_r, tmp_e;
1906         int kx;
1907         f_sse = _mm_load1_ps(&f);
1908         for (kx = 0; kx < end; kx += 4)
1909         {
1910             tmp_d1   = _mm_load_ps(d_aligned+kx);
1911             lu       = _mm_rcp_ps(tmp_d1);
1912             d_inv    = _mm_mul_ps(lu, _mm_sub_ps(two, _mm_mul_ps(lu, tmp_d1)));
1913             tmp_r    = _mm_load_ps(r_aligned+kx);
1914             tmp_r    = gmx_mm_exp_ps(tmp_r);
1915             tmp_e    = _mm_mul_ps(f_sse, d_inv);
1916             tmp_e    = _mm_mul_ps(tmp_e, tmp_r);
1917             _mm_store_ps(e_aligned+kx, tmp_e);
1918         }
1919     }
1920 }
1921 #else
1922 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1923 {
1924     int kx;
1925     for (kx = start; kx < end; kx++)
1926     {
1927         d[kx] = 1.0/d[kx];
1928     }
1929     for (kx = start; kx < end; kx++)
1930     {
1931         r[kx] = exp(r[kx]);
1932     }
1933     for (kx = start; kx < end; kx++)
1934     {
1935         e[kx] = f*r[kx]*d[kx];
1936     }
1937 }
1938 #endif
1939
1940
1941 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1942                          real ewaldcoeff, real vol,
1943                          gmx_bool bEnerVir,
1944                          int nthread, int thread)
1945 {
1946     /* do recip sum over local cells in grid */
1947     /* y major, z middle, x minor or continuous */
1948     t_complex *p0;
1949     int     kx, ky, kz, maxkx, maxky, maxkz;
1950     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1951     real    mx, my, mz;
1952     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1953     real    ets2, struct2, vfactor, ets2vf;
1954     real    d1, d2, energy = 0;
1955     real    by, bz;
1956     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1957     real    rxx, ryx, ryy, rzx, rzy, rzz;
1958     pme_work_t *work;
1959     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1960     real    mhxk, mhyk, mhzk, m2k;
1961     real    corner_fac;
1962     ivec    complex_order;
1963     ivec    local_ndata, local_offset, local_size;
1964     real    elfac;
1965
1966     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1967
1968     nx = pme->nkx;
1969     ny = pme->nky;
1970     nz = pme->nkz;
1971
1972     /* Dimensions should be identical for A/B grid, so we just use A here */
1973     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1974                                       complex_order,
1975                                       local_ndata,
1976                                       local_offset,
1977                                       local_size);
1978
1979     rxx = pme->recipbox[XX][XX];
1980     ryx = pme->recipbox[YY][XX];
1981     ryy = pme->recipbox[YY][YY];
1982     rzx = pme->recipbox[ZZ][XX];
1983     rzy = pme->recipbox[ZZ][YY];
1984     rzz = pme->recipbox[ZZ][ZZ];
1985
1986     maxkx = (nx+1)/2;
1987     maxky = (ny+1)/2;
1988     maxkz = nz/2+1;
1989
1990     work  = &pme->work[thread];
1991     mhx   = work->mhx;
1992     mhy   = work->mhy;
1993     mhz   = work->mhz;
1994     m2    = work->m2;
1995     denom = work->denom;
1996     tmp1  = work->tmp1;
1997     eterm = work->eterm;
1998     m2inv = work->m2inv;
1999
2000     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
2001     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
2002
2003     for (iyz = iyz0; iyz < iyz1; iyz++)
2004     {
2005         iy = iyz/local_ndata[ZZ];
2006         iz = iyz - iy*local_ndata[ZZ];
2007
2008         ky = iy + local_offset[YY];
2009
2010         if (ky < maxky)
2011         {
2012             my = ky;
2013         }
2014         else
2015         {
2016             my = (ky - ny);
2017         }
2018
2019         by = M_PI*vol*pme->bsp_mod[YY][ky];
2020
2021         kz = iz + local_offset[ZZ];
2022
2023         mz = kz;
2024
2025         bz = pme->bsp_mod[ZZ][kz];
2026
2027         /* 0.5 correction for corner points */
2028         corner_fac = 1;
2029         if (kz == 0 || kz == (nz+1)/2)
2030         {
2031             corner_fac = 0.5;
2032         }
2033
2034         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2035
2036         /* We should skip the k-space point (0,0,0) */
2037         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2038         {
2039             kxstart = local_offset[XX];
2040         }
2041         else
2042         {
2043             kxstart = local_offset[XX] + 1;
2044             p0++;
2045         }
2046         kxend = local_offset[XX] + local_ndata[XX];
2047
2048         if (bEnerVir)
2049         {
2050             /* More expensive inner loop, especially because of the storage
2051              * of the mh elements in array's.
2052              * Because x is the minor grid index, all mh elements
2053              * depend on kx for triclinic unit cells.
2054              */
2055
2056             /* Two explicit loops to avoid a conditional inside the loop */
2057             for (kx = kxstart; kx < maxkx; kx++)
2058             {
2059                 mx = kx;
2060
2061                 mhxk      = mx * rxx;
2062                 mhyk      = mx * ryx + my * ryy;
2063                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2064                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2065                 mhx[kx]   = mhxk;
2066                 mhy[kx]   = mhyk;
2067                 mhz[kx]   = mhzk;
2068                 m2[kx]    = m2k;
2069                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2070                 tmp1[kx]  = -factor*m2k;
2071             }
2072
2073             for (kx = maxkx; kx < kxend; kx++)
2074             {
2075                 mx = (kx - nx);
2076
2077                 mhxk      = mx * rxx;
2078                 mhyk      = mx * ryx + my * ryy;
2079                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2080                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2081                 mhx[kx]   = mhxk;
2082                 mhy[kx]   = mhyk;
2083                 mhz[kx]   = mhzk;
2084                 m2[kx]    = m2k;
2085                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2086                 tmp1[kx]  = -factor*m2k;
2087             }
2088
2089             for (kx = kxstart; kx < kxend; kx++)
2090             {
2091                 m2inv[kx] = 1.0/m2[kx];
2092             }
2093
2094             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2095
2096             for (kx = kxstart; kx < kxend; kx++, p0++)
2097             {
2098                 d1      = p0->re;
2099                 d2      = p0->im;
2100
2101                 p0->re  = d1*eterm[kx];
2102                 p0->im  = d2*eterm[kx];
2103
2104                 struct2 = 2.0*(d1*d1+d2*d2);
2105
2106                 tmp1[kx] = eterm[kx]*struct2;
2107             }
2108
2109             for (kx = kxstart; kx < kxend; kx++)
2110             {
2111                 ets2     = corner_fac*tmp1[kx];
2112                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2113                 energy  += ets2;
2114
2115                 ets2vf   = ets2*vfactor;
2116                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2117                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2118                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2119                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2120                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2121                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2122             }
2123         }
2124         else
2125         {
2126             /* We don't need to calculate the energy and the virial.
2127              * In this case the triclinic overhead is small.
2128              */
2129
2130             /* Two explicit loops to avoid a conditional inside the loop */
2131
2132             for (kx = kxstart; kx < maxkx; kx++)
2133             {
2134                 mx = kx;
2135
2136                 mhxk      = mx * rxx;
2137                 mhyk      = mx * ryx + my * ryy;
2138                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2139                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2140                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2141                 tmp1[kx]  = -factor*m2k;
2142             }
2143
2144             for (kx = maxkx; kx < kxend; kx++)
2145             {
2146                 mx = (kx - nx);
2147
2148                 mhxk      = mx * rxx;
2149                 mhyk      = mx * ryx + my * ryy;
2150                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2151                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2152                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2153                 tmp1[kx]  = -factor*m2k;
2154             }
2155
2156             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2157
2158             for (kx = kxstart; kx < kxend; kx++, p0++)
2159             {
2160                 d1      = p0->re;
2161                 d2      = p0->im;
2162
2163                 p0->re  = d1*eterm[kx];
2164                 p0->im  = d2*eterm[kx];
2165             }
2166         }
2167     }
2168
2169     if (bEnerVir)
2170     {
2171         /* Update virial with local values.
2172          * The virial is symmetric by definition.
2173          * this virial seems ok for isotropic scaling, but I'm
2174          * experiencing problems on semiisotropic membranes.
2175          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2176          */
2177         work->vir[XX][XX] = 0.25*virxx;
2178         work->vir[YY][YY] = 0.25*viryy;
2179         work->vir[ZZ][ZZ] = 0.25*virzz;
2180         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2181         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2182         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2183
2184         /* This energy should be corrected for a charged system */
2185         work->energy = 0.5*energy;
2186     }
2187
2188     /* Return the loop count */
2189     return local_ndata[YY]*local_ndata[XX];
2190 }
2191
2192 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2193                              real *mesh_energy, matrix vir)
2194 {
2195     /* This function sums output over threads
2196      * and should therefore only be called after thread synchronization.
2197      */
2198     int thread;
2199
2200     *mesh_energy = pme->work[0].energy;
2201     copy_mat(pme->work[0].vir, vir);
2202
2203     for (thread = 1; thread < nthread; thread++)
2204     {
2205         *mesh_energy += pme->work[thread].energy;
2206         m_add(vir, pme->work[thread].vir, vir);
2207     }
2208 }
2209
2210 #define DO_FSPLINE(order)                      \
2211     for (ithx = 0; (ithx < order); ithx++)              \
2212     {                                              \
2213         index_x = (i0+ithx)*pny*pnz;               \
2214         tx      = thx[ithx];                       \
2215         dx      = dthx[ithx];                      \
2216                                                \
2217         for (ithy = 0; (ithy < order); ithy++)          \
2218         {                                          \
2219             index_xy = index_x+(j0+ithy)*pnz;      \
2220             ty       = thy[ithy];                  \
2221             dy       = dthy[ithy];                 \
2222             fxy1     = fz1 = 0;                    \
2223                                                \
2224             for (ithz = 0; (ithz < order); ithz++)      \
2225             {                                      \
2226                 gval  = grid[index_xy+(k0+ithz)];  \
2227                 fxy1 += thz[ithz]*gval;            \
2228                 fz1  += dthz[ithz]*gval;           \
2229             }                                      \
2230             fx += dx*ty*fxy1;                      \
2231             fy += tx*dy*fxy1;                      \
2232             fz += tx*ty*fz1;                       \
2233         }                                          \
2234     }
2235
2236
2237 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2238                               gmx_bool bClearF, pme_atomcomm_t *atc,
2239                               splinedata_t *spline,
2240                               real scale)
2241 {
2242     /* sum forces for local particles */
2243     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2244     int     index_x, index_xy;
2245     int     nx, ny, nz, pnx, pny, pnz;
2246     int *   idxptr;
2247     real    tx, ty, dx, dy, qn;
2248     real    fx, fy, fz, gval;
2249     real    fxy1, fz1;
2250     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2251     int     norder;
2252     real    rxx, ryx, ryy, rzx, rzy, rzz;
2253     int     order;
2254
2255     pme_spline_work_t *work;
2256
2257     work = pme->spline_work;
2258
2259     order = pme->pme_order;
2260     thx   = spline->theta[XX];
2261     thy   = spline->theta[YY];
2262     thz   = spline->theta[ZZ];
2263     dthx  = spline->dtheta[XX];
2264     dthy  = spline->dtheta[YY];
2265     dthz  = spline->dtheta[ZZ];
2266     nx    = pme->nkx;
2267     ny    = pme->nky;
2268     nz    = pme->nkz;
2269     pnx   = pme->pmegrid_nx;
2270     pny   = pme->pmegrid_ny;
2271     pnz   = pme->pmegrid_nz;
2272
2273     rxx   = pme->recipbox[XX][XX];
2274     ryx   = pme->recipbox[YY][XX];
2275     ryy   = pme->recipbox[YY][YY];
2276     rzx   = pme->recipbox[ZZ][XX];
2277     rzy   = pme->recipbox[ZZ][YY];
2278     rzz   = pme->recipbox[ZZ][ZZ];
2279
2280     for (nn = 0; nn < spline->n; nn++)
2281     {
2282         n  = spline->ind[nn];
2283         qn = scale*atc->q[n];
2284
2285         if (bClearF)
2286         {
2287             atc->f[n][XX] = 0;
2288             atc->f[n][YY] = 0;
2289             atc->f[n][ZZ] = 0;
2290         }
2291         if (qn != 0)
2292         {
2293             fx     = 0;
2294             fy     = 0;
2295             fz     = 0;
2296             idxptr = atc->idx[n];
2297             norder = nn*order;
2298
2299             i0   = idxptr[XX];
2300             j0   = idxptr[YY];
2301             k0   = idxptr[ZZ];
2302
2303             /* Pointer arithmetic alert, next six statements */
2304             thx  = spline->theta[XX] + norder;
2305             thy  = spline->theta[YY] + norder;
2306             thz  = spline->theta[ZZ] + norder;
2307             dthx = spline->dtheta[XX] + norder;
2308             dthy = spline->dtheta[YY] + norder;
2309             dthz = spline->dtheta[ZZ] + norder;
2310
2311             switch (order)
2312             {
2313                 case 4:
2314 #ifdef PME_SSE
2315 #ifdef PME_SSE_UNALIGNED
2316 #define PME_GATHER_F_SSE_ORDER4
2317 #else
2318 #define PME_GATHER_F_SSE_ALIGNED
2319 #define PME_ORDER 4
2320 #endif
2321 #include "pme_sse_single.h"
2322 #else
2323                     DO_FSPLINE(4);
2324 #endif
2325                     break;
2326                 case 5:
2327 #ifdef PME_SSE
2328 #define PME_GATHER_F_SSE_ALIGNED
2329 #define PME_ORDER 5
2330 #include "pme_sse_single.h"
2331 #else
2332                     DO_FSPLINE(5);
2333 #endif
2334                     break;
2335                 default:
2336                     DO_FSPLINE(order);
2337                     break;
2338             }
2339
2340             atc->f[n][XX] += -qn*( fx*nx*rxx );
2341             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2342             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2343         }
2344     }
2345     /* Since the energy and not forces are interpolated
2346      * the net force might not be exactly zero.
2347      * This can be solved by also interpolating F, but
2348      * that comes at a cost.
2349      * A better hack is to remove the net force every
2350      * step, but that must be done at a higher level
2351      * since this routine doesn't see all atoms if running
2352      * in parallel. Don't know how important it is?  EL 990726
2353      */
2354 }
2355
2356
2357 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2358                                    pme_atomcomm_t *atc)
2359 {
2360     splinedata_t *spline;
2361     int     n, ithx, ithy, ithz, i0, j0, k0;
2362     int     index_x, index_xy;
2363     int *   idxptr;
2364     real    energy, pot, tx, ty, qn, gval;
2365     real    *thx, *thy, *thz;
2366     int     norder;
2367     int     order;
2368
2369     spline = &atc->spline[0];
2370
2371     order = pme->pme_order;
2372
2373     energy = 0;
2374     for (n = 0; (n < atc->n); n++)
2375     {
2376         qn      = atc->q[n];
2377
2378         if (qn != 0)
2379         {
2380             idxptr = atc->idx[n];
2381             norder = n*order;
2382
2383             i0   = idxptr[XX];
2384             j0   = idxptr[YY];
2385             k0   = idxptr[ZZ];
2386
2387             /* Pointer arithmetic alert, next three statements */
2388             thx  = spline->theta[XX] + norder;
2389             thy  = spline->theta[YY] + norder;
2390             thz  = spline->theta[ZZ] + norder;
2391
2392             pot = 0;
2393             for (ithx = 0; (ithx < order); ithx++)
2394             {
2395                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2396                 tx      = thx[ithx];
2397
2398                 for (ithy = 0; (ithy < order); ithy++)
2399                 {
2400                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2401                     ty       = thy[ithy];
2402
2403                     for (ithz = 0; (ithz < order); ithz++)
2404                     {
2405                         gval  = grid[index_xy+(k0+ithz)];
2406                         pot  += tx*ty*thz[ithz]*gval;
2407                     }
2408
2409                 }
2410             }
2411
2412             energy += pot*qn;
2413         }
2414     }
2415
2416     return energy;
2417 }
2418
2419 /* Macro to force loop unrolling by fixing order.
2420  * This gives a significant performance gain.
2421  */
2422 #define CALC_SPLINE(order)                     \
2423     {                                              \
2424         int j, k, l;                                 \
2425         real dr, div;                               \
2426         real data[PME_ORDER_MAX];                  \
2427         real ddata[PME_ORDER_MAX];                 \
2428                                                \
2429         for (j = 0; (j < DIM); j++)                     \
2430         {                                          \
2431             dr  = xptr[j];                         \
2432                                                \
2433             /* dr is relative offset from lower cell limit */ \
2434             data[order-1] = 0;                     \
2435             data[1]       = dr;                          \
2436             data[0]       = 1 - dr;                      \
2437                                                \
2438             for (k = 3; (k < order); k++)               \
2439             {                                      \
2440                 div       = 1.0/(k - 1.0);               \
2441                 data[k-1] = div*dr*data[k-2];      \
2442                 for (l = 1; (l < (k-1)); l++)           \
2443                 {                                  \
2444                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2445                                        data[k-l-1]);                \
2446                 }                                  \
2447                 data[0] = div*(1-dr)*data[0];      \
2448             }                                      \
2449             /* differentiate */                    \
2450             ddata[0] = -data[0];                   \
2451             for (k = 1; (k < order); k++)               \
2452             {                                      \
2453                 ddata[k] = data[k-1] - data[k];    \
2454             }                                      \
2455                                                \
2456             div           = 1.0/(order - 1);                 \
2457             data[order-1] = div*dr*data[order-2];  \
2458             for (l = 1; (l < (order-1)); l++)           \
2459             {                                      \
2460                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2461                                        (order-l-dr)*data[order-l-1]); \
2462             }                                      \
2463             data[0] = div*(1 - dr)*data[0];        \
2464                                                \
2465             for (k = 0; k < order; k++)                 \
2466             {                                      \
2467                 theta[j][i*order+k]  = data[k];    \
2468                 dtheta[j][i*order+k] = ddata[k];   \
2469             }                                      \
2470         }                                          \
2471     }
2472
2473 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2474                    rvec fractx[], int nr, int ind[], real charge[],
2475                    gmx_bool bFreeEnergy)
2476 {
2477     /* construct splines for local atoms */
2478     int  i, ii;
2479     real *xptr;
2480
2481     for (i = 0; i < nr; i++)
2482     {
2483         /* With free energy we do not use the charge check.
2484          * In most cases this will be more efficient than calling make_bsplines
2485          * twice, since usually more than half the particles have charges.
2486          */
2487         ii = ind[i];
2488         if (bFreeEnergy || charge[ii] != 0.0)
2489         {
2490             xptr = fractx[ii];
2491             switch (order)
2492             {
2493                 case 4:  CALC_SPLINE(4);     break;
2494                 case 5:  CALC_SPLINE(5);     break;
2495                 default: CALC_SPLINE(order); break;
2496             }
2497         }
2498     }
2499 }
2500
2501
2502 void make_dft_mod(real *mod, real *data, int ndata)
2503 {
2504     int i, j;
2505     real sc, ss, arg;
2506
2507     for (i = 0; i < ndata; i++)
2508     {
2509         sc = ss = 0;
2510         for (j = 0; j < ndata; j++)
2511         {
2512             arg = (2.0*M_PI*i*j)/ndata;
2513             sc += data[j]*cos(arg);
2514             ss += data[j]*sin(arg);
2515         }
2516         mod[i] = sc*sc+ss*ss;
2517     }
2518     for (i = 0; i < ndata; i++)
2519     {
2520         if (mod[i] < 1e-7)
2521         {
2522             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2523         }
2524     }
2525 }
2526
2527
2528 static void make_bspline_moduli(splinevec bsp_mod,
2529                                 int nx, int ny, int nz, int order)
2530 {
2531     int nmax = max(nx, max(ny, nz));
2532     real *data, *ddata, *bsp_data;
2533     int i, k, l;
2534     real div;
2535
2536     snew(data, order);
2537     snew(ddata, order);
2538     snew(bsp_data, nmax);
2539
2540     data[order-1] = 0;
2541     data[1]       = 0;
2542     data[0]       = 1;
2543
2544     for (k = 3; k < order; k++)
2545     {
2546         div       = 1.0/(k-1.0);
2547         data[k-1] = 0;
2548         for (l = 1; l < (k-1); l++)
2549         {
2550             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2551         }
2552         data[0] = div*data[0];
2553     }
2554     /* differentiate */
2555     ddata[0] = -data[0];
2556     for (k = 1; k < order; k++)
2557     {
2558         ddata[k] = data[k-1]-data[k];
2559     }
2560     div           = 1.0/(order-1);
2561     data[order-1] = 0;
2562     for (l = 1; l < (order-1); l++)
2563     {
2564         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2565     }
2566     data[0] = div*data[0];
2567
2568     for (i = 0; i < nmax; i++)
2569     {
2570         bsp_data[i] = 0;
2571     }
2572     for (i = 1; i <= order; i++)
2573     {
2574         bsp_data[i] = data[i-1];
2575     }
2576
2577     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2578     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2579     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2580
2581     sfree(data);
2582     sfree(ddata);
2583     sfree(bsp_data);
2584 }
2585
2586
2587 /* Return the P3M optimal influence function */
2588 static double do_p3m_influence(double z, int order)
2589 {
2590     double z2, z4;
2591
2592     z2 = z*z;
2593     z4 = z2*z2;
2594
2595     /* The formula and most constants can be found in:
2596      * Ballenegger et al., JCTC 8, 936 (2012)
2597      */
2598     switch (order)
2599     {
2600         case 2:
2601             return 1.0 - 2.0*z2/3.0;
2602             break;
2603         case 3:
2604             return 1.0 - z2 + 2.0*z4/15.0;
2605             break;
2606         case 4:
2607             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2608             break;
2609         case 5:
2610             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2611             break;
2612         case 6:
2613             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2614             break;
2615         case 7:
2616             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2617         case 8:
2618             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2619             break;
2620     }
2621
2622     return 0.0;
2623 }
2624
2625 /* Calculate the P3M B-spline moduli for one dimension */
2626 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2627 {
2628     double zarg, zai, sinzai, infl;
2629     int    maxk, i;
2630
2631     if (order > 8)
2632     {
2633         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2634     }
2635
2636     zarg = M_PI/n;
2637
2638     maxk = (n + 1)/2;
2639
2640     for (i = -maxk; i < 0; i++)
2641     {
2642         zai          = zarg*i;
2643         sinzai       = sin(zai);
2644         infl         = do_p3m_influence(sinzai, order);
2645         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2646     }
2647     bsp_mod[0] = 1.0;
2648     for (i = 1; i < maxk; i++)
2649     {
2650         zai        = zarg*i;
2651         sinzai     = sin(zai);
2652         infl       = do_p3m_influence(sinzai, order);
2653         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2654     }
2655 }
2656
2657 /* Calculate the P3M B-spline moduli */
2658 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2659                                     int nx, int ny, int nz, int order)
2660 {
2661     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2662     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2663     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2664 }
2665
2666
2667 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2668 {
2669     int nslab, n, i;
2670     int fw, bw;
2671
2672     nslab = atc->nslab;
2673
2674     n = 0;
2675     for (i = 1; i <= nslab/2; i++)
2676     {
2677         fw = (atc->nodeid + i) % nslab;
2678         bw = (atc->nodeid - i + nslab) % nslab;
2679         if (n < nslab - 1)
2680         {
2681             atc->node_dest[n] = fw;
2682             atc->node_src[n]  = bw;
2683             n++;
2684         }
2685         if (n < nslab - 1)
2686         {
2687             atc->node_dest[n] = bw;
2688             atc->node_src[n]  = fw;
2689             n++;
2690         }
2691     }
2692 }
2693
2694 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2695 {
2696     int thread;
2697
2698     if (NULL != log)
2699     {
2700         fprintf(log, "Destroying PME data structures.\n");
2701     }
2702
2703     sfree((*pmedata)->nnx);
2704     sfree((*pmedata)->nny);
2705     sfree((*pmedata)->nnz);
2706
2707     pmegrids_destroy(&(*pmedata)->pmegridA);
2708
2709     sfree((*pmedata)->fftgridA);
2710     sfree((*pmedata)->cfftgridA);
2711     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2712
2713     if ((*pmedata)->pmegridB.grid.grid != NULL)
2714     {
2715         pmegrids_destroy(&(*pmedata)->pmegridB);
2716         sfree((*pmedata)->fftgridB);
2717         sfree((*pmedata)->cfftgridB);
2718         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2719     }
2720     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2721     {
2722         free_work(&(*pmedata)->work[thread]);
2723     }
2724     sfree((*pmedata)->work);
2725
2726     sfree(*pmedata);
2727     *pmedata = NULL;
2728
2729     return 0;
2730 }
2731
2732 static int mult_up(int n, int f)
2733 {
2734     return ((n + f - 1)/f)*f;
2735 }
2736
2737
2738 static double pme_load_imbalance(gmx_pme_t pme)
2739 {
2740     int    nma, nmi;
2741     double n1, n2, n3;
2742
2743     nma = pme->nnodes_major;
2744     nmi = pme->nnodes_minor;
2745
2746     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2747     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2748     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2749
2750     /* pme_solve is roughly double the cost of an fft */
2751
2752     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2753 }
2754
2755 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2756                           int dimind, gmx_bool bSpread)
2757 {
2758     int nk, k, s, thread;
2759
2760     atc->dimind    = dimind;
2761     atc->nslab     = 1;
2762     atc->nodeid    = 0;
2763     atc->pd_nalloc = 0;
2764 #ifdef GMX_MPI
2765     if (pme->nnodes > 1)
2766     {
2767         atc->mpi_comm = pme->mpi_comm_d[dimind];
2768         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2769         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2770     }
2771     if (debug)
2772     {
2773         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2774     }
2775 #endif
2776
2777     atc->bSpread   = bSpread;
2778     atc->pme_order = pme->pme_order;
2779
2780     if (atc->nslab > 1)
2781     {
2782         /* These three allocations are not required for particle decomp. */
2783         snew(atc->node_dest, atc->nslab);
2784         snew(atc->node_src, atc->nslab);
2785         setup_coordinate_communication(atc);
2786
2787         snew(atc->count_thread, pme->nthread);
2788         for (thread = 0; thread < pme->nthread; thread++)
2789         {
2790             snew(atc->count_thread[thread], atc->nslab);
2791         }
2792         atc->count = atc->count_thread[0];
2793         snew(atc->rcount, atc->nslab);
2794         snew(atc->buf_index, atc->nslab);
2795     }
2796
2797     atc->nthread = pme->nthread;
2798     if (atc->nthread > 1)
2799     {
2800         snew(atc->thread_plist, atc->nthread);
2801     }
2802     snew(atc->spline, atc->nthread);
2803     for (thread = 0; thread < atc->nthread; thread++)
2804     {
2805         if (atc->nthread > 1)
2806         {
2807             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2808             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2809         }
2810         snew(atc->spline[thread].thread_one, pme->nthread);
2811         atc->spline[thread].thread_one[thread] = 1;
2812     }
2813 }
2814
2815 static void
2816 init_overlap_comm(pme_overlap_t *  ol,
2817                   int              norder,
2818 #ifdef GMX_MPI
2819                   MPI_Comm         comm,
2820 #endif
2821                   int              nnodes,
2822                   int              nodeid,
2823                   int              ndata,
2824                   int              commplainsize)
2825 {
2826     int lbnd, rbnd, maxlr, b, i;
2827     int exten;
2828     int nn, nk;
2829     pme_grid_comm_t *pgc;
2830     gmx_bool bCont;
2831     int fft_start, fft_end, send_index1, recv_index1;
2832 #ifdef GMX_MPI
2833     MPI_Status stat;
2834
2835     ol->mpi_comm = comm;
2836 #endif
2837
2838     ol->nnodes = nnodes;
2839     ol->nodeid = nodeid;
2840
2841     /* Linear translation of the PME grid won't affect reciprocal space
2842      * calculations, so to optimize we only interpolate "upwards",
2843      * which also means we only have to consider overlap in one direction.
2844      * I.e., particles on this node might also be spread to grid indices
2845      * that belong to higher nodes (modulo nnodes)
2846      */
2847
2848     snew(ol->s2g0, ol->nnodes+1);
2849     snew(ol->s2g1, ol->nnodes);
2850     if (debug)
2851     {
2852         fprintf(debug, "PME slab boundaries:");
2853     }
2854     for (i = 0; i < nnodes; i++)
2855     {
2856         /* s2g0 the local interpolation grid start.
2857          * s2g1 the local interpolation grid end.
2858          * Because grid overlap communication only goes forward,
2859          * the grid the slabs for fft's should be rounded down.
2860          */
2861         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2862         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2863
2864         if (debug)
2865         {
2866             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2867         }
2868     }
2869     ol->s2g0[nnodes] = ndata;
2870     if (debug)
2871     {
2872         fprintf(debug, "\n");
2873     }
2874
2875     /* Determine with how many nodes we need to communicate the grid overlap */
2876     b = 0;
2877     do
2878     {
2879         b++;
2880         bCont = FALSE;
2881         for (i = 0; i < nnodes; i++)
2882         {
2883             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2884                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2885             {
2886                 bCont = TRUE;
2887             }
2888         }
2889     }
2890     while (bCont && b < nnodes);
2891     ol->noverlap_nodes = b - 1;
2892
2893     snew(ol->send_id, ol->noverlap_nodes);
2894     snew(ol->recv_id, ol->noverlap_nodes);
2895     for (b = 0; b < ol->noverlap_nodes; b++)
2896     {
2897         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2898         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2899     }
2900     snew(ol->comm_data, ol->noverlap_nodes);
2901
2902     ol->send_size = 0;
2903     for (b = 0; b < ol->noverlap_nodes; b++)
2904     {
2905         pgc = &ol->comm_data[b];
2906         /* Send */
2907         fft_start        = ol->s2g0[ol->send_id[b]];
2908         fft_end          = ol->s2g0[ol->send_id[b]+1];
2909         if (ol->send_id[b] < nodeid)
2910         {
2911             fft_start += ndata;
2912             fft_end   += ndata;
2913         }
2914         send_index1       = ol->s2g1[nodeid];
2915         send_index1       = min(send_index1, fft_end);
2916         pgc->send_index0  = fft_start;
2917         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2918         ol->send_size    += pgc->send_nindex;
2919
2920         /* We always start receiving to the first index of our slab */
2921         fft_start        = ol->s2g0[ol->nodeid];
2922         fft_end          = ol->s2g0[ol->nodeid+1];
2923         recv_index1      = ol->s2g1[ol->recv_id[b]];
2924         if (ol->recv_id[b] > nodeid)
2925         {
2926             recv_index1 -= ndata;
2927         }
2928         recv_index1      = min(recv_index1, fft_end);
2929         pgc->recv_index0 = fft_start;
2930         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2931     }
2932
2933 #ifdef GMX_MPI
2934     /* Communicate the buffer sizes to receive */
2935     for (b = 0; b < ol->noverlap_nodes; b++)
2936     {
2937         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2938                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2939                      ol->mpi_comm, &stat);
2940     }
2941 #endif
2942
2943     /* For non-divisible grid we need pme_order iso pme_order-1 */
2944     snew(ol->sendbuf, norder*commplainsize);
2945     snew(ol->recvbuf, norder*commplainsize);
2946 }
2947
2948 static void
2949 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2950                               int **global_to_local,
2951                               real **fraction_shift)
2952 {
2953     int i;
2954     int * gtl;
2955     real * fsh;
2956
2957     snew(gtl, 5*n);
2958     snew(fsh, 5*n);
2959     for (i = 0; (i < 5*n); i++)
2960     {
2961         /* Determine the global to local grid index */
2962         gtl[i] = (i - local_start + n) % n;
2963         /* For coordinates that fall within the local grid the fraction
2964          * is correct, we don't need to shift it.
2965          */
2966         fsh[i] = 0;
2967         if (local_range < n)
2968         {
2969             /* Due to rounding issues i could be 1 beyond the lower or
2970              * upper boundary of the local grid. Correct the index for this.
2971              * If we shift the index, we need to shift the fraction by
2972              * the same amount in the other direction to not affect
2973              * the weights.
2974              * Note that due to this shifting the weights at the end of
2975              * the spline might change, but that will only involve values
2976              * between zero and values close to the precision of a real,
2977              * which is anyhow the accuracy of the whole mesh calculation.
2978              */
2979             /* With local_range=0 we should not change i=local_start */
2980             if (i % n != local_start)
2981             {
2982                 if (gtl[i] == n-1)
2983                 {
2984                     gtl[i] = 0;
2985                     fsh[i] = -1;
2986                 }
2987                 else if (gtl[i] == local_range)
2988                 {
2989                     gtl[i] = local_range - 1;
2990                     fsh[i] = 1;
2991                 }
2992             }
2993         }
2994     }
2995
2996     *global_to_local = gtl;
2997     *fraction_shift  = fsh;
2998 }
2999
3000 static pme_spline_work_t *make_pme_spline_work(int order)
3001 {
3002     pme_spline_work_t *work;
3003
3004 #ifdef PME_SSE
3005     float  tmp[8];
3006     __m128 zero_SSE;
3007     int    of, i;
3008
3009     snew_aligned(work, 1, 16);
3010
3011     zero_SSE = _mm_setzero_ps();
3012
3013     /* Generate bit masks to mask out the unused grid entries,
3014      * as we only operate on order of the 8 grid entries that are
3015      * load into 2 SSE float registers.
3016      */
3017     for (of = 0; of < 8-(order-1); of++)
3018     {
3019         for (i = 0; i < 8; i++)
3020         {
3021             tmp[i] = (i >= of && i < of+order ? 1 : 0);
3022         }
3023         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
3024         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
3025         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
3026         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
3027     }
3028 #else
3029     work = NULL;
3030 #endif
3031
3032     return work;
3033 }
3034
3035 static void
3036 gmx_pme_check_grid_restrictions(FILE *fplog, char dim, int nnodes, int *nk)
3037 {
3038     int nk_new;
3039
3040     if (*nk % nnodes != 0)
3041     {
3042         nk_new = nnodes*(*nk/nnodes + 1);
3043
3044         if (2*nk_new >= 3*(*nk))
3045         {
3046             gmx_fatal(FARGS, "The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
3047                       dim, *nk, dim, nnodes, dim);
3048         }
3049
3050         if (fplog != NULL)
3051         {
3052             fprintf(fplog, "\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
3053                     dim, *nk, dim, nnodes, dim, nk_new, dim);
3054         }
3055
3056         *nk = nk_new;
3057     }
3058 }
3059
3060 int gmx_pme_init(gmx_pme_t *         pmedata,
3061                  t_commrec *         cr,
3062                  int                 nnodes_major,
3063                  int                 nnodes_minor,
3064                  t_inputrec *        ir,
3065                  int                 homenr,
3066                  gmx_bool            bFreeEnergy,
3067                  gmx_bool            bReproducible,
3068                  int                 nthread)
3069 {
3070     gmx_pme_t pme = NULL;
3071
3072     int  use_threads, sum_use_threads;
3073     ivec ndata;
3074
3075     if (debug)
3076     {
3077         fprintf(debug, "Creating PME data structures.\n");
3078     }
3079     snew(pme, 1);
3080
3081     pme->redist_init         = FALSE;
3082     pme->sum_qgrid_tmp       = NULL;
3083     pme->sum_qgrid_dd_tmp    = NULL;
3084     pme->buf_nalloc          = 0;
3085     pme->redist_buf_nalloc   = 0;
3086
3087     pme->nnodes              = 1;
3088     pme->bPPnode             = TRUE;
3089
3090     pme->nnodes_major        = nnodes_major;
3091     pme->nnodes_minor        = nnodes_minor;
3092
3093 #ifdef GMX_MPI
3094     if (nnodes_major*nnodes_minor > 1)
3095     {
3096         pme->mpi_comm = cr->mpi_comm_mygroup;
3097
3098         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3099         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3100         if (pme->nnodes != nnodes_major*nnodes_minor)
3101         {
3102             gmx_incons("PME node count mismatch");
3103         }
3104     }
3105     else
3106     {
3107         pme->mpi_comm = MPI_COMM_NULL;
3108     }
3109 #endif
3110
3111     if (pme->nnodes == 1)
3112     {
3113 #ifdef GMX_MPI
3114         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3115         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3116 #endif
3117         pme->ndecompdim   = 0;
3118         pme->nodeid_major = 0;
3119         pme->nodeid_minor = 0;
3120 #ifdef GMX_MPI
3121         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3122 #endif
3123     }
3124     else
3125     {
3126         if (nnodes_minor == 1)
3127         {
3128 #ifdef GMX_MPI
3129             pme->mpi_comm_d[0] = pme->mpi_comm;
3130             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3131 #endif
3132             pme->ndecompdim   = 1;
3133             pme->nodeid_major = pme->nodeid;
3134             pme->nodeid_minor = 0;
3135
3136         }
3137         else if (nnodes_major == 1)
3138         {
3139 #ifdef GMX_MPI
3140             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3141             pme->mpi_comm_d[1] = pme->mpi_comm;
3142 #endif
3143             pme->ndecompdim   = 1;
3144             pme->nodeid_major = 0;
3145             pme->nodeid_minor = pme->nodeid;
3146         }
3147         else
3148         {
3149             if (pme->nnodes % nnodes_major != 0)
3150             {
3151                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3152             }
3153             pme->ndecompdim = 2;
3154
3155 #ifdef GMX_MPI
3156             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3157                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3158             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3159                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3160
3161             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3162             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3163             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3164             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3165 #endif
3166         }
3167         pme->bPPnode = (cr->duty & DUTY_PP);
3168     }
3169
3170     pme->nthread = nthread;
3171
3172      /* Check if any of the PME MPI ranks uses threads */
3173     use_threads = (pme->nthread > 1 ? 1 : 0);
3174 #ifdef GMX_MPI
3175     if (pme->nnodes > 1)
3176     {
3177         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3178                       MPI_SUM, pme->mpi_comm);
3179     }
3180     else
3181 #endif
3182     {
3183         sum_use_threads = use_threads;
3184     }
3185     pme->bUseThreads = (sum_use_threads > 0);
3186
3187     if (ir->ePBC == epbcSCREW)
3188     {
3189         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3190     }
3191
3192     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3193     pme->nkx         = ir->nkx;
3194     pme->nky         = ir->nky;
3195     pme->nkz         = ir->nkz;
3196     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3197     pme->pme_order   = ir->pme_order;
3198     pme->epsilon_r   = ir->epsilon_r;
3199
3200     if (pme->pme_order > PME_ORDER_MAX)
3201     {
3202         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3203                   pme->pme_order, PME_ORDER_MAX);
3204     }
3205
3206     /* Currently pme.c supports only the fft5d FFT code.
3207      * Therefore the grid always needs to be divisible by nnodes.
3208      * When the old 1D code is also supported again, change this check.
3209      *
3210      * This check should be done before calling gmx_pme_init
3211      * and fplog should be passed iso stderr.
3212      *
3213        if (pme->ndecompdim >= 2)
3214      */
3215     if (pme->ndecompdim >= 1)
3216     {
3217         /*
3218            gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3219                                         'x',nnodes_major,&pme->nkx);
3220            gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3221                                         'y',nnodes_minor,&pme->nky);
3222          */
3223     }
3224
3225     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3226         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3227         pme->nkz <= pme->pme_order)
3228     {
3229         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order", pme->pme_order);
3230     }
3231
3232     if (pme->nnodes > 1)
3233     {
3234         double imbal;
3235
3236 #ifdef GMX_MPI
3237         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3238         MPI_Type_commit(&(pme->rvec_mpi));
3239 #endif
3240
3241         /* Note that the charge spreading and force gathering, which usually
3242          * takes about the same amount of time as FFT+solve_pme,
3243          * is always fully load balanced
3244          * (unless the charge distribution is inhomogeneous).
3245          */
3246
3247         imbal = pme_load_imbalance(pme);
3248         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3249         {
3250             fprintf(stderr,
3251                     "\n"
3252                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3253                     "      For optimal PME load balancing\n"
3254                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3255                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3256                     "\n",
3257                     (int)((imbal-1)*100 + 0.5),
3258                     pme->nkx, pme->nky, pme->nnodes_major,
3259                     pme->nky, pme->nkz, pme->nnodes_minor);
3260         }
3261     }
3262
3263     /* For non-divisible grid we need pme_order iso pme_order-1 */
3264     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3265      * y is always copied through a buffer: we don't need padding in z,
3266      * but we do need the overlap in x because of the communication order.
3267      */
3268     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3269 #ifdef GMX_MPI
3270                       pme->mpi_comm_d[0],
3271 #endif
3272                       pme->nnodes_major, pme->nodeid_major,
3273                       pme->nkx,
3274                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3275
3276     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3277      * We do this with an offset buffer of equal size, so we need to allocate
3278      * extra for the offset. That's what the (+1)*pme->nkz is for.
3279      */
3280     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3281 #ifdef GMX_MPI
3282                       pme->mpi_comm_d[1],
3283 #endif
3284                       pme->nnodes_minor, pme->nodeid_minor,
3285                       pme->nky,
3286                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3287
3288     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3289      * We only allow multiple communication pulses in dim 1, not in dim 0.
3290      */
3291     if (pme->bUseThreads && (pme->overlap[0].noverlap_nodes > 1 ||
3292                              pme->nkx < pme->nnodes_major*pme->pme_order))
3293     {
3294         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x and should be >= pme_order (%d). To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3295                   pme->nkx/(double)pme->nnodes_major, pme->pme_order);
3296     }
3297
3298     snew(pme->bsp_mod[XX], pme->nkx);
3299     snew(pme->bsp_mod[YY], pme->nky);
3300     snew(pme->bsp_mod[ZZ], pme->nkz);
3301
3302     /* The required size of the interpolation grid, including overlap.
3303      * The allocated size (pmegrid_n?) might be slightly larger.
3304      */
3305     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3306         pme->overlap[0].s2g0[pme->nodeid_major];
3307     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3308         pme->overlap[1].s2g0[pme->nodeid_minor];
3309     pme->pmegrid_nz_base = pme->nkz;
3310     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3311     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3312
3313     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3314     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3315     pme->pmegrid_start_iz = 0;
3316
3317     make_gridindex5_to_localindex(pme->nkx,
3318                                   pme->pmegrid_start_ix,
3319                                   pme->pmegrid_nx - (pme->pme_order-1),
3320                                   &pme->nnx, &pme->fshx);
3321     make_gridindex5_to_localindex(pme->nky,
3322                                   pme->pmegrid_start_iy,
3323                                   pme->pmegrid_ny - (pme->pme_order-1),
3324                                   &pme->nny, &pme->fshy);
3325     make_gridindex5_to_localindex(pme->nkz,
3326                                   pme->pmegrid_start_iz,
3327                                   pme->pmegrid_nz_base,
3328                                   &pme->nnz, &pme->fshz);
3329
3330     pmegrids_init(&pme->pmegridA,
3331                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3332                   pme->pmegrid_nz_base,
3333                   pme->pme_order,
3334                   pme->bUseThreads,
3335                   pme->nthread,
3336                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3337                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3338
3339     pme->spline_work = make_pme_spline_work(pme->pme_order);
3340
3341     ndata[0] = pme->nkx;
3342     ndata[1] = pme->nky;
3343     ndata[2] = pme->nkz;
3344
3345     /* This routine will allocate the grid data to fit the FFTs */
3346     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3347                             &pme->fftgridA, &pme->cfftgridA,
3348                             pme->mpi_comm_d,
3349                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3350                             bReproducible, pme->nthread);
3351
3352     if (bFreeEnergy)
3353     {
3354         pmegrids_init(&pme->pmegridB,
3355                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3356                       pme->pmegrid_nz_base,
3357                       pme->pme_order,
3358                       pme->bUseThreads,
3359                       pme->nthread,
3360                       pme->nkx % pme->nnodes_major != 0,
3361                       pme->nky % pme->nnodes_minor != 0);
3362
3363         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3364                                 &pme->fftgridB, &pme->cfftgridB,
3365                                 pme->mpi_comm_d,
3366                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3367                                 bReproducible, pme->nthread);
3368     }
3369     else
3370     {
3371         pme->pmegridB.grid.grid = NULL;
3372         pme->fftgridB           = NULL;
3373         pme->cfftgridB          = NULL;
3374     }
3375
3376     if (!pme->bP3M)
3377     {
3378         /* Use plain SPME B-spline interpolation */
3379         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3380     }
3381     else
3382     {
3383         /* Use the P3M grid-optimized influence function */
3384         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3385     }
3386
3387     /* Use atc[0] for spreading */
3388     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3389     if (pme->ndecompdim >= 2)
3390     {
3391         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3392     }
3393
3394     if (pme->nnodes == 1)
3395     {
3396         pme->atc[0].n = homenr;
3397         pme_realloc_atomcomm_things(&pme->atc[0]);
3398     }
3399
3400     {
3401         int thread;
3402
3403         /* Use fft5d, order after FFT is y major, z, x minor */
3404
3405         snew(pme->work, pme->nthread);
3406         for (thread = 0; thread < pme->nthread; thread++)
3407         {
3408             realloc_work(&pme->work[thread], pme->nkx);
3409         }
3410     }
3411
3412     *pmedata = pme;
3413
3414     return 0;
3415 }
3416
3417 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3418 {
3419     int d, t;
3420
3421     for (d = 0; d < DIM; d++)
3422     {
3423         if (new->grid.n[d] > old->grid.n[d])
3424         {
3425             return;
3426         }
3427     }
3428
3429     sfree_aligned(new->grid.grid);
3430     new->grid.grid = old->grid.grid;
3431
3432     if (new->grid_th != NULL && new->nthread == old->nthread)
3433     {
3434         sfree_aligned(new->grid_all);
3435         for (t = 0; t < new->nthread; t++)
3436         {
3437             new->grid_th[t].grid = old->grid_th[t].grid;
3438         }
3439     }
3440 }
3441
3442 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3443                    t_commrec *         cr,
3444                    gmx_pme_t           pme_src,
3445                    const t_inputrec *  ir,
3446                    ivec                grid_size)
3447 {
3448     t_inputrec irc;
3449     int homenr;
3450     int ret;
3451
3452     irc     = *ir;
3453     irc.nkx = grid_size[XX];
3454     irc.nky = grid_size[YY];
3455     irc.nkz = grid_size[ZZ];
3456
3457     if (pme_src->nnodes == 1)
3458     {
3459         homenr = pme_src->atc[0].n;
3460     }
3461     else
3462     {
3463         homenr = -1;
3464     }
3465
3466     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3467                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3468
3469     if (ret == 0)
3470     {
3471         /* We can easily reuse the allocated pme grids in pme_src */
3472         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3473         /* We would like to reuse the fft grids, but that's harder */
3474     }
3475
3476     return ret;
3477 }
3478
3479
3480 static void copy_local_grid(gmx_pme_t pme,
3481                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3482 {
3483     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3484     int  fft_my, fft_mz;
3485     int  nsx, nsy, nsz;
3486     ivec nf;
3487     int  offx, offy, offz, x, y, z, i0, i0t;
3488     int  d;
3489     pmegrid_t *pmegrid;
3490     real *grid_th;
3491
3492     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3493                                    local_fft_ndata,
3494                                    local_fft_offset,
3495                                    local_fft_size);
3496     fft_my = local_fft_size[YY];
3497     fft_mz = local_fft_size[ZZ];
3498
3499     pmegrid = &pmegrids->grid_th[thread];
3500
3501     nsx = pmegrid->s[XX];
3502     nsy = pmegrid->s[YY];
3503     nsz = pmegrid->s[ZZ];
3504
3505     for (d = 0; d < DIM; d++)
3506     {
3507         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3508                     local_fft_ndata[d] - pmegrid->offset[d]);
3509     }
3510
3511     offx = pmegrid->offset[XX];
3512     offy = pmegrid->offset[YY];
3513     offz = pmegrid->offset[ZZ];
3514
3515     /* Directly copy the non-overlapping parts of the local grids.
3516      * This also initializes the full grid.
3517      */
3518     grid_th = pmegrid->grid;
3519     for (x = 0; x < nf[XX]; x++)
3520     {
3521         for (y = 0; y < nf[YY]; y++)
3522         {
3523             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3524             i0t = (x*nsy + y)*nsz;
3525             for (z = 0; z < nf[ZZ]; z++)
3526             {
3527                 fftgrid[i0+z] = grid_th[i0t+z];
3528             }
3529         }
3530     }
3531 }
3532
3533 static void
3534 reduce_threadgrid_overlap(gmx_pme_t pme,
3535                           const pmegrids_t *pmegrids, int thread,
3536                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3537 {
3538     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3539     int  fft_nx, fft_ny, fft_nz;
3540     int  fft_my, fft_mz;
3541     int  buf_my = -1;
3542     int  nsx, nsy, nsz;
3543     ivec ne;
3544     int  offx, offy, offz, x, y, z, i0, i0t;
3545     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3546     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3547     gmx_bool bCommX, bCommY;
3548     int  d;
3549     int  thread_f;
3550     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3551     const real *grid_th;
3552     real *commbuf = NULL;
3553
3554     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3555                                    local_fft_ndata,
3556                                    local_fft_offset,
3557                                    local_fft_size);
3558     fft_nx = local_fft_ndata[XX];
3559     fft_ny = local_fft_ndata[YY];
3560     fft_nz = local_fft_ndata[ZZ];
3561
3562     fft_my = local_fft_size[YY];
3563     fft_mz = local_fft_size[ZZ];
3564
3565     /* This routine is called when all thread have finished spreading.
3566      * Here each thread sums grid contributions calculated by other threads
3567      * to the thread local grid volume.
3568      * To minimize the number of grid copying operations,
3569      * this routines sums immediately from the pmegrid to the fftgrid.
3570      */
3571
3572     /* Determine which part of the full node grid we should operate on,
3573      * this is our thread local part of the full grid.
3574      */
3575     pmegrid = &pmegrids->grid_th[thread];
3576
3577     for (d = 0; d < DIM; d++)
3578     {
3579         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3580                     local_fft_ndata[d]);
3581     }
3582
3583     offx = pmegrid->offset[XX];
3584     offy = pmegrid->offset[YY];
3585     offz = pmegrid->offset[ZZ];
3586
3587
3588     bClearBufX  = TRUE;
3589     bClearBufY  = TRUE;
3590     bClearBufXY = TRUE;
3591
3592     /* Now loop over all the thread data blocks that contribute
3593      * to the grid region we (our thread) are operating on.
3594      */
3595     /* Note that ffy_nx/y is equal to the number of grid points
3596      * between the first point of our node grid and the one of the next node.
3597      */
3598     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3599     {
3600         fx     = pmegrid->ci[XX] + sx;
3601         ox     = 0;
3602         bCommX = FALSE;
3603         if (fx < 0)
3604         {
3605             fx    += pmegrids->nc[XX];
3606             ox    -= fft_nx;
3607             bCommX = (pme->nnodes_major > 1);
3608         }
3609         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3610         ox       += pmegrid_g->offset[XX];
3611         if (!bCommX)
3612         {
3613             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3614         }
3615         else
3616         {
3617             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3618         }
3619
3620         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3621         {
3622             fy     = pmegrid->ci[YY] + sy;
3623             oy     = 0;
3624             bCommY = FALSE;
3625             if (fy < 0)
3626             {
3627                 fy    += pmegrids->nc[YY];
3628                 oy    -= fft_ny;
3629                 bCommY = (pme->nnodes_minor > 1);
3630             }
3631             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3632             oy       += pmegrid_g->offset[YY];
3633             if (!bCommY)
3634             {
3635                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3636             }
3637             else
3638             {
3639                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3640             }
3641
3642             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3643             {
3644                 fz = pmegrid->ci[ZZ] + sz;
3645                 oz = 0;
3646                 if (fz < 0)
3647                 {
3648                     fz += pmegrids->nc[ZZ];
3649                     oz -= fft_nz;
3650                 }
3651                 pmegrid_g = &pmegrids->grid_th[fz];
3652                 oz       += pmegrid_g->offset[ZZ];
3653                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3654
3655                 if (sx == 0 && sy == 0 && sz == 0)
3656                 {
3657                     /* We have already added our local contribution
3658                      * before calling this routine, so skip it here.
3659                      */
3660                     continue;
3661                 }
3662
3663                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3664
3665                 pmegrid_f = &pmegrids->grid_th[thread_f];
3666
3667                 grid_th = pmegrid_f->grid;
3668
3669                 nsx = pmegrid_f->s[XX];
3670                 nsy = pmegrid_f->s[YY];
3671                 nsz = pmegrid_f->s[ZZ];
3672
3673 #ifdef DEBUG_PME_REDUCE
3674                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3675                        pme->nodeid, thread, thread_f,
3676                        pme->pmegrid_start_ix,
3677                        pme->pmegrid_start_iy,
3678                        pme->pmegrid_start_iz,
3679                        sx, sy, sz,
3680                        offx-ox, tx1-ox, offx, tx1,
3681                        offy-oy, ty1-oy, offy, ty1,
3682                        offz-oz, tz1-oz, offz, tz1);
3683 #endif
3684
3685                 if (!(bCommX || bCommY))
3686                 {
3687                     /* Copy from the thread local grid to the node grid */
3688                     for (x = offx; x < tx1; x++)
3689                     {
3690                         for (y = offy; y < ty1; y++)
3691                         {
3692                             i0  = (x*fft_my + y)*fft_mz;
3693                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3694                             for (z = offz; z < tz1; z++)
3695                             {
3696                                 fftgrid[i0+z] += grid_th[i0t+z];
3697                             }
3698                         }
3699                     }
3700                 }
3701                 else
3702                 {
3703                     /* The order of this conditional decides
3704                      * where the corner volume gets stored with x+y decomp.
3705                      */
3706                     if (bCommY)
3707                     {
3708                         commbuf = commbuf_y;
3709                         buf_my  = ty1 - offy;
3710                         if (bCommX)
3711                         {
3712                             /* We index commbuf modulo the local grid size */
3713                             commbuf += buf_my*fft_nx*fft_nz;
3714
3715                             bClearBuf   = bClearBufXY;
3716                             bClearBufXY = FALSE;
3717                         }
3718                         else
3719                         {
3720                             bClearBuf  = bClearBufY;
3721                             bClearBufY = FALSE;
3722                         }
3723                     }
3724                     else
3725                     {
3726                         commbuf    = commbuf_x;
3727                         buf_my     = fft_ny;
3728                         bClearBuf  = bClearBufX;
3729                         bClearBufX = FALSE;
3730                     }
3731
3732                     /* Copy to the communication buffer */
3733                     for (x = offx; x < tx1; x++)
3734                     {
3735                         for (y = offy; y < ty1; y++)
3736                         {
3737                             i0  = (x*buf_my + y)*fft_nz;
3738                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3739
3740                             if (bClearBuf)
3741                             {
3742                                 /* First access of commbuf, initialize it */
3743                                 for (z = offz; z < tz1; z++)
3744                                 {
3745                                     commbuf[i0+z]  = grid_th[i0t+z];
3746                                 }
3747                             }
3748                             else
3749                             {
3750                                 for (z = offz; z < tz1; z++)
3751                                 {
3752                                     commbuf[i0+z] += grid_th[i0t+z];
3753                                 }
3754                             }
3755                         }
3756                     }
3757                 }
3758             }
3759         }
3760     }
3761 }
3762
3763
3764 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3765 {
3766     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3767     pme_overlap_t *overlap;
3768     int  send_index0, send_nindex;
3769     int  recv_nindex;
3770 #ifdef GMX_MPI
3771     MPI_Status stat;
3772 #endif
3773     int  send_size_y, recv_size_y;
3774     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3775     real *sendptr, *recvptr;
3776     int  x, y, z, indg, indb;
3777
3778     /* Note that this routine is only used for forward communication.
3779      * Since the force gathering, unlike the charge spreading,
3780      * can be trivially parallelized over the particles,
3781      * the backwards process is much simpler and can use the "old"
3782      * communication setup.
3783      */
3784
3785     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3786                                    local_fft_ndata,
3787                                    local_fft_offset,
3788                                    local_fft_size);
3789
3790     if (pme->nnodes_minor > 1)
3791     {
3792         /* Major dimension */
3793         overlap = &pme->overlap[1];
3794
3795         if (pme->nnodes_major > 1)
3796         {
3797             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3798         }
3799         else
3800         {
3801             size_yx = 0;
3802         }
3803         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3804
3805         send_size_y = overlap->send_size;
3806
3807         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3808         {
3809             send_id       = overlap->send_id[ipulse];
3810             recv_id       = overlap->recv_id[ipulse];
3811             send_index0   =
3812                 overlap->comm_data[ipulse].send_index0 -
3813                 overlap->comm_data[0].send_index0;
3814             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3815             /* We don't use recv_index0, as we always receive starting at 0 */
3816             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3817             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3818
3819             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3820             recvptr = overlap->recvbuf;
3821
3822 #ifdef GMX_MPI
3823             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3824                          send_id, ipulse,
3825                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3826                          recv_id, ipulse,
3827                          overlap->mpi_comm, &stat);
3828 #endif
3829
3830             for (x = 0; x < local_fft_ndata[XX]; x++)
3831             {
3832                 for (y = 0; y < recv_nindex; y++)
3833                 {
3834                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3835                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3836                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3837                     {
3838                         fftgrid[indg+z] += recvptr[indb+z];
3839                     }
3840                 }
3841             }
3842
3843             if (pme->nnodes_major > 1)
3844             {
3845                 /* Copy from the received buffer to the send buffer for dim 0 */
3846                 sendptr = pme->overlap[0].sendbuf;
3847                 for (x = 0; x < size_yx; x++)
3848                 {
3849                     for (y = 0; y < recv_nindex; y++)
3850                     {
3851                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3852                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3853                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3854                         {
3855                             sendptr[indg+z] += recvptr[indb+z];
3856                         }
3857                     }
3858                 }
3859             }
3860         }
3861     }
3862
3863     /* We only support a single pulse here.
3864      * This is not a severe limitation, as this code is only used
3865      * with OpenMP and with OpenMP the (PME) domains can be larger.
3866      */
3867     if (pme->nnodes_major > 1)
3868     {
3869         /* Major dimension */
3870         overlap = &pme->overlap[0];
3871
3872         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3873         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3874
3875         ipulse = 0;
3876
3877         send_id       = overlap->send_id[ipulse];
3878         recv_id       = overlap->recv_id[ipulse];
3879         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3880         /* We don't use recv_index0, as we always receive starting at 0 */
3881         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3882
3883         sendptr = overlap->sendbuf;
3884         recvptr = overlap->recvbuf;
3885
3886         if (debug != NULL)
3887         {
3888             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3889                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3890         }
3891
3892 #ifdef GMX_MPI
3893         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3894                      send_id, ipulse,
3895                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3896                      recv_id, ipulse,
3897                      overlap->mpi_comm, &stat);
3898 #endif
3899
3900         for (x = 0; x < recv_nindex; x++)
3901         {
3902             for (y = 0; y < local_fft_ndata[YY]; y++)
3903             {
3904                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3905                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3906                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3907                 {
3908                     fftgrid[indg+z] += recvptr[indb+z];
3909                 }
3910             }
3911         }
3912     }
3913 }
3914
3915
3916 static void spread_on_grid(gmx_pme_t pme,
3917                            pme_atomcomm_t *atc, pmegrids_t *grids,
3918                            gmx_bool bCalcSplines, gmx_bool bSpread,
3919                            real *fftgrid)
3920 {
3921     int nthread, thread;
3922 #ifdef PME_TIME_THREADS
3923     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3924     static double cs1     = 0, cs2 = 0, cs3 = 0;
3925     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3926     static int cnt        = 0;
3927 #endif
3928
3929     nthread = pme->nthread;
3930     assert(nthread > 0);
3931
3932 #ifdef PME_TIME_THREADS
3933     c1 = omp_cyc_start();
3934 #endif
3935     if (bCalcSplines)
3936     {
3937 #pragma omp parallel for num_threads(nthread) schedule(static)
3938         for (thread = 0; thread < nthread; thread++)
3939         {
3940             int start, end;
3941
3942             start = atc->n* thread   /nthread;
3943             end   = atc->n*(thread+1)/nthread;
3944
3945             /* Compute fftgrid index for all atoms,
3946              * with help of some extra variables.
3947              */
3948             calc_interpolation_idx(pme, atc, start, end, thread);
3949         }
3950     }
3951 #ifdef PME_TIME_THREADS
3952     c1   = omp_cyc_end(c1);
3953     cs1 += (double)c1;
3954 #endif
3955
3956 #ifdef PME_TIME_THREADS
3957     c2 = omp_cyc_start();
3958 #endif
3959 #pragma omp parallel for num_threads(nthread) schedule(static)
3960     for (thread = 0; thread < nthread; thread++)
3961     {
3962         splinedata_t *spline;
3963         pmegrid_t *grid = NULL;
3964
3965         /* make local bsplines  */
3966         if (grids == NULL || !pme->bUseThreads)
3967         {
3968             spline = &atc->spline[0];
3969
3970             spline->n = atc->n;
3971
3972             if (bSpread)
3973             {
3974                 grid = &grids->grid;
3975             }
3976         }
3977         else
3978         {
3979             spline = &atc->spline[thread];
3980
3981             if (grids->nthread == 1)
3982             {
3983                 /* One thread, we operate on all charges */
3984                 spline->n = atc->n;
3985             }
3986             else
3987             {
3988                 /* Get the indices our thread should operate on */
3989                 make_thread_local_ind(atc, thread, spline);
3990             }
3991
3992             grid = &grids->grid_th[thread];
3993         }
3994
3995         if (bCalcSplines)
3996         {
3997             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
3998                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
3999         }
4000
4001         if (bSpread)
4002         {
4003             /* put local atoms on grid. */
4004 #ifdef PME_TIME_SPREAD
4005             ct1a = omp_cyc_start();
4006 #endif
4007             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4008
4009             if (pme->bUseThreads)
4010             {
4011                 copy_local_grid(pme, grids, thread, fftgrid);
4012             }
4013 #ifdef PME_TIME_SPREAD
4014             ct1a          = omp_cyc_end(ct1a);
4015             cs1a[thread] += (double)ct1a;
4016 #endif
4017         }
4018     }
4019 #ifdef PME_TIME_THREADS
4020     c2   = omp_cyc_end(c2);
4021     cs2 += (double)c2;
4022 #endif
4023
4024     if (bSpread && pme->bUseThreads)
4025     {
4026 #ifdef PME_TIME_THREADS
4027         c3 = omp_cyc_start();
4028 #endif
4029 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4030         for (thread = 0; thread < grids->nthread; thread++)
4031         {
4032             reduce_threadgrid_overlap(pme, grids, thread,
4033                                       fftgrid,
4034                                       pme->overlap[0].sendbuf,
4035                                       pme->overlap[1].sendbuf);
4036         }
4037 #ifdef PME_TIME_THREADS
4038         c3   = omp_cyc_end(c3);
4039         cs3 += (double)c3;
4040 #endif
4041
4042         if (pme->nnodes > 1)
4043         {
4044             /* Communicate the overlapping part of the fftgrid.
4045              * For this communication call we need to check pme->bUseThreads
4046              * to have all ranks communicate here, regardless of pme->nthread.
4047              */
4048             sum_fftgrid_dd(pme, fftgrid);
4049         }
4050     }
4051
4052 #ifdef PME_TIME_THREADS
4053     cnt++;
4054     if (cnt % 20 == 0)
4055     {
4056         printf("idx %.2f spread %.2f red %.2f",
4057                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4058 #ifdef PME_TIME_SPREAD
4059         for (thread = 0; thread < nthread; thread++)
4060         {
4061             printf(" %.2f", cs1a[thread]*1e-9);
4062         }
4063 #endif
4064         printf("\n");
4065     }
4066 #endif
4067 }
4068
4069
4070 static void dump_grid(FILE *fp,
4071                       int sx, int sy, int sz, int nx, int ny, int nz,
4072                       int my, int mz, const real *g)
4073 {
4074     int x, y, z;
4075
4076     for (x = 0; x < nx; x++)
4077     {
4078         for (y = 0; y < ny; y++)
4079         {
4080             for (z = 0; z < nz; z++)
4081             {
4082                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4083                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4084             }
4085         }
4086     }
4087 }
4088
4089 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4090 {
4091     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4092
4093     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4094                                    local_fft_ndata,
4095                                    local_fft_offset,
4096                                    local_fft_size);
4097
4098     dump_grid(stderr,
4099               pme->pmegrid_start_ix,
4100               pme->pmegrid_start_iy,
4101               pme->pmegrid_start_iz,
4102               pme->pmegrid_nx-pme->pme_order+1,
4103               pme->pmegrid_ny-pme->pme_order+1,
4104               pme->pmegrid_nz-pme->pme_order+1,
4105               local_fft_size[YY],
4106               local_fft_size[ZZ],
4107               fftgrid);
4108 }
4109
4110
4111 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4112 {
4113     pme_atomcomm_t *atc;
4114     pmegrids_t *grid;
4115
4116     if (pme->nnodes > 1)
4117     {
4118         gmx_incons("gmx_pme_calc_energy called in parallel");
4119     }
4120     if (pme->bFEP > 1)
4121     {
4122         gmx_incons("gmx_pme_calc_energy with free energy");
4123     }
4124
4125     atc            = &pme->atc_energy;
4126     atc->nthread   = 1;
4127     if (atc->spline == NULL)
4128     {
4129         snew(atc->spline, atc->nthread);
4130     }
4131     atc->nslab     = 1;
4132     atc->bSpread   = TRUE;
4133     atc->pme_order = pme->pme_order;
4134     atc->n         = n;
4135     pme_realloc_atomcomm_things(atc);
4136     atc->x         = x;
4137     atc->q         = q;
4138
4139     /* We only use the A-charges grid */
4140     grid = &pme->pmegridA;
4141
4142     /* Only calculate the spline coefficients, don't actually spread */
4143     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4144
4145     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4146 }
4147
4148
4149 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4150                                    t_nrnb *nrnb, t_inputrec *ir,
4151                                    gmx_large_int_t step)
4152 {
4153     /* Reset all the counters related to performance over the run */
4154     wallcycle_stop(wcycle, ewcRUN);
4155     wallcycle_reset_all(wcycle);
4156     init_nrnb(nrnb);
4157     if (ir->nsteps >= 0)
4158     {
4159         /* ir->nsteps is not used here, but we update it for consistency */
4160         ir->nsteps -= step - ir->init_step;
4161     }
4162     ir->init_step = step;
4163     wallcycle_start(wcycle, ewcRUN);
4164 }
4165
4166
4167 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4168                                ivec grid_size,
4169                                t_commrec *cr, t_inputrec *ir,
4170                                gmx_pme_t *pme_ret)
4171 {
4172     int ind;
4173     gmx_pme_t pme = NULL;
4174
4175     ind = 0;
4176     while (ind < *npmedata)
4177     {
4178         pme = (*pmedata)[ind];
4179         if (pme->nkx == grid_size[XX] &&
4180             pme->nky == grid_size[YY] &&
4181             pme->nkz == grid_size[ZZ])
4182         {
4183             *pme_ret = pme;
4184
4185             return;
4186         }
4187
4188         ind++;
4189     }
4190
4191     (*npmedata)++;
4192     srenew(*pmedata, *npmedata);
4193
4194     /* Generate a new PME data structure, copying part of the old pointers */
4195     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4196
4197     *pme_ret = (*pmedata)[ind];
4198 }
4199
4200
4201 int gmx_pmeonly(gmx_pme_t pme,
4202                 t_commrec *cr,    t_nrnb *nrnb,
4203                 gmx_wallcycle_t wcycle,
4204                 real ewaldcoeff,  gmx_bool bGatherOnly,
4205                 t_inputrec *ir)
4206 {
4207     int npmedata;
4208     gmx_pme_t *pmedata;
4209     gmx_pme_pp_t pme_pp;
4210     int  ret;
4211     int  natoms;
4212     matrix box;
4213     rvec *x_pp      = NULL, *f_pp = NULL;
4214     real *chargeA   = NULL, *chargeB = NULL;
4215     real lambda     = 0;
4216     int  maxshift_x = 0, maxshift_y = 0;
4217     real energy, dvdlambda;
4218     matrix vir;
4219     float cycles;
4220     int  count;
4221     gmx_bool bEnerVir;
4222     gmx_large_int_t step, step_rel;
4223     ivec grid_switch;
4224
4225     /* This data will only use with PME tuning, i.e. switching PME grids */
4226     npmedata = 1;
4227     snew(pmedata, npmedata);
4228     pmedata[0] = pme;
4229
4230     pme_pp = gmx_pme_pp_init(cr);
4231
4232     init_nrnb(nrnb);
4233
4234     count = 0;
4235     do /****** this is a quasi-loop over time steps! */
4236     {
4237         /* The reason for having a loop here is PME grid tuning/switching */
4238         do
4239         {
4240             /* Domain decomposition */
4241             ret = gmx_pme_recv_q_x(pme_pp,
4242                                    &natoms,
4243                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4244                                    &maxshift_x, &maxshift_y,
4245                                    &pme->bFEP, &lambda,
4246                                    &bEnerVir,
4247                                    &step,
4248                                    grid_switch, &ewaldcoeff);
4249
4250             if (ret == pmerecvqxSWITCHGRID)
4251             {
4252                 /* Switch the PME grid to grid_switch */
4253                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4254             }
4255
4256             if (ret == pmerecvqxRESETCOUNTERS)
4257             {
4258                 /* Reset the cycle and flop counters */
4259                 reset_pmeonly_counters(cr, wcycle, nrnb, ir, step);
4260             }
4261         }
4262         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4263
4264         if (ret == pmerecvqxFINISH)
4265         {
4266             /* We should stop: break out of the loop */
4267             break;
4268         }
4269
4270         step_rel = step - ir->init_step;
4271
4272         if (count == 0)
4273         {
4274             wallcycle_start(wcycle, ewcRUN);
4275         }
4276
4277         wallcycle_start(wcycle, ewcPMEMESH);
4278
4279         dvdlambda = 0;
4280         clear_mat(vir);
4281         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4282                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4283                    &energy, lambda, &dvdlambda,
4284                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4285
4286         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4287
4288         gmx_pme_send_force_vir_ener(pme_pp,
4289                                     f_pp, vir, energy, dvdlambda,
4290                                     cycles);
4291
4292         count++;
4293     } /***** end of quasi-loop, we stop with the break above */
4294     while (TRUE);
4295
4296     return 0;
4297 }
4298
4299 int gmx_pme_do(gmx_pme_t pme,
4300                int start,       int homenr,
4301                rvec x[],        rvec f[],
4302                real *chargeA,   real *chargeB,
4303                matrix box, t_commrec *cr,
4304                int  maxshift_x, int maxshift_y,
4305                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4306                matrix vir,      real ewaldcoeff,
4307                real *energy,    real lambda,
4308                real *dvdlambda, int flags)
4309 {
4310     int     q, d, i, j, ntot, npme;
4311     int     nx, ny, nz;
4312     int     n_d, local_ny;
4313     pme_atomcomm_t *atc = NULL;
4314     pmegrids_t *pmegrid = NULL;
4315     real    *grid       = NULL;
4316     real    *ptr;
4317     rvec    *x_d, *f_d;
4318     real    *charge = NULL, *q_d;
4319     real    energy_AB[2];
4320     matrix  vir_AB[2];
4321     gmx_bool bClearF;
4322     gmx_parallel_3dfft_t pfft_setup;
4323     real *  fftgrid;
4324     t_complex * cfftgrid;
4325     int     thread;
4326     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4327     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4328
4329     assert(pme->nnodes > 0);
4330     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4331
4332     if (pme->nnodes > 1)
4333     {
4334         atc      = &pme->atc[0];
4335         atc->npd = homenr;
4336         if (atc->npd > atc->pd_nalloc)
4337         {
4338             atc->pd_nalloc = over_alloc_dd(atc->npd);
4339             srenew(atc->pd, atc->pd_nalloc);
4340         }
4341         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4342     }
4343     else
4344     {
4345         /* This could be necessary for TPI */
4346         pme->atc[0].n = homenr;
4347     }
4348
4349     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4350     {
4351         if (q == 0)
4352         {
4353             pmegrid    = &pme->pmegridA;
4354             fftgrid    = pme->fftgridA;
4355             cfftgrid   = pme->cfftgridA;
4356             pfft_setup = pme->pfft_setupA;
4357             charge     = chargeA+start;
4358         }
4359         else
4360         {
4361             pmegrid    = &pme->pmegridB;
4362             fftgrid    = pme->fftgridB;
4363             cfftgrid   = pme->cfftgridB;
4364             pfft_setup = pme->pfft_setupB;
4365             charge     = chargeB+start;
4366         }
4367         grid = pmegrid->grid.grid;
4368         /* Unpack structure */
4369         if (debug)
4370         {
4371             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4372                     cr->nnodes, cr->nodeid);
4373             fprintf(debug, "Grid = %p\n", (void*)grid);
4374             if (grid == NULL)
4375             {
4376                 gmx_fatal(FARGS, "No grid!");
4377             }
4378         }
4379         where();
4380
4381         m_inv_ur0(box, pme->recipbox);
4382
4383         if (pme->nnodes == 1)
4384         {
4385             atc = &pme->atc[0];
4386             if (DOMAINDECOMP(cr))
4387             {
4388                 atc->n = homenr;
4389                 pme_realloc_atomcomm_things(atc);
4390             }
4391             atc->x = x;
4392             atc->q = charge;
4393             atc->f = f;
4394         }
4395         else
4396         {
4397             wallcycle_start(wcycle, ewcPME_REDISTXF);
4398             for (d = pme->ndecompdim-1; d >= 0; d--)
4399             {
4400                 if (d == pme->ndecompdim-1)
4401                 {
4402                     n_d = homenr;
4403                     x_d = x + start;
4404                     q_d = charge;
4405                 }
4406                 else
4407                 {
4408                     n_d = pme->atc[d+1].n;
4409                     x_d = atc->x;
4410                     q_d = atc->q;
4411                 }
4412                 atc      = &pme->atc[d];
4413                 atc->npd = n_d;
4414                 if (atc->npd > atc->pd_nalloc)
4415                 {
4416                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4417                     srenew(atc->pd, atc->pd_nalloc);
4418                 }
4419                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4420                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4421                 where();
4422
4423                 GMX_BARRIER(cr->mpi_comm_mygroup);
4424                 /* Redistribute x (only once) and qA or qB */
4425                 if (DOMAINDECOMP(cr))
4426                 {
4427                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4428                 }
4429                 else
4430                 {
4431                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4432                 }
4433             }
4434             where();
4435
4436             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4437         }
4438
4439         if (debug)
4440         {
4441             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4442                     cr->nodeid, atc->n);
4443         }
4444
4445         if (flags & GMX_PME_SPREAD_Q)
4446         {
4447             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4448
4449             /* Spread the charges on a grid */
4450             GMX_MPE_LOG(ev_spread_on_grid_start);
4451
4452             /* Spread the charges on a grid */
4453             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4454             GMX_MPE_LOG(ev_spread_on_grid_finish);
4455
4456             if (q == 0)
4457             {
4458                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4459             }
4460             inc_nrnb(nrnb, eNR_SPREADQBSP,
4461                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4462
4463             if (!pme->bUseThreads)
4464             {
4465                 wrap_periodic_pmegrid(pme, grid);
4466
4467                 /* sum contributions to local grid from other nodes */
4468 #ifdef GMX_MPI
4469                 if (pme->nnodes > 1)
4470                 {
4471                     GMX_BARRIER(cr->mpi_comm_mygroup);
4472                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4473                     where();
4474                 }
4475 #endif
4476
4477                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4478             }
4479
4480             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4481
4482             /*
4483                dump_local_fftgrid(pme,fftgrid);
4484                exit(0);
4485              */
4486         }
4487
4488         /* Here we start a large thread parallel region */
4489 #pragma omp parallel num_threads(pme->nthread) private(thread)
4490         {
4491             thread = gmx_omp_get_thread_num();
4492             if (flags & GMX_PME_SOLVE)
4493             {
4494                 int loop_count;
4495
4496                 /* do 3d-fft */
4497                 if (thread == 0)
4498                 {
4499                     GMX_BARRIER(cr->mpi_comm_mygroup);
4500                     GMX_MPE_LOG(ev_gmxfft3d_start);
4501                     wallcycle_start(wcycle, ewcPME_FFT);
4502                 }
4503                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4504                                            fftgrid, cfftgrid, thread, wcycle);
4505                 if (thread == 0)
4506                 {
4507                     wallcycle_stop(wcycle, ewcPME_FFT);
4508                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4509                 }
4510                 where();
4511
4512                 /* solve in k-space for our local cells */
4513                 if (thread == 0)
4514                 {
4515                     GMX_BARRIER(cr->mpi_comm_mygroup);
4516                     GMX_MPE_LOG(ev_solve_pme_start);
4517                     wallcycle_start(wcycle, ewcPME_SOLVE);
4518                 }
4519                 loop_count =
4520                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4521                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4522                                   bCalcEnerVir,
4523                                   pme->nthread, thread);
4524                 if (thread == 0)
4525                 {
4526                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4527                     where();
4528                     GMX_MPE_LOG(ev_solve_pme_finish);
4529                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4530                 }
4531             }
4532
4533             if (bCalcF)
4534             {
4535                 /* do 3d-invfft */
4536                 if (thread == 0)
4537                 {
4538                     GMX_BARRIER(cr->mpi_comm_mygroup);
4539                     GMX_MPE_LOG(ev_gmxfft3d_start);
4540                     where();
4541                     wallcycle_start(wcycle, ewcPME_FFT);
4542                 }
4543                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4544                                            cfftgrid, fftgrid, thread, wcycle);
4545                 if (thread == 0)
4546                 {
4547                     wallcycle_stop(wcycle, ewcPME_FFT);
4548
4549                     where();
4550                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4551
4552                     if (pme->nodeid == 0)
4553                     {
4554                         ntot  = pme->nkx*pme->nky*pme->nkz;
4555                         npme  = ntot*log((real)ntot)/log(2.0);
4556                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4557                     }
4558
4559                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4560                 }
4561
4562                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4563             }
4564         }
4565         /* End of thread parallel section.
4566          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4567          */
4568
4569         if (bCalcF)
4570         {
4571             /* distribute local grid to all nodes */
4572 #ifdef GMX_MPI
4573             if (pme->nnodes > 1)
4574             {
4575                 GMX_BARRIER(cr->mpi_comm_mygroup);
4576                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4577             }
4578 #endif
4579             where();
4580
4581             unwrap_periodic_pmegrid(pme, grid);
4582
4583             /* interpolate forces for our local atoms */
4584             GMX_BARRIER(cr->mpi_comm_mygroup);
4585             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4586
4587             where();
4588
4589             /* If we are running without parallelization,
4590              * atc->f is the actual force array, not a buffer,
4591              * therefore we should not clear it.
4592              */
4593             bClearF = (q == 0 && PAR(cr));
4594 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4595             for (thread = 0; thread < pme->nthread; thread++)
4596             {
4597                 gather_f_bsplines(pme, grid, bClearF, atc,
4598                                   &atc->spline[thread],
4599                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4600             }
4601
4602             where();
4603
4604             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4605
4606             inc_nrnb(nrnb, eNR_GATHERFBSP,
4607                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4608             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4609         }
4610
4611         if (bCalcEnerVir)
4612         {
4613             /* This should only be called on the master thread
4614              * and after the threads have synchronized.
4615              */
4616             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4617         }
4618     } /* of q-loop */
4619
4620     if (bCalcF && pme->nnodes > 1)
4621     {
4622         wallcycle_start(wcycle, ewcPME_REDISTXF);
4623         for (d = 0; d < pme->ndecompdim; d++)
4624         {
4625             atc = &pme->atc[d];
4626             if (d == pme->ndecompdim - 1)
4627             {
4628                 n_d = homenr;
4629                 f_d = f + start;
4630             }
4631             else
4632             {
4633                 n_d = pme->atc[d+1].n;
4634                 f_d = pme->atc[d+1].f;
4635             }
4636             GMX_BARRIER(cr->mpi_comm_mygroup);
4637             if (DOMAINDECOMP(cr))
4638             {
4639                 dd_pmeredist_f(pme, atc, n_d, f_d,
4640                                d == pme->ndecompdim-1 && pme->bPPnode);
4641             }
4642             else
4643             {
4644                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4645             }
4646         }
4647
4648         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4649     }
4650     where();
4651
4652     if (bCalcEnerVir)
4653     {
4654         if (!pme->bFEP)
4655         {
4656             *energy = energy_AB[0];
4657             m_add(vir, vir_AB[0], vir);
4658         }
4659         else
4660         {
4661             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4662             *dvdlambda += energy_AB[1] - energy_AB[0];
4663             for (i = 0; i < DIM; i++)
4664             {
4665                 for (j = 0; j < DIM; j++)
4666                 {
4667                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4668                         lambda*vir_AB[1][i][j];
4669                 }
4670             }
4671         }
4672     }
4673     else
4674     {
4675         *energy = 0;
4676     }
4677
4678     if (debug)
4679     {
4680         fprintf(debug, "PME mesh energy: %g\n", *energy);
4681     }
4682
4683     return 0;
4684 }