mdlib: clean up -Wunused-parameter warnings
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #include "gromacs/fft/parallel_3dfft.h"
64 #include "gromacs/utility/gmxmpi.h"
65
66 #include <stdio.h>
67 #include <string.h>
68 #include <math.h>
69 #include <assert.h>
70 #include "typedefs.h"
71 #include "txtdump.h"
72 #include "vec.h"
73 #include "gmxcomplex.h"
74 #include "smalloc.h"
75 #include "futil.h"
76 #include "coulomb.h"
77 #include "gmx_fatal.h"
78 #include "pme.h"
79 #include "network.h"
80 #include "physics.h"
81 #include "nrnb.h"
82 #include "gmx_wallcycle.h"
83 #include "pdbio.h"
84 #include "gmx_cyclecounter.h"
85 #include "gmx_omp.h"
86 #include "macros.h"
87
88 /* Single precision, with SSE2 or higher available */
89 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
90
91 #include "gmx_x86_simd_single.h"
92
93 #define PME_SSE
94 /* Some old AMD processors could have problems with unaligned loads+stores */
95 #ifndef GMX_FAHCORE
96 #define PME_SSE_UNALIGNED
97 #endif
98 #endif
99
100 #define DFT_TOL 1e-7
101 /* #define PRT_FORCE */
102 /* conditions for on the fly time-measurement */
103 /* #define TAKETIME (step > 1 && timesteps < 10) */
104 #define TAKETIME FALSE
105
106 /* #define PME_TIME_THREADS */
107
108 #ifdef GMX_DOUBLE
109 #define mpi_type MPI_DOUBLE
110 #else
111 #define mpi_type MPI_FLOAT
112 #endif
113
114 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
115 #define GMX_CACHE_SEP 64
116
117 /* We only define a maximum to be able to use local arrays without allocation.
118  * An order larger than 12 should never be needed, even for test cases.
119  * If needed it can be changed here.
120  */
121 #define PME_ORDER_MAX 12
122
123 /* Internal datastructures */
124 typedef struct {
125     int send_index0;
126     int send_nindex;
127     int recv_index0;
128     int recv_nindex;
129     int recv_size;   /* Receive buffer width, used with OpenMP */
130 } pme_grid_comm_t;
131
132 typedef struct {
133 #ifdef GMX_MPI
134     MPI_Comm         mpi_comm;
135 #endif
136     int              nnodes, nodeid;
137     int             *s2g0;
138     int             *s2g1;
139     int              noverlap_nodes;
140     int             *send_id, *recv_id;
141     int              send_size; /* Send buffer width, used with OpenMP */
142     pme_grid_comm_t *comm_data;
143     real            *sendbuf;
144     real            *recvbuf;
145 } pme_overlap_t;
146
147 typedef struct {
148     int *n;      /* Cumulative counts of the number of particles per thread */
149     int  nalloc; /* Allocation size of i */
150     int *i;      /* Particle indices ordered on thread index (n) */
151 } thread_plist_t;
152
153 typedef struct {
154     int      *thread_one;
155     int       n;
156     int      *ind;
157     splinevec theta;
158     real     *ptr_theta_z;
159     splinevec dtheta;
160     real     *ptr_dtheta_z;
161 } splinedata_t;
162
163 typedef struct {
164     int      dimind;        /* The index of the dimension, 0=x, 1=y */
165     int      nslab;
166     int      nodeid;
167 #ifdef GMX_MPI
168     MPI_Comm mpi_comm;
169 #endif
170
171     int     *node_dest;     /* The nodes to send x and q to with DD */
172     int     *node_src;      /* The nodes to receive x and q from with DD */
173     int     *buf_index;     /* Index for commnode into the buffers */
174
175     int      maxshift;
176
177     int      npd;
178     int      pd_nalloc;
179     int     *pd;
180     int     *count;         /* The number of atoms to send to each node */
181     int    **count_thread;
182     int     *rcount;        /* The number of atoms to receive */
183
184     int      n;
185     int      nalloc;
186     rvec    *x;
187     real    *q;
188     rvec    *f;
189     gmx_bool bSpread;       /* These coordinates are used for spreading */
190     int      pme_order;
191     ivec    *idx;
192     rvec    *fractx;            /* Fractional coordinate relative to the
193                                  * lower cell boundary
194                                  */
195     int             nthread;
196     int            *thread_idx; /* Which thread should spread which charge */
197     thread_plist_t *thread_plist;
198     splinedata_t   *spline;
199 } pme_atomcomm_t;
200
201 #define FLBS  3
202 #define FLBSZ 4
203
204 typedef struct {
205     ivec  ci;     /* The spatial location of this grid         */
206     ivec  n;      /* The used size of *grid, including order-1 */
207     ivec  offset; /* The grid offset from the full node grid   */
208     int   order;  /* PME spreading order                       */
209     ivec  s;      /* The allocated size of *grid, s >= n       */
210     real *grid;   /* The grid local thread, size n             */
211 } pmegrid_t;
212
213 typedef struct {
214     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
215     int        nthread;      /* The number of threads operating on this grid     */
216     ivec       nc;           /* The local spatial decomposition over the threads */
217     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
218     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
219     int      **g2t;          /* The grid to thread index                         */
220     ivec       nthread_comm; /* The number of threads to communicate with        */
221 } pmegrids_t;
222
223
224 typedef struct {
225 #ifdef PME_SSE
226     /* Masks for SSE aligned spreading and gathering */
227     __m128 mask_SSE0[6], mask_SSE1[6];
228 #else
229     int    dummy; /* C89 requires that struct has at least one member */
230 #endif
231 } pme_spline_work_t;
232
233 typedef struct {
234     /* work data for solve_pme */
235     int      nalloc;
236     real *   mhx;
237     real *   mhy;
238     real *   mhz;
239     real *   m2;
240     real *   denom;
241     real *   tmp1_alloc;
242     real *   tmp1;
243     real *   eterm;
244     real *   m2inv;
245
246     real     energy;
247     matrix   vir;
248 } pme_work_t;
249
250 typedef struct gmx_pme {
251     int           ndecompdim; /* The number of decomposition dimensions */
252     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
253     int           nodeid_major;
254     int           nodeid_minor;
255     int           nnodes;    /* The number of nodes doing PME */
256     int           nnodes_major;
257     int           nnodes_minor;
258
259     MPI_Comm      mpi_comm;
260     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
261 #ifdef GMX_MPI
262     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
263 #endif
264
265     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
266     int        nthread;       /* The number of threads doing PME on our rank */
267
268     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
269     gmx_bool   bFEP;          /* Compute Free energy contribution */
270     int        nkx, nky, nkz; /* Grid dimensions */
271     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
272     int        pme_order;
273     real       epsilon_r;
274
275     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
276     pmegrids_t pmegridB;
277     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
278     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
279     /* pmegrid_nz might be larger than strictly necessary to ensure
280      * memory alignment, pmegrid_nz_base gives the real base size.
281      */
282     int     pmegrid_nz_base;
283     /* The local PME grid starting indices */
284     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
285
286     /* Work data for spreading and gathering */
287     pme_spline_work_t    *spline_work;
288
289     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
290     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
291     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
292
293     t_complex            *cfftgridA;  /* Grids for complex FFT data */
294     t_complex            *cfftgridB;
295     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
296
297     gmx_parallel_3dfft_t  pfft_setupA;
298     gmx_parallel_3dfft_t  pfft_setupB;
299
300     int                  *nnx, *nny, *nnz;
301     real                 *fshx, *fshy, *fshz;
302
303     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
304     matrix                recipbox;
305     splinevec             bsp_mod;
306
307     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
308
309     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
310
311     rvec                 *bufv;       /* Communication buffer */
312     real                 *bufr;       /* Communication buffer */
313     int                   buf_nalloc; /* The communication buffer size */
314
315     /* thread local work data for solve_pme */
316     pme_work_t *work;
317
318     /* Work data for PME_redist */
319     gmx_bool redist_init;
320     int *    scounts;
321     int *    rcounts;
322     int *    sdispls;
323     int *    rdispls;
324     int *    sidx;
325     int *    idxa;
326     real *   redist_buf;
327     int      redist_buf_nalloc;
328
329     /* Work data for sum_qgrid */
330     real *   sum_qgrid_tmp;
331     real *   sum_qgrid_dd_tmp;
332 } t_gmx_pme;
333
334
335 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
336                                    int start, int end, int thread)
337 {
338     int             i;
339     int            *idxptr, tix, tiy, tiz;
340     real           *xptr, *fptr, tx, ty, tz;
341     real            rxx, ryx, ryy, rzx, rzy, rzz;
342     int             nx, ny, nz;
343     int             start_ix, start_iy, start_iz;
344     int            *g2tx, *g2ty, *g2tz;
345     gmx_bool        bThreads;
346     int            *thread_idx = NULL;
347     thread_plist_t *tpl        = NULL;
348     int            *tpl_n      = NULL;
349     int             thread_i;
350
351     nx  = pme->nkx;
352     ny  = pme->nky;
353     nz  = pme->nkz;
354
355     start_ix = pme->pmegrid_start_ix;
356     start_iy = pme->pmegrid_start_iy;
357     start_iz = pme->pmegrid_start_iz;
358
359     rxx = pme->recipbox[XX][XX];
360     ryx = pme->recipbox[YY][XX];
361     ryy = pme->recipbox[YY][YY];
362     rzx = pme->recipbox[ZZ][XX];
363     rzy = pme->recipbox[ZZ][YY];
364     rzz = pme->recipbox[ZZ][ZZ];
365
366     g2tx = pme->pmegridA.g2t[XX];
367     g2ty = pme->pmegridA.g2t[YY];
368     g2tz = pme->pmegridA.g2t[ZZ];
369
370     bThreads = (atc->nthread > 1);
371     if (bThreads)
372     {
373         thread_idx = atc->thread_idx;
374
375         tpl   = &atc->thread_plist[thread];
376         tpl_n = tpl->n;
377         for (i = 0; i < atc->nthread; i++)
378         {
379             tpl_n[i] = 0;
380         }
381     }
382
383     for (i = start; i < end; i++)
384     {
385         xptr   = atc->x[i];
386         idxptr = atc->idx[i];
387         fptr   = atc->fractx[i];
388
389         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
390         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
391         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
392         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
393
394         tix = (int)(tx);
395         tiy = (int)(ty);
396         tiz = (int)(tz);
397
398         /* Because decomposition only occurs in x and y,
399          * we never have a fraction correction in z.
400          */
401         fptr[XX] = tx - tix + pme->fshx[tix];
402         fptr[YY] = ty - tiy + pme->fshy[tiy];
403         fptr[ZZ] = tz - tiz;
404
405         idxptr[XX] = pme->nnx[tix];
406         idxptr[YY] = pme->nny[tiy];
407         idxptr[ZZ] = pme->nnz[tiz];
408
409 #ifdef DEBUG
410         range_check(idxptr[XX], 0, pme->pmegrid_nx);
411         range_check(idxptr[YY], 0, pme->pmegrid_ny);
412         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
413 #endif
414
415         if (bThreads)
416         {
417             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
418             thread_idx[i] = thread_i;
419             tpl_n[thread_i]++;
420         }
421     }
422
423     if (bThreads)
424     {
425         /* Make a list of particle indices sorted on thread */
426
427         /* Get the cumulative count */
428         for (i = 1; i < atc->nthread; i++)
429         {
430             tpl_n[i] += tpl_n[i-1];
431         }
432         /* The current implementation distributes particles equally
433          * over the threads, so we could actually allocate for that
434          * in pme_realloc_atomcomm_things.
435          */
436         if (tpl_n[atc->nthread-1] > tpl->nalloc)
437         {
438             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
439             srenew(tpl->i, tpl->nalloc);
440         }
441         /* Set tpl_n to the cumulative start */
442         for (i = atc->nthread-1; i >= 1; i--)
443         {
444             tpl_n[i] = tpl_n[i-1];
445         }
446         tpl_n[0] = 0;
447
448         /* Fill our thread local array with indices sorted on thread */
449         for (i = start; i < end; i++)
450         {
451             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
452         }
453         /* Now tpl_n contains the cummulative count again */
454     }
455 }
456
457 static void make_thread_local_ind(pme_atomcomm_t *atc,
458                                   int thread, splinedata_t *spline)
459 {
460     int             n, t, i, start, end;
461     thread_plist_t *tpl;
462
463     /* Combine the indices made by each thread into one index */
464
465     n     = 0;
466     start = 0;
467     for (t = 0; t < atc->nthread; t++)
468     {
469         tpl = &atc->thread_plist[t];
470         /* Copy our part (start - end) from the list of thread t */
471         if (thread > 0)
472         {
473             start = tpl->n[thread-1];
474         }
475         end = tpl->n[thread];
476         for (i = start; i < end; i++)
477         {
478             spline->ind[n++] = tpl->i[i];
479         }
480     }
481
482     spline->n = n;
483 }
484
485
486 static void pme_calc_pidx(int start, int end,
487                           matrix recipbox, rvec x[],
488                           pme_atomcomm_t *atc, int *count)
489 {
490     int   nslab, i;
491     int   si;
492     real *xptr, s;
493     real  rxx, ryx, rzx, ryy, rzy;
494     int  *pd;
495
496     /* Calculate PME task index (pidx) for each grid index.
497      * Here we always assign equally sized slabs to each node
498      * for load balancing reasons (the PME grid spacing is not used).
499      */
500
501     nslab = atc->nslab;
502     pd    = atc->pd;
503
504     /* Reset the count */
505     for (i = 0; i < nslab; i++)
506     {
507         count[i] = 0;
508     }
509
510     if (atc->dimind == 0)
511     {
512         rxx = recipbox[XX][XX];
513         ryx = recipbox[YY][XX];
514         rzx = recipbox[ZZ][XX];
515         /* Calculate the node index in x-dimension */
516         for (i = start; i < end; i++)
517         {
518             xptr   = x[i];
519             /* Fractional coordinates along box vectors */
520             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
521             si    = (int)(s + 2*nslab) % nslab;
522             pd[i] = si;
523             count[si]++;
524         }
525     }
526     else
527     {
528         ryy = recipbox[YY][YY];
529         rzy = recipbox[ZZ][YY];
530         /* Calculate the node index in y-dimension */
531         for (i = start; i < end; i++)
532         {
533             xptr   = x[i];
534             /* Fractional coordinates along box vectors */
535             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
536             si    = (int)(s + 2*nslab) % nslab;
537             pd[i] = si;
538             count[si]++;
539         }
540     }
541 }
542
543 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
544                                   pme_atomcomm_t *atc)
545 {
546     int nthread, thread, slab;
547
548     nthread = atc->nthread;
549
550 #pragma omp parallel for num_threads(nthread) schedule(static)
551     for (thread = 0; thread < nthread; thread++)
552     {
553         pme_calc_pidx(natoms* thread   /nthread,
554                       natoms*(thread+1)/nthread,
555                       recipbox, x, atc, atc->count_thread[thread]);
556     }
557     /* Non-parallel reduction, since nslab is small */
558
559     for (thread = 1; thread < nthread; thread++)
560     {
561         for (slab = 0; slab < atc->nslab; slab++)
562         {
563             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
564         }
565     }
566 }
567
568 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
569 {
570     const int padding = 4;
571     int       i;
572
573     srenew(th[XX], nalloc);
574     srenew(th[YY], nalloc);
575     /* In z we add padding, this is only required for the aligned SSE code */
576     srenew(*ptr_z, nalloc+2*padding);
577     th[ZZ] = *ptr_z + padding;
578
579     for (i = 0; i < padding; i++)
580     {
581         (*ptr_z)[               i] = 0;
582         (*ptr_z)[padding+nalloc+i] = 0;
583     }
584 }
585
586 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
587 {
588     int i, d;
589
590     srenew(spline->ind, atc->nalloc);
591     /* Initialize the index to identity so it works without threads */
592     for (i = 0; i < atc->nalloc; i++)
593     {
594         spline->ind[i] = i;
595     }
596
597     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
598                       atc->pme_order*atc->nalloc);
599     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
600                       atc->pme_order*atc->nalloc);
601 }
602
603 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
604 {
605     int nalloc_old, i, j, nalloc_tpl;
606
607     /* We have to avoid a NULL pointer for atc->x to avoid
608      * possible fatal errors in MPI routines.
609      */
610     if (atc->n > atc->nalloc || atc->nalloc == 0)
611     {
612         nalloc_old  = atc->nalloc;
613         atc->nalloc = over_alloc_dd(max(atc->n, 1));
614
615         if (atc->nslab > 1)
616         {
617             srenew(atc->x, atc->nalloc);
618             srenew(atc->q, atc->nalloc);
619             srenew(atc->f, atc->nalloc);
620             for (i = nalloc_old; i < atc->nalloc; i++)
621             {
622                 clear_rvec(atc->f[i]);
623             }
624         }
625         if (atc->bSpread)
626         {
627             srenew(atc->fractx, atc->nalloc);
628             srenew(atc->idx, atc->nalloc);
629
630             if (atc->nthread > 1)
631             {
632                 srenew(atc->thread_idx, atc->nalloc);
633             }
634
635             for (i = 0; i < atc->nthread; i++)
636             {
637                 pme_realloc_splinedata(&atc->spline[i], atc);
638             }
639         }
640     }
641 }
642
643 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
644                          int n, gmx_bool bXF, rvec *x_f, real *charge,
645                          pme_atomcomm_t *atc)
646 /* Redistribute particle data for PME calculation */
647 /* domain decomposition by x coordinate           */
648 {
649     int *idxa;
650     int  i, ii;
651
652     if (FALSE == pme->redist_init)
653     {
654         snew(pme->scounts, atc->nslab);
655         snew(pme->rcounts, atc->nslab);
656         snew(pme->sdispls, atc->nslab);
657         snew(pme->rdispls, atc->nslab);
658         snew(pme->sidx, atc->nslab);
659         pme->redist_init = TRUE;
660     }
661     if (n > pme->redist_buf_nalloc)
662     {
663         pme->redist_buf_nalloc = over_alloc_dd(n);
664         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
665     }
666
667     pme->idxa = atc->pd;
668
669 #ifdef GMX_MPI
670     if (forw && bXF)
671     {
672         /* forward, redistribution from pp to pme */
673
674         /* Calculate send counts and exchange them with other nodes */
675         for (i = 0; (i < atc->nslab); i++)
676         {
677             pme->scounts[i] = 0;
678         }
679         for (i = 0; (i < n); i++)
680         {
681             pme->scounts[pme->idxa[i]]++;
682         }
683         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
684
685         /* Calculate send and receive displacements and index into send
686            buffer */
687         pme->sdispls[0] = 0;
688         pme->rdispls[0] = 0;
689         pme->sidx[0]    = 0;
690         for (i = 1; i < atc->nslab; i++)
691         {
692             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
693             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
694             pme->sidx[i]    = pme->sdispls[i];
695         }
696         /* Total # of particles to be received */
697         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
698
699         pme_realloc_atomcomm_things(atc);
700
701         /* Copy particle coordinates into send buffer and exchange*/
702         for (i = 0; (i < n); i++)
703         {
704             ii = DIM*pme->sidx[pme->idxa[i]];
705             pme->sidx[pme->idxa[i]]++;
706             pme->redist_buf[ii+XX] = x_f[i][XX];
707             pme->redist_buf[ii+YY] = x_f[i][YY];
708             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
709         }
710         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
711                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
712                       pme->rvec_mpi, atc->mpi_comm);
713     }
714     if (forw)
715     {
716         /* Copy charge into send buffer and exchange*/
717         for (i = 0; i < atc->nslab; i++)
718         {
719             pme->sidx[i] = pme->sdispls[i];
720         }
721         for (i = 0; (i < n); i++)
722         {
723             ii = pme->sidx[pme->idxa[i]];
724             pme->sidx[pme->idxa[i]]++;
725             pme->redist_buf[ii] = charge[i];
726         }
727         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
728                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
729                       atc->mpi_comm);
730     }
731     else   /* backward, redistribution from pme to pp */
732     {
733         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
734                       pme->redist_buf, pme->scounts, pme->sdispls,
735                       pme->rvec_mpi, atc->mpi_comm);
736
737         /* Copy data from receive buffer */
738         for (i = 0; i < atc->nslab; i++)
739         {
740             pme->sidx[i] = pme->sdispls[i];
741         }
742         for (i = 0; (i < n); i++)
743         {
744             ii          = DIM*pme->sidx[pme->idxa[i]];
745             x_f[i][XX] += pme->redist_buf[ii+XX];
746             x_f[i][YY] += pme->redist_buf[ii+YY];
747             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
748             pme->sidx[pme->idxa[i]]++;
749         }
750     }
751 #endif
752 }
753
754 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
755                             gmx_bool bBackward, int shift,
756                             void *buf_s, int nbyte_s,
757                             void *buf_r, int nbyte_r)
758 {
759 #ifdef GMX_MPI
760     int        dest, src;
761     MPI_Status stat;
762
763     if (bBackward == FALSE)
764     {
765         dest = atc->node_dest[shift];
766         src  = atc->node_src[shift];
767     }
768     else
769     {
770         dest = atc->node_src[shift];
771         src  = atc->node_dest[shift];
772     }
773
774     if (nbyte_s > 0 && nbyte_r > 0)
775     {
776         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
777                      dest, shift,
778                      buf_r, nbyte_r, MPI_BYTE,
779                      src, shift,
780                      atc->mpi_comm, &stat);
781     }
782     else if (nbyte_s > 0)
783     {
784         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
785                  dest, shift,
786                  atc->mpi_comm);
787     }
788     else if (nbyte_r > 0)
789     {
790         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
791                  src, shift,
792                  atc->mpi_comm, &stat);
793     }
794 #endif
795 }
796
797 static void dd_pmeredist_x_q(gmx_pme_t pme,
798                              int n, gmx_bool bX, rvec *x, real *charge,
799                              pme_atomcomm_t *atc)
800 {
801     int *commnode, *buf_index;
802     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
803
804     commnode  = atc->node_dest;
805     buf_index = atc->buf_index;
806
807     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
808
809     nsend = 0;
810     for (i = 0; i < nnodes_comm; i++)
811     {
812         buf_index[commnode[i]] = nsend;
813         nsend                 += atc->count[commnode[i]];
814     }
815     if (bX)
816     {
817         if (atc->count[atc->nodeid] + nsend != n)
818         {
819             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
820                       "This usually means that your system is not well equilibrated.",
821                       n - (atc->count[atc->nodeid] + nsend),
822                       pme->nodeid, 'x'+atc->dimind);
823         }
824
825         if (nsend > pme->buf_nalloc)
826         {
827             pme->buf_nalloc = over_alloc_dd(nsend);
828             srenew(pme->bufv, pme->buf_nalloc);
829             srenew(pme->bufr, pme->buf_nalloc);
830         }
831
832         atc->n = atc->count[atc->nodeid];
833         for (i = 0; i < nnodes_comm; i++)
834         {
835             scount = atc->count[commnode[i]];
836             /* Communicate the count */
837             if (debug)
838             {
839                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
840                         atc->dimind, atc->nodeid, commnode[i], scount);
841             }
842             pme_dd_sendrecv(atc, FALSE, i,
843                             &scount, sizeof(int),
844                             &atc->rcount[i], sizeof(int));
845             atc->n += atc->rcount[i];
846         }
847
848         pme_realloc_atomcomm_things(atc);
849     }
850
851     local_pos = 0;
852     for (i = 0; i < n; i++)
853     {
854         node = atc->pd[i];
855         if (node == atc->nodeid)
856         {
857             /* Copy direct to the receive buffer */
858             if (bX)
859             {
860                 copy_rvec(x[i], atc->x[local_pos]);
861             }
862             atc->q[local_pos] = charge[i];
863             local_pos++;
864         }
865         else
866         {
867             /* Copy to the send buffer */
868             if (bX)
869             {
870                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
871             }
872             pme->bufr[buf_index[node]] = charge[i];
873             buf_index[node]++;
874         }
875     }
876
877     buf_pos = 0;
878     for (i = 0; i < nnodes_comm; i++)
879     {
880         scount = atc->count[commnode[i]];
881         rcount = atc->rcount[i];
882         if (scount > 0 || rcount > 0)
883         {
884             if (bX)
885             {
886                 /* Communicate the coordinates */
887                 pme_dd_sendrecv(atc, FALSE, i,
888                                 pme->bufv[buf_pos], scount*sizeof(rvec),
889                                 atc->x[local_pos], rcount*sizeof(rvec));
890             }
891             /* Communicate the charges */
892             pme_dd_sendrecv(atc, FALSE, i,
893                             pme->bufr+buf_pos, scount*sizeof(real),
894                             atc->q+local_pos, rcount*sizeof(real));
895             buf_pos   += scount;
896             local_pos += atc->rcount[i];
897         }
898     }
899 }
900
901 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
902                            int n, rvec *f,
903                            gmx_bool bAddF)
904 {
905     int *commnode, *buf_index;
906     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
907
908     commnode  = atc->node_dest;
909     buf_index = atc->buf_index;
910
911     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
912
913     local_pos = atc->count[atc->nodeid];
914     buf_pos   = 0;
915     for (i = 0; i < nnodes_comm; i++)
916     {
917         scount = atc->rcount[i];
918         rcount = atc->count[commnode[i]];
919         if (scount > 0 || rcount > 0)
920         {
921             /* Communicate the forces */
922             pme_dd_sendrecv(atc, TRUE, i,
923                             atc->f[local_pos], scount*sizeof(rvec),
924                             pme->bufv[buf_pos], rcount*sizeof(rvec));
925             local_pos += scount;
926         }
927         buf_index[commnode[i]] = buf_pos;
928         buf_pos               += rcount;
929     }
930
931     local_pos = 0;
932     if (bAddF)
933     {
934         for (i = 0; i < n; i++)
935         {
936             node = atc->pd[i];
937             if (node == atc->nodeid)
938             {
939                 /* Add from the local force array */
940                 rvec_inc(f[i], atc->f[local_pos]);
941                 local_pos++;
942             }
943             else
944             {
945                 /* Add from the receive buffer */
946                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
947                 buf_index[node]++;
948             }
949         }
950     }
951     else
952     {
953         for (i = 0; i < n; i++)
954         {
955             node = atc->pd[i];
956             if (node == atc->nodeid)
957             {
958                 /* Copy from the local force array */
959                 copy_rvec(atc->f[local_pos], f[i]);
960                 local_pos++;
961             }
962             else
963             {
964                 /* Copy from the receive buffer */
965                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
966                 buf_index[node]++;
967             }
968         }
969     }
970 }
971
972 #ifdef GMX_MPI
973 static void
974 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
975 {
976     pme_overlap_t *overlap;
977     int            send_index0, send_nindex;
978     int            recv_index0, recv_nindex;
979     MPI_Status     stat;
980     int            i, j, k, ix, iy, iz, icnt;
981     int            ipulse, send_id, recv_id, datasize;
982     real          *p;
983     real          *sendptr, *recvptr;
984
985     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
986     overlap = &pme->overlap[1];
987
988     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
989     {
990         /* Since we have already (un)wrapped the overlap in the z-dimension,
991          * we only have to communicate 0 to nkz (not pmegrid_nz).
992          */
993         if (direction == GMX_SUM_QGRID_FORWARD)
994         {
995             send_id       = overlap->send_id[ipulse];
996             recv_id       = overlap->recv_id[ipulse];
997             send_index0   = overlap->comm_data[ipulse].send_index0;
998             send_nindex   = overlap->comm_data[ipulse].send_nindex;
999             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1000             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1001         }
1002         else
1003         {
1004             send_id       = overlap->recv_id[ipulse];
1005             recv_id       = overlap->send_id[ipulse];
1006             send_index0   = overlap->comm_data[ipulse].recv_index0;
1007             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1008             recv_index0   = overlap->comm_data[ipulse].send_index0;
1009             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1010         }
1011
1012         /* Copy data to contiguous send buffer */
1013         if (debug)
1014         {
1015             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1016                     pme->nodeid, overlap->nodeid, send_id,
1017                     pme->pmegrid_start_iy,
1018                     send_index0-pme->pmegrid_start_iy,
1019                     send_index0-pme->pmegrid_start_iy+send_nindex);
1020         }
1021         icnt = 0;
1022         for (i = 0; i < pme->pmegrid_nx; i++)
1023         {
1024             ix = i;
1025             for (j = 0; j < send_nindex; j++)
1026             {
1027                 iy = j + send_index0 - pme->pmegrid_start_iy;
1028                 for (k = 0; k < pme->nkz; k++)
1029                 {
1030                     iz = k;
1031                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1032                 }
1033             }
1034         }
1035
1036         datasize      = pme->pmegrid_nx * pme->nkz;
1037
1038         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1039                      send_id, ipulse,
1040                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1041                      recv_id, ipulse,
1042                      overlap->mpi_comm, &stat);
1043
1044         /* Get data from contiguous recv buffer */
1045         if (debug)
1046         {
1047             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1048                     pme->nodeid, overlap->nodeid, recv_id,
1049                     pme->pmegrid_start_iy,
1050                     recv_index0-pme->pmegrid_start_iy,
1051                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1052         }
1053         icnt = 0;
1054         for (i = 0; i < pme->pmegrid_nx; i++)
1055         {
1056             ix = i;
1057             for (j = 0; j < recv_nindex; j++)
1058             {
1059                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1060                 for (k = 0; k < pme->nkz; k++)
1061                 {
1062                     iz = k;
1063                     if (direction == GMX_SUM_QGRID_FORWARD)
1064                     {
1065                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1066                     }
1067                     else
1068                     {
1069                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1070                     }
1071                 }
1072             }
1073         }
1074     }
1075
1076     /* Major dimension is easier, no copying required,
1077      * but we might have to sum to separate array.
1078      * Since we don't copy, we have to communicate up to pmegrid_nz,
1079      * not nkz as for the minor direction.
1080      */
1081     overlap = &pme->overlap[0];
1082
1083     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1084     {
1085         if (direction == GMX_SUM_QGRID_FORWARD)
1086         {
1087             send_id       = overlap->send_id[ipulse];
1088             recv_id       = overlap->recv_id[ipulse];
1089             send_index0   = overlap->comm_data[ipulse].send_index0;
1090             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1091             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1092             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1093             recvptr       = overlap->recvbuf;
1094         }
1095         else
1096         {
1097             send_id       = overlap->recv_id[ipulse];
1098             recv_id       = overlap->send_id[ipulse];
1099             send_index0   = overlap->comm_data[ipulse].recv_index0;
1100             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1101             recv_index0   = overlap->comm_data[ipulse].send_index0;
1102             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1103             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1104         }
1105
1106         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1107         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1108
1109         if (debug)
1110         {
1111             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1112                     pme->nodeid, overlap->nodeid, send_id,
1113                     pme->pmegrid_start_ix,
1114                     send_index0-pme->pmegrid_start_ix,
1115                     send_index0-pme->pmegrid_start_ix+send_nindex);
1116             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1117                     pme->nodeid, overlap->nodeid, recv_id,
1118                     pme->pmegrid_start_ix,
1119                     recv_index0-pme->pmegrid_start_ix,
1120                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1121         }
1122
1123         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1124                      send_id, ipulse,
1125                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1126                      recv_id, ipulse,
1127                      overlap->mpi_comm, &stat);
1128
1129         /* ADD data from contiguous recv buffer */
1130         if (direction == GMX_SUM_QGRID_FORWARD)
1131         {
1132             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1133             for (i = 0; i < recv_nindex*datasize; i++)
1134             {
1135                 p[i] += overlap->recvbuf[i];
1136             }
1137         }
1138     }
1139 }
1140 #endif
1141
1142
1143 static int
1144 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1145 {
1146     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1147     ivec    local_pme_size;
1148     int     i, ix, iy, iz;
1149     int     pmeidx, fftidx;
1150
1151     /* Dimensions should be identical for A/B grid, so we just use A here */
1152     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1153                                    local_fft_ndata,
1154                                    local_fft_offset,
1155                                    local_fft_size);
1156
1157     local_pme_size[0] = pme->pmegrid_nx;
1158     local_pme_size[1] = pme->pmegrid_ny;
1159     local_pme_size[2] = pme->pmegrid_nz;
1160
1161     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1162        the offset is identical, and the PME grid always has more data (due to overlap)
1163      */
1164     {
1165 #ifdef DEBUG_PME
1166         FILE *fp, *fp2;
1167         char  fn[STRLEN], format[STRLEN];
1168         real  val;
1169         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1170         fp = ffopen(fn, "w");
1171         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1172         fp2 = ffopen(fn, "w");
1173         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1174 #endif
1175
1176         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1177         {
1178             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1179             {
1180                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1181                 {
1182                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1183                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1184                     fftgrid[fftidx] = pmegrid[pmeidx];
1185 #ifdef DEBUG_PME
1186                     val = 100*pmegrid[pmeidx];
1187                     if (pmegrid[pmeidx] != 0)
1188                     {
1189                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1190                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1191                     }
1192                     if (pmegrid[pmeidx] != 0)
1193                     {
1194                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1195                                 "qgrid",
1196                                 pme->pmegrid_start_ix + ix,
1197                                 pme->pmegrid_start_iy + iy,
1198                                 pme->pmegrid_start_iz + iz,
1199                                 pmegrid[pmeidx]);
1200                     }
1201 #endif
1202                 }
1203             }
1204         }
1205 #ifdef DEBUG_PME
1206         ffclose(fp);
1207         ffclose(fp2);
1208 #endif
1209     }
1210     return 0;
1211 }
1212
1213
1214 static gmx_cycles_t omp_cyc_start()
1215 {
1216     return gmx_cycles_read();
1217 }
1218
1219 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1220 {
1221     return gmx_cycles_read() - c;
1222 }
1223
1224
1225 static int
1226 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1227                         int nthread, int thread)
1228 {
1229     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1230     ivec          local_pme_size;
1231     int           ixy0, ixy1, ixy, ix, iy, iz;
1232     int           pmeidx, fftidx;
1233 #ifdef PME_TIME_THREADS
1234     gmx_cycles_t  c1;
1235     static double cs1 = 0;
1236     static int    cnt = 0;
1237 #endif
1238
1239 #ifdef PME_TIME_THREADS
1240     c1 = omp_cyc_start();
1241 #endif
1242     /* Dimensions should be identical for A/B grid, so we just use A here */
1243     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1244                                    local_fft_ndata,
1245                                    local_fft_offset,
1246                                    local_fft_size);
1247
1248     local_pme_size[0] = pme->pmegrid_nx;
1249     local_pme_size[1] = pme->pmegrid_ny;
1250     local_pme_size[2] = pme->pmegrid_nz;
1251
1252     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1253        the offset is identical, and the PME grid always has more data (due to overlap)
1254      */
1255     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1256     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1257
1258     for (ixy = ixy0; ixy < ixy1; ixy++)
1259     {
1260         ix = ixy/local_fft_ndata[YY];
1261         iy = ixy - ix*local_fft_ndata[YY];
1262
1263         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1264         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1265         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1266         {
1267             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1268         }
1269     }
1270
1271 #ifdef PME_TIME_THREADS
1272     c1   = omp_cyc_end(c1);
1273     cs1 += (double)c1;
1274     cnt++;
1275     if (cnt % 20 == 0)
1276     {
1277         printf("copy %.2f\n", cs1*1e-9);
1278     }
1279 #endif
1280
1281     return 0;
1282 }
1283
1284
1285 static void
1286 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1287 {
1288     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1289
1290     nx = pme->nkx;
1291     ny = pme->nky;
1292     nz = pme->nkz;
1293
1294     pnx = pme->pmegrid_nx;
1295     pny = pme->pmegrid_ny;
1296     pnz = pme->pmegrid_nz;
1297
1298     overlap = pme->pme_order - 1;
1299
1300     /* Add periodic overlap in z */
1301     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1302     {
1303         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1304         {
1305             for (iz = 0; iz < overlap; iz++)
1306             {
1307                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1308                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1309             }
1310         }
1311     }
1312
1313     if (pme->nnodes_minor == 1)
1314     {
1315         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1316         {
1317             for (iy = 0; iy < overlap; iy++)
1318             {
1319                 for (iz = 0; iz < nz; iz++)
1320                 {
1321                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1322                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1323                 }
1324             }
1325         }
1326     }
1327
1328     if (pme->nnodes_major == 1)
1329     {
1330         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1331
1332         for (ix = 0; ix < overlap; ix++)
1333         {
1334             for (iy = 0; iy < ny_x; iy++)
1335             {
1336                 for (iz = 0; iz < nz; iz++)
1337                 {
1338                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1339                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1340                 }
1341             }
1342         }
1343     }
1344 }
1345
1346
1347 static void
1348 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1349 {
1350     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1351
1352     nx = pme->nkx;
1353     ny = pme->nky;
1354     nz = pme->nkz;
1355
1356     pnx = pme->pmegrid_nx;
1357     pny = pme->pmegrid_ny;
1358     pnz = pme->pmegrid_nz;
1359
1360     overlap = pme->pme_order - 1;
1361
1362     if (pme->nnodes_major == 1)
1363     {
1364         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1365
1366         for (ix = 0; ix < overlap; ix++)
1367         {
1368             int iy, iz;
1369
1370             for (iy = 0; iy < ny_x; iy++)
1371             {
1372                 for (iz = 0; iz < nz; iz++)
1373                 {
1374                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1375                         pmegrid[(ix*pny+iy)*pnz+iz];
1376                 }
1377             }
1378         }
1379     }
1380
1381     if (pme->nnodes_minor == 1)
1382     {
1383 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1384         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1385         {
1386             int iy, iz;
1387
1388             for (iy = 0; iy < overlap; iy++)
1389             {
1390                 for (iz = 0; iz < nz; iz++)
1391                 {
1392                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1393                         pmegrid[(ix*pny+iy)*pnz+iz];
1394                 }
1395             }
1396         }
1397     }
1398
1399     /* Copy periodic overlap in z */
1400 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1401     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1402     {
1403         int iy, iz;
1404
1405         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1406         {
1407             for (iz = 0; iz < overlap; iz++)
1408             {
1409                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1410                     pmegrid[(ix*pny+iy)*pnz+iz];
1411             }
1412         }
1413     }
1414 }
1415
1416
1417 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1418 #define DO_BSPLINE(order)                            \
1419     for (ithx = 0; (ithx < order); ithx++)                    \
1420     {                                                    \
1421         index_x = (i0+ithx)*pny*pnz;                     \
1422         valx    = qn*thx[ithx];                          \
1423                                                      \
1424         for (ithy = 0; (ithy < order); ithy++)                \
1425         {                                                \
1426             valxy    = valx*thy[ithy];                   \
1427             index_xy = index_x+(j0+ithy)*pnz;            \
1428                                                      \
1429             for (ithz = 0; (ithz < order); ithz++)            \
1430             {                                            \
1431                 index_xyz        = index_xy+(k0+ithz);   \
1432                 grid[index_xyz] += valxy*thz[ithz];      \
1433             }                                            \
1434         }                                                \
1435     }
1436
1437
1438 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1439                                      pme_atomcomm_t *atc, splinedata_t *spline,
1440                                      pme_spline_work_t *work)
1441 {
1442
1443     /* spread charges from home atoms to local grid */
1444     real          *grid;
1445     pme_overlap_t *ol;
1446     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1447     int       *    idxptr;
1448     int            order, norder, index_x, index_xy, index_xyz;
1449     real           valx, valxy, qn;
1450     real          *thx, *thy, *thz;
1451     int            localsize, bndsize;
1452     int            pnx, pny, pnz, ndatatot;
1453     int            offx, offy, offz;
1454
1455     pnx = pmegrid->s[XX];
1456     pny = pmegrid->s[YY];
1457     pnz = pmegrid->s[ZZ];
1458
1459     offx = pmegrid->offset[XX];
1460     offy = pmegrid->offset[YY];
1461     offz = pmegrid->offset[ZZ];
1462
1463     ndatatot = pnx*pny*pnz;
1464     grid     = pmegrid->grid;
1465     for (i = 0; i < ndatatot; i++)
1466     {
1467         grid[i] = 0;
1468     }
1469
1470     order = pmegrid->order;
1471
1472     for (nn = 0; nn < spline->n; nn++)
1473     {
1474         n  = spline->ind[nn];
1475         qn = atc->q[n];
1476
1477         if (qn != 0)
1478         {
1479             idxptr = atc->idx[n];
1480             norder = nn*order;
1481
1482             i0   = idxptr[XX] - offx;
1483             j0   = idxptr[YY] - offy;
1484             k0   = idxptr[ZZ] - offz;
1485
1486             thx = spline->theta[XX] + norder;
1487             thy = spline->theta[YY] + norder;
1488             thz = spline->theta[ZZ] + norder;
1489
1490             switch (order)
1491             {
1492                 case 4:
1493 #ifdef PME_SSE
1494 #ifdef PME_SSE_UNALIGNED
1495 #define PME_SPREAD_SSE_ORDER4
1496 #else
1497 #define PME_SPREAD_SSE_ALIGNED
1498 #define PME_ORDER 4
1499 #endif
1500 #include "pme_sse_single.h"
1501 #else
1502                     DO_BSPLINE(4);
1503 #endif
1504                     break;
1505                 case 5:
1506 #ifdef PME_SSE
1507 #define PME_SPREAD_SSE_ALIGNED
1508 #define PME_ORDER 5
1509 #include "pme_sse_single.h"
1510 #else
1511                     DO_BSPLINE(5);
1512 #endif
1513                     break;
1514                 default:
1515                     DO_BSPLINE(order);
1516                     break;
1517             }
1518         }
1519     }
1520 }
1521
1522 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1523 {
1524 #ifdef PME_SSE
1525     if (pme_order == 5
1526 #ifndef PME_SSE_UNALIGNED
1527         || pme_order == 4
1528 #endif
1529         )
1530     {
1531         /* Round nz up to a multiple of 4 to ensure alignment */
1532         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1533     }
1534 #endif
1535 }
1536
1537 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
1538 {
1539 #ifdef PME_SSE
1540 #ifndef PME_SSE_UNALIGNED
1541     if (pme_order == 4)
1542     {
1543         /* Add extra elements to ensured aligned operations do not go
1544          * beyond the allocated grid size.
1545          * Note that for pme_order=5, the pme grid z-size alignment
1546          * ensures that we will not go beyond the grid size.
1547          */
1548         *gridsize += 4;
1549     }
1550 #endif
1551 #endif
1552 }
1553
1554 static void pmegrid_init(pmegrid_t *grid,
1555                          int cx, int cy, int cz,
1556                          int x0, int y0, int z0,
1557                          int x1, int y1, int z1,
1558                          gmx_bool set_alignment,
1559                          int pme_order,
1560                          real *ptr)
1561 {
1562     int nz, gridsize;
1563
1564     grid->ci[XX]     = cx;
1565     grid->ci[YY]     = cy;
1566     grid->ci[ZZ]     = cz;
1567     grid->offset[XX] = x0;
1568     grid->offset[YY] = y0;
1569     grid->offset[ZZ] = z0;
1570     grid->n[XX]      = x1 - x0 + pme_order - 1;
1571     grid->n[YY]      = y1 - y0 + pme_order - 1;
1572     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1573     copy_ivec(grid->n, grid->s);
1574
1575     nz = grid->s[ZZ];
1576     set_grid_alignment(&nz, pme_order);
1577     if (set_alignment)
1578     {
1579         grid->s[ZZ] = nz;
1580     }
1581     else if (nz != grid->s[ZZ])
1582     {
1583         gmx_incons("pmegrid_init call with an unaligned z size");
1584     }
1585
1586     grid->order = pme_order;
1587     if (ptr == NULL)
1588     {
1589         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1590         set_gridsize_alignment(&gridsize, pme_order);
1591         snew_aligned(grid->grid, gridsize, 16);
1592     }
1593     else
1594     {
1595         grid->grid = ptr;
1596     }
1597 }
1598
1599 static int div_round_up(int enumerator, int denominator)
1600 {
1601     return (enumerator + denominator - 1)/denominator;
1602 }
1603
1604 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1605                                   ivec nsub)
1606 {
1607     int gsize_opt, gsize;
1608     int nsx, nsy, nsz;
1609     char *env;
1610
1611     gsize_opt = -1;
1612     for (nsx = 1; nsx <= nthread; nsx++)
1613     {
1614         if (nthread % nsx == 0)
1615         {
1616             for (nsy = 1; nsy <= nthread; nsy++)
1617             {
1618                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1619                 {
1620                     nsz = nthread/(nsx*nsy);
1621
1622                     /* Determine the number of grid points per thread */
1623                     gsize =
1624                         (div_round_up(n[XX], nsx) + ovl)*
1625                         (div_round_up(n[YY], nsy) + ovl)*
1626                         (div_round_up(n[ZZ], nsz) + ovl);
1627
1628                     /* Minimize the number of grids points per thread
1629                      * and, secondarily, the number of cuts in minor dimensions.
1630                      */
1631                     if (gsize_opt == -1 ||
1632                         gsize < gsize_opt ||
1633                         (gsize == gsize_opt &&
1634                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1635                     {
1636                         nsub[XX]  = nsx;
1637                         nsub[YY]  = nsy;
1638                         nsub[ZZ]  = nsz;
1639                         gsize_opt = gsize;
1640                     }
1641                 }
1642             }
1643         }
1644     }
1645
1646     env = getenv("GMX_PME_THREAD_DIVISION");
1647     if (env != NULL)
1648     {
1649         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1650     }
1651
1652     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1653     {
1654         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1655     }
1656 }
1657
1658 static void pmegrids_init(pmegrids_t *grids,
1659                           int nx, int ny, int nz, int nz_base,
1660                           int pme_order,
1661                           gmx_bool bUseThreads,
1662                           int nthread,
1663                           int overlap_x,
1664                           int overlap_y)
1665 {
1666     ivec n, n_base, g0, g1;
1667     int t, x, y, z, d, i, tfac;
1668     int max_comm_lines = -1;
1669
1670     n[XX] = nx - (pme_order - 1);
1671     n[YY] = ny - (pme_order - 1);
1672     n[ZZ] = nz - (pme_order - 1);
1673
1674     copy_ivec(n, n_base);
1675     n_base[ZZ] = nz_base;
1676
1677     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1678                  NULL);
1679
1680     grids->nthread = nthread;
1681
1682     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1683
1684     if (bUseThreads)
1685     {
1686         ivec nst;
1687         int gridsize;
1688
1689         for (d = 0; d < DIM; d++)
1690         {
1691             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1692         }
1693         set_grid_alignment(&nst[ZZ], pme_order);
1694
1695         if (debug)
1696         {
1697             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1698                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1699             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1700                     nx, ny, nz,
1701                     nst[XX], nst[YY], nst[ZZ]);
1702         }
1703
1704         snew(grids->grid_th, grids->nthread);
1705         t        = 0;
1706         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1707         set_gridsize_alignment(&gridsize, pme_order);
1708         snew_aligned(grids->grid_all,
1709                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1710                      16);
1711
1712         for (x = 0; x < grids->nc[XX]; x++)
1713         {
1714             for (y = 0; y < grids->nc[YY]; y++)
1715             {
1716                 for (z = 0; z < grids->nc[ZZ]; z++)
1717                 {
1718                     pmegrid_init(&grids->grid_th[t],
1719                                  x, y, z,
1720                                  (n[XX]*(x  ))/grids->nc[XX],
1721                                  (n[YY]*(y  ))/grids->nc[YY],
1722                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1723                                  (n[XX]*(x+1))/grids->nc[XX],
1724                                  (n[YY]*(y+1))/grids->nc[YY],
1725                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1726                                  TRUE,
1727                                  pme_order,
1728                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1729                     t++;
1730                 }
1731             }
1732         }
1733     }
1734     else
1735     {
1736         grids->grid_th = NULL;
1737     }
1738
1739     snew(grids->g2t, DIM);
1740     tfac = 1;
1741     for (d = DIM-1; d >= 0; d--)
1742     {
1743         snew(grids->g2t[d], n[d]);
1744         t = 0;
1745         for (i = 0; i < n[d]; i++)
1746         {
1747             /* The second check should match the parameters
1748              * of the pmegrid_init call above.
1749              */
1750             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1751             {
1752                 t++;
1753             }
1754             grids->g2t[d][i] = t*tfac;
1755         }
1756
1757         tfac *= grids->nc[d];
1758
1759         switch (d)
1760         {
1761             case XX: max_comm_lines = overlap_x;     break;
1762             case YY: max_comm_lines = overlap_y;     break;
1763             case ZZ: max_comm_lines = pme_order - 1; break;
1764         }
1765         grids->nthread_comm[d] = 0;
1766         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1767                grids->nthread_comm[d] < grids->nc[d])
1768         {
1769             grids->nthread_comm[d]++;
1770         }
1771         if (debug != NULL)
1772         {
1773             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1774                     'x'+d, grids->nthread_comm[d]);
1775         }
1776         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1777          * work, but this is not a problematic restriction.
1778          */
1779         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1780         {
1781             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1782         }
1783     }
1784 }
1785
1786
1787 static void pmegrids_destroy(pmegrids_t *grids)
1788 {
1789     int t;
1790
1791     if (grids->grid.grid != NULL)
1792     {
1793         sfree(grids->grid.grid);
1794
1795         if (grids->nthread > 0)
1796         {
1797             for (t = 0; t < grids->nthread; t++)
1798             {
1799                 sfree(grids->grid_th[t].grid);
1800             }
1801             sfree(grids->grid_th);
1802         }
1803     }
1804 }
1805
1806
1807 static void realloc_work(pme_work_t *work, int nkx)
1808 {
1809     if (nkx > work->nalloc)
1810     {
1811         work->nalloc = nkx;
1812         srenew(work->mhx, work->nalloc);
1813         srenew(work->mhy, work->nalloc);
1814         srenew(work->mhz, work->nalloc);
1815         srenew(work->m2, work->nalloc);
1816         /* Allocate an aligned pointer for SSE operations, including 3 extra
1817          * elements at the end since SSE operates on 4 elements at a time.
1818          */
1819         sfree_aligned(work->denom);
1820         sfree_aligned(work->tmp1);
1821         sfree_aligned(work->eterm);
1822         snew_aligned(work->denom, work->nalloc+3, 16);
1823         snew_aligned(work->tmp1, work->nalloc+3, 16);
1824         snew_aligned(work->eterm, work->nalloc+3, 16);
1825         srenew(work->m2inv, work->nalloc);
1826     }
1827 }
1828
1829
1830 static void free_work(pme_work_t *work)
1831 {
1832     sfree(work->mhx);
1833     sfree(work->mhy);
1834     sfree(work->mhz);
1835     sfree(work->m2);
1836     sfree_aligned(work->denom);
1837     sfree_aligned(work->tmp1);
1838     sfree_aligned(work->eterm);
1839     sfree(work->m2inv);
1840 }
1841
1842
1843 #ifdef PME_SSE
1844 /* Calculate exponentials through SSE in float precision */
1845 inline static void calc_exponentials(int gmx_unused start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1846 {
1847     {
1848         const __m128 two = _mm_set_ps(2.0f, 2.0f, 2.0f, 2.0f);
1849         __m128 f_sse;
1850         __m128 lu;
1851         __m128 tmp_d1, d_inv, tmp_r, tmp_e;
1852         int kx;
1853         f_sse = _mm_load1_ps(&f);
1854         for (kx = 0; kx < end; kx += 4)
1855         {
1856             tmp_d1   = _mm_load_ps(d_aligned+kx);
1857             lu       = _mm_rcp_ps(tmp_d1);
1858             d_inv    = _mm_mul_ps(lu, _mm_sub_ps(two, _mm_mul_ps(lu, tmp_d1)));
1859             tmp_r    = _mm_load_ps(r_aligned+kx);
1860             tmp_r    = gmx_mm_exp_ps(tmp_r);
1861             tmp_e    = _mm_mul_ps(f_sse, d_inv);
1862             tmp_e    = _mm_mul_ps(tmp_e, tmp_r);
1863             _mm_store_ps(e_aligned+kx, tmp_e);
1864         }
1865     }
1866 }
1867 #else
1868 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1869 {
1870     int kx;
1871     for (kx = start; kx < end; kx++)
1872     {
1873         d[kx] = 1.0/d[kx];
1874     }
1875     for (kx = start; kx < end; kx++)
1876     {
1877         r[kx] = exp(r[kx]);
1878     }
1879     for (kx = start; kx < end; kx++)
1880     {
1881         e[kx] = f*r[kx]*d[kx];
1882     }
1883 }
1884 #endif
1885
1886
1887 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1888                          real ewaldcoeff, real vol,
1889                          gmx_bool bEnerVir,
1890                          int nthread, int thread)
1891 {
1892     /* do recip sum over local cells in grid */
1893     /* y major, z middle, x minor or continuous */
1894     t_complex *p0;
1895     int     kx, ky, kz, maxkx, maxky, maxkz;
1896     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1897     real    mx, my, mz;
1898     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1899     real    ets2, struct2, vfactor, ets2vf;
1900     real    d1, d2, energy = 0;
1901     real    by, bz;
1902     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1903     real    rxx, ryx, ryy, rzx, rzy, rzz;
1904     pme_work_t *work;
1905     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1906     real    mhxk, mhyk, mhzk, m2k;
1907     real    corner_fac;
1908     ivec    complex_order;
1909     ivec    local_ndata, local_offset, local_size;
1910     real    elfac;
1911
1912     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1913
1914     nx = pme->nkx;
1915     ny = pme->nky;
1916     nz = pme->nkz;
1917
1918     /* Dimensions should be identical for A/B grid, so we just use A here */
1919     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1920                                       complex_order,
1921                                       local_ndata,
1922                                       local_offset,
1923                                       local_size);
1924
1925     rxx = pme->recipbox[XX][XX];
1926     ryx = pme->recipbox[YY][XX];
1927     ryy = pme->recipbox[YY][YY];
1928     rzx = pme->recipbox[ZZ][XX];
1929     rzy = pme->recipbox[ZZ][YY];
1930     rzz = pme->recipbox[ZZ][ZZ];
1931
1932     maxkx = (nx+1)/2;
1933     maxky = (ny+1)/2;
1934     maxkz = nz/2+1;
1935
1936     work  = &pme->work[thread];
1937     mhx   = work->mhx;
1938     mhy   = work->mhy;
1939     mhz   = work->mhz;
1940     m2    = work->m2;
1941     denom = work->denom;
1942     tmp1  = work->tmp1;
1943     eterm = work->eterm;
1944     m2inv = work->m2inv;
1945
1946     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1947     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1948
1949     for (iyz = iyz0; iyz < iyz1; iyz++)
1950     {
1951         iy = iyz/local_ndata[ZZ];
1952         iz = iyz - iy*local_ndata[ZZ];
1953
1954         ky = iy + local_offset[YY];
1955
1956         if (ky < maxky)
1957         {
1958             my = ky;
1959         }
1960         else
1961         {
1962             my = (ky - ny);
1963         }
1964
1965         by = M_PI*vol*pme->bsp_mod[YY][ky];
1966
1967         kz = iz + local_offset[ZZ];
1968
1969         mz = kz;
1970
1971         bz = pme->bsp_mod[ZZ][kz];
1972
1973         /* 0.5 correction for corner points */
1974         corner_fac = 1;
1975         if (kz == 0 || kz == (nz+1)/2)
1976         {
1977             corner_fac = 0.5;
1978         }
1979
1980         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1981
1982         /* We should skip the k-space point (0,0,0) */
1983         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1984         {
1985             kxstart = local_offset[XX];
1986         }
1987         else
1988         {
1989             kxstart = local_offset[XX] + 1;
1990             p0++;
1991         }
1992         kxend = local_offset[XX] + local_ndata[XX];
1993
1994         if (bEnerVir)
1995         {
1996             /* More expensive inner loop, especially because of the storage
1997              * of the mh elements in array's.
1998              * Because x is the minor grid index, all mh elements
1999              * depend on kx for triclinic unit cells.
2000              */
2001
2002             /* Two explicit loops to avoid a conditional inside the loop */
2003             for (kx = kxstart; kx < maxkx; kx++)
2004             {
2005                 mx = kx;
2006
2007                 mhxk      = mx * rxx;
2008                 mhyk      = mx * ryx + my * ryy;
2009                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2010                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2011                 mhx[kx]   = mhxk;
2012                 mhy[kx]   = mhyk;
2013                 mhz[kx]   = mhzk;
2014                 m2[kx]    = m2k;
2015                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2016                 tmp1[kx]  = -factor*m2k;
2017             }
2018
2019             for (kx = maxkx; kx < kxend; kx++)
2020             {
2021                 mx = (kx - nx);
2022
2023                 mhxk      = mx * rxx;
2024                 mhyk      = mx * ryx + my * ryy;
2025                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2026                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2027                 mhx[kx]   = mhxk;
2028                 mhy[kx]   = mhyk;
2029                 mhz[kx]   = mhzk;
2030                 m2[kx]    = m2k;
2031                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2032                 tmp1[kx]  = -factor*m2k;
2033             }
2034
2035             for (kx = kxstart; kx < kxend; kx++)
2036             {
2037                 m2inv[kx] = 1.0/m2[kx];
2038             }
2039
2040             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2041
2042             for (kx = kxstart; kx < kxend; kx++, p0++)
2043             {
2044                 d1      = p0->re;
2045                 d2      = p0->im;
2046
2047                 p0->re  = d1*eterm[kx];
2048                 p0->im  = d2*eterm[kx];
2049
2050                 struct2 = 2.0*(d1*d1+d2*d2);
2051
2052                 tmp1[kx] = eterm[kx]*struct2;
2053             }
2054
2055             for (kx = kxstart; kx < kxend; kx++)
2056             {
2057                 ets2     = corner_fac*tmp1[kx];
2058                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2059                 energy  += ets2;
2060
2061                 ets2vf   = ets2*vfactor;
2062                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2063                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2064                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2065                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2066                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2067                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2068             }
2069         }
2070         else
2071         {
2072             /* We don't need to calculate the energy and the virial.
2073              * In this case the triclinic overhead is small.
2074              */
2075
2076             /* Two explicit loops to avoid a conditional inside the loop */
2077
2078             for (kx = kxstart; kx < maxkx; kx++)
2079             {
2080                 mx = kx;
2081
2082                 mhxk      = mx * rxx;
2083                 mhyk      = mx * ryx + my * ryy;
2084                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2085                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2086                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2087                 tmp1[kx]  = -factor*m2k;
2088             }
2089
2090             for (kx = maxkx; kx < kxend; kx++)
2091             {
2092                 mx = (kx - nx);
2093
2094                 mhxk      = mx * rxx;
2095                 mhyk      = mx * ryx + my * ryy;
2096                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2097                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2098                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2099                 tmp1[kx]  = -factor*m2k;
2100             }
2101
2102             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2103
2104             for (kx = kxstart; kx < kxend; kx++, p0++)
2105             {
2106                 d1      = p0->re;
2107                 d2      = p0->im;
2108
2109                 p0->re  = d1*eterm[kx];
2110                 p0->im  = d2*eterm[kx];
2111             }
2112         }
2113     }
2114
2115     if (bEnerVir)
2116     {
2117         /* Update virial with local values.
2118          * The virial is symmetric by definition.
2119          * this virial seems ok for isotropic scaling, but I'm
2120          * experiencing problems on semiisotropic membranes.
2121          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2122          */
2123         work->vir[XX][XX] = 0.25*virxx;
2124         work->vir[YY][YY] = 0.25*viryy;
2125         work->vir[ZZ][ZZ] = 0.25*virzz;
2126         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2127         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2128         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2129
2130         /* This energy should be corrected for a charged system */
2131         work->energy = 0.5*energy;
2132     }
2133
2134     /* Return the loop count */
2135     return local_ndata[YY]*local_ndata[XX];
2136 }
2137
2138 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2139                              real *mesh_energy, matrix vir)
2140 {
2141     /* This function sums output over threads
2142      * and should therefore only be called after thread synchronization.
2143      */
2144     int thread;
2145
2146     *mesh_energy = pme->work[0].energy;
2147     copy_mat(pme->work[0].vir, vir);
2148
2149     for (thread = 1; thread < nthread; thread++)
2150     {
2151         *mesh_energy += pme->work[thread].energy;
2152         m_add(vir, pme->work[thread].vir, vir);
2153     }
2154 }
2155
2156 #define DO_FSPLINE(order)                      \
2157     for (ithx = 0; (ithx < order); ithx++)              \
2158     {                                              \
2159         index_x = (i0+ithx)*pny*pnz;               \
2160         tx      = thx[ithx];                       \
2161         dx      = dthx[ithx];                      \
2162                                                \
2163         for (ithy = 0; (ithy < order); ithy++)          \
2164         {                                          \
2165             index_xy = index_x+(j0+ithy)*pnz;      \
2166             ty       = thy[ithy];                  \
2167             dy       = dthy[ithy];                 \
2168             fxy1     = fz1 = 0;                    \
2169                                                \
2170             for (ithz = 0; (ithz < order); ithz++)      \
2171             {                                      \
2172                 gval  = grid[index_xy+(k0+ithz)];  \
2173                 fxy1 += thz[ithz]*gval;            \
2174                 fz1  += dthz[ithz]*gval;           \
2175             }                                      \
2176             fx += dx*ty*fxy1;                      \
2177             fy += tx*dy*fxy1;                      \
2178             fz += tx*ty*fz1;                       \
2179         }                                          \
2180     }
2181
2182
2183 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2184                               gmx_bool bClearF, pme_atomcomm_t *atc,
2185                               splinedata_t *spline,
2186                               real scale)
2187 {
2188     /* sum forces for local particles */
2189     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2190     int     index_x, index_xy;
2191     int     nx, ny, nz, pnx, pny, pnz;
2192     int *   idxptr;
2193     real    tx, ty, dx, dy, qn;
2194     real    fx, fy, fz, gval;
2195     real    fxy1, fz1;
2196     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2197     int     norder;
2198     real    rxx, ryx, ryy, rzx, rzy, rzz;
2199     int     order;
2200
2201     pme_spline_work_t *work;
2202
2203     work = pme->spline_work;
2204
2205     order = pme->pme_order;
2206     thx   = spline->theta[XX];
2207     thy   = spline->theta[YY];
2208     thz   = spline->theta[ZZ];
2209     dthx  = spline->dtheta[XX];
2210     dthy  = spline->dtheta[YY];
2211     dthz  = spline->dtheta[ZZ];
2212     nx    = pme->nkx;
2213     ny    = pme->nky;
2214     nz    = pme->nkz;
2215     pnx   = pme->pmegrid_nx;
2216     pny   = pme->pmegrid_ny;
2217     pnz   = pme->pmegrid_nz;
2218
2219     rxx   = pme->recipbox[XX][XX];
2220     ryx   = pme->recipbox[YY][XX];
2221     ryy   = pme->recipbox[YY][YY];
2222     rzx   = pme->recipbox[ZZ][XX];
2223     rzy   = pme->recipbox[ZZ][YY];
2224     rzz   = pme->recipbox[ZZ][ZZ];
2225
2226     for (nn = 0; nn < spline->n; nn++)
2227     {
2228         n  = spline->ind[nn];
2229         qn = scale*atc->q[n];
2230
2231         if (bClearF)
2232         {
2233             atc->f[n][XX] = 0;
2234             atc->f[n][YY] = 0;
2235             atc->f[n][ZZ] = 0;
2236         }
2237         if (qn != 0)
2238         {
2239             fx     = 0;
2240             fy     = 0;
2241             fz     = 0;
2242             idxptr = atc->idx[n];
2243             norder = nn*order;
2244
2245             i0   = idxptr[XX];
2246             j0   = idxptr[YY];
2247             k0   = idxptr[ZZ];
2248
2249             /* Pointer arithmetic alert, next six statements */
2250             thx  = spline->theta[XX] + norder;
2251             thy  = spline->theta[YY] + norder;
2252             thz  = spline->theta[ZZ] + norder;
2253             dthx = spline->dtheta[XX] + norder;
2254             dthy = spline->dtheta[YY] + norder;
2255             dthz = spline->dtheta[ZZ] + norder;
2256
2257             switch (order)
2258             {
2259                 case 4:
2260 #ifdef PME_SSE
2261 #ifdef PME_SSE_UNALIGNED
2262 #define PME_GATHER_F_SSE_ORDER4
2263 #else
2264 #define PME_GATHER_F_SSE_ALIGNED
2265 #define PME_ORDER 4
2266 #endif
2267 #include "pme_sse_single.h"
2268 #else
2269                     DO_FSPLINE(4);
2270 #endif
2271                     break;
2272                 case 5:
2273 #ifdef PME_SSE
2274 #define PME_GATHER_F_SSE_ALIGNED
2275 #define PME_ORDER 5
2276 #include "pme_sse_single.h"
2277 #else
2278                     DO_FSPLINE(5);
2279 #endif
2280                     break;
2281                 default:
2282                     DO_FSPLINE(order);
2283                     break;
2284             }
2285
2286             atc->f[n][XX] += -qn*( fx*nx*rxx );
2287             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2288             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2289         }
2290     }
2291     /* Since the energy and not forces are interpolated
2292      * the net force might not be exactly zero.
2293      * This can be solved by also interpolating F, but
2294      * that comes at a cost.
2295      * A better hack is to remove the net force every
2296      * step, but that must be done at a higher level
2297      * since this routine doesn't see all atoms if running
2298      * in parallel. Don't know how important it is?  EL 990726
2299      */
2300 }
2301
2302
2303 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2304                                    pme_atomcomm_t *atc)
2305 {
2306     splinedata_t *spline;
2307     int     n, ithx, ithy, ithz, i0, j0, k0;
2308     int     index_x, index_xy;
2309     int *   idxptr;
2310     real    energy, pot, tx, ty, qn, gval;
2311     real    *thx, *thy, *thz;
2312     int     norder;
2313     int     order;
2314
2315     spline = &atc->spline[0];
2316
2317     order = pme->pme_order;
2318
2319     energy = 0;
2320     for (n = 0; (n < atc->n); n++)
2321     {
2322         qn      = atc->q[n];
2323
2324         if (qn != 0)
2325         {
2326             idxptr = atc->idx[n];
2327             norder = n*order;
2328
2329             i0   = idxptr[XX];
2330             j0   = idxptr[YY];
2331             k0   = idxptr[ZZ];
2332
2333             /* Pointer arithmetic alert, next three statements */
2334             thx  = spline->theta[XX] + norder;
2335             thy  = spline->theta[YY] + norder;
2336             thz  = spline->theta[ZZ] + norder;
2337
2338             pot = 0;
2339             for (ithx = 0; (ithx < order); ithx++)
2340             {
2341                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2342                 tx      = thx[ithx];
2343
2344                 for (ithy = 0; (ithy < order); ithy++)
2345                 {
2346                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2347                     ty       = thy[ithy];
2348
2349                     for (ithz = 0; (ithz < order); ithz++)
2350                     {
2351                         gval  = grid[index_xy+(k0+ithz)];
2352                         pot  += tx*ty*thz[ithz]*gval;
2353                     }
2354
2355                 }
2356             }
2357
2358             energy += pot*qn;
2359         }
2360     }
2361
2362     return energy;
2363 }
2364
2365 /* Macro to force loop unrolling by fixing order.
2366  * This gives a significant performance gain.
2367  */
2368 #define CALC_SPLINE(order)                     \
2369     {                                              \
2370         int j, k, l;                                 \
2371         real dr, div;                               \
2372         real data[PME_ORDER_MAX];                  \
2373         real ddata[PME_ORDER_MAX];                 \
2374                                                \
2375         for (j = 0; (j < DIM); j++)                     \
2376         {                                          \
2377             dr  = xptr[j];                         \
2378                                                \
2379             /* dr is relative offset from lower cell limit */ \
2380             data[order-1] = 0;                     \
2381             data[1]       = dr;                          \
2382             data[0]       = 1 - dr;                      \
2383                                                \
2384             for (k = 3; (k < order); k++)               \
2385             {                                      \
2386                 div       = 1.0/(k - 1.0);               \
2387                 data[k-1] = div*dr*data[k-2];      \
2388                 for (l = 1; (l < (k-1)); l++)           \
2389                 {                                  \
2390                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2391                                        data[k-l-1]);                \
2392                 }                                  \
2393                 data[0] = div*(1-dr)*data[0];      \
2394             }                                      \
2395             /* differentiate */                    \
2396             ddata[0] = -data[0];                   \
2397             for (k = 1; (k < order); k++)               \
2398             {                                      \
2399                 ddata[k] = data[k-1] - data[k];    \
2400             }                                      \
2401                                                \
2402             div           = 1.0/(order - 1);                 \
2403             data[order-1] = div*dr*data[order-2];  \
2404             for (l = 1; (l < (order-1)); l++)           \
2405             {                                      \
2406                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2407                                        (order-l-dr)*data[order-l-1]); \
2408             }                                      \
2409             data[0] = div*(1 - dr)*data[0];        \
2410                                                \
2411             for (k = 0; k < order; k++)                 \
2412             {                                      \
2413                 theta[j][i*order+k]  = data[k];    \
2414                 dtheta[j][i*order+k] = ddata[k];   \
2415             }                                      \
2416         }                                          \
2417     }
2418
2419 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2420                    rvec fractx[], int nr, int ind[], real charge[],
2421                    gmx_bool bFreeEnergy)
2422 {
2423     /* construct splines for local atoms */
2424     int  i, ii;
2425     real *xptr;
2426
2427     for (i = 0; i < nr; i++)
2428     {
2429         /* With free energy we do not use the charge check.
2430          * In most cases this will be more efficient than calling make_bsplines
2431          * twice, since usually more than half the particles have charges.
2432          */
2433         ii = ind[i];
2434         if (bFreeEnergy || charge[ii] != 0.0)
2435         {
2436             xptr = fractx[ii];
2437             switch (order)
2438             {
2439                 case 4:  CALC_SPLINE(4);     break;
2440                 case 5:  CALC_SPLINE(5);     break;
2441                 default: CALC_SPLINE(order); break;
2442             }
2443         }
2444     }
2445 }
2446
2447
2448 void make_dft_mod(real *mod, real *data, int ndata)
2449 {
2450     int i, j;
2451     real sc, ss, arg;
2452
2453     for (i = 0; i < ndata; i++)
2454     {
2455         sc = ss = 0;
2456         for (j = 0; j < ndata; j++)
2457         {
2458             arg = (2.0*M_PI*i*j)/ndata;
2459             sc += data[j]*cos(arg);
2460             ss += data[j]*sin(arg);
2461         }
2462         mod[i] = sc*sc+ss*ss;
2463     }
2464     for (i = 0; i < ndata; i++)
2465     {
2466         if (mod[i] < 1e-7)
2467         {
2468             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2469         }
2470     }
2471 }
2472
2473
2474 static void make_bspline_moduli(splinevec bsp_mod,
2475                                 int nx, int ny, int nz, int order)
2476 {
2477     int nmax = max(nx, max(ny, nz));
2478     real *data, *ddata, *bsp_data;
2479     int i, k, l;
2480     real div;
2481
2482     snew(data, order);
2483     snew(ddata, order);
2484     snew(bsp_data, nmax);
2485
2486     data[order-1] = 0;
2487     data[1]       = 0;
2488     data[0]       = 1;
2489
2490     for (k = 3; k < order; k++)
2491     {
2492         div       = 1.0/(k-1.0);
2493         data[k-1] = 0;
2494         for (l = 1; l < (k-1); l++)
2495         {
2496             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2497         }
2498         data[0] = div*data[0];
2499     }
2500     /* differentiate */
2501     ddata[0] = -data[0];
2502     for (k = 1; k < order; k++)
2503     {
2504         ddata[k] = data[k-1]-data[k];
2505     }
2506     div           = 1.0/(order-1);
2507     data[order-1] = 0;
2508     for (l = 1; l < (order-1); l++)
2509     {
2510         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2511     }
2512     data[0] = div*data[0];
2513
2514     for (i = 0; i < nmax; i++)
2515     {
2516         bsp_data[i] = 0;
2517     }
2518     for (i = 1; i <= order; i++)
2519     {
2520         bsp_data[i] = data[i-1];
2521     }
2522
2523     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2524     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2525     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2526
2527     sfree(data);
2528     sfree(ddata);
2529     sfree(bsp_data);
2530 }
2531
2532
2533 /* Return the P3M optimal influence function */
2534 static double do_p3m_influence(double z, int order)
2535 {
2536     double z2, z4;
2537
2538     z2 = z*z;
2539     z4 = z2*z2;
2540
2541     /* The formula and most constants can be found in:
2542      * Ballenegger et al., JCTC 8, 936 (2012)
2543      */
2544     switch (order)
2545     {
2546         case 2:
2547             return 1.0 - 2.0*z2/3.0;
2548             break;
2549         case 3:
2550             return 1.0 - z2 + 2.0*z4/15.0;
2551             break;
2552         case 4:
2553             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2554             break;
2555         case 5:
2556             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2557             break;
2558         case 6:
2559             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2560             break;
2561         case 7:
2562             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2563         case 8:
2564             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2565             break;
2566     }
2567
2568     return 0.0;
2569 }
2570
2571 /* Calculate the P3M B-spline moduli for one dimension */
2572 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2573 {
2574     double zarg, zai, sinzai, infl;
2575     int    maxk, i;
2576
2577     if (order > 8)
2578     {
2579         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2580     }
2581
2582     zarg = M_PI/n;
2583
2584     maxk = (n + 1)/2;
2585
2586     for (i = -maxk; i < 0; i++)
2587     {
2588         zai          = zarg*i;
2589         sinzai       = sin(zai);
2590         infl         = do_p3m_influence(sinzai, order);
2591         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2592     }
2593     bsp_mod[0] = 1.0;
2594     for (i = 1; i < maxk; i++)
2595     {
2596         zai        = zarg*i;
2597         sinzai     = sin(zai);
2598         infl       = do_p3m_influence(sinzai, order);
2599         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2600     }
2601 }
2602
2603 /* Calculate the P3M B-spline moduli */
2604 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2605                                     int nx, int ny, int nz, int order)
2606 {
2607     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2608     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2609     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2610 }
2611
2612
2613 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2614 {
2615     int nslab, n, i;
2616     int fw, bw;
2617
2618     nslab = atc->nslab;
2619
2620     n = 0;
2621     for (i = 1; i <= nslab/2; i++)
2622     {
2623         fw = (atc->nodeid + i) % nslab;
2624         bw = (atc->nodeid - i + nslab) % nslab;
2625         if (n < nslab - 1)
2626         {
2627             atc->node_dest[n] = fw;
2628             atc->node_src[n]  = bw;
2629             n++;
2630         }
2631         if (n < nslab - 1)
2632         {
2633             atc->node_dest[n] = bw;
2634             atc->node_src[n]  = fw;
2635             n++;
2636         }
2637     }
2638 }
2639
2640 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2641 {
2642     int thread;
2643
2644     if (NULL != log)
2645     {
2646         fprintf(log, "Destroying PME data structures.\n");
2647     }
2648
2649     sfree((*pmedata)->nnx);
2650     sfree((*pmedata)->nny);
2651     sfree((*pmedata)->nnz);
2652
2653     pmegrids_destroy(&(*pmedata)->pmegridA);
2654
2655     sfree((*pmedata)->fftgridA);
2656     sfree((*pmedata)->cfftgridA);
2657     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2658
2659     if ((*pmedata)->pmegridB.grid.grid != NULL)
2660     {
2661         pmegrids_destroy(&(*pmedata)->pmegridB);
2662         sfree((*pmedata)->fftgridB);
2663         sfree((*pmedata)->cfftgridB);
2664         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2665     }
2666     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2667     {
2668         free_work(&(*pmedata)->work[thread]);
2669     }
2670     sfree((*pmedata)->work);
2671
2672     sfree(*pmedata);
2673     *pmedata = NULL;
2674
2675     return 0;
2676 }
2677
2678 static int mult_up(int n, int f)
2679 {
2680     return ((n + f - 1)/f)*f;
2681 }
2682
2683
2684 static double pme_load_imbalance(gmx_pme_t pme)
2685 {
2686     int    nma, nmi;
2687     double n1, n2, n3;
2688
2689     nma = pme->nnodes_major;
2690     nmi = pme->nnodes_minor;
2691
2692     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2693     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2694     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2695
2696     /* pme_solve is roughly double the cost of an fft */
2697
2698     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2699 }
2700
2701 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc,
2702                           int dimind, gmx_bool bSpread)
2703 {
2704     int nk, k, s, thread;
2705
2706     atc->dimind    = dimind;
2707     atc->nslab     = 1;
2708     atc->nodeid    = 0;
2709     atc->pd_nalloc = 0;
2710 #ifdef GMX_MPI
2711     if (pme->nnodes > 1)
2712     {
2713         atc->mpi_comm = pme->mpi_comm_d[dimind];
2714         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2715         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2716     }
2717     if (debug)
2718     {
2719         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2720     }
2721 #endif
2722
2723     atc->bSpread   = bSpread;
2724     atc->pme_order = pme->pme_order;
2725
2726     if (atc->nslab > 1)
2727     {
2728         /* These three allocations are not required for particle decomp. */
2729         snew(atc->node_dest, atc->nslab);
2730         snew(atc->node_src, atc->nslab);
2731         setup_coordinate_communication(atc);
2732
2733         snew(atc->count_thread, pme->nthread);
2734         for (thread = 0; thread < pme->nthread; thread++)
2735         {
2736             snew(atc->count_thread[thread], atc->nslab);
2737         }
2738         atc->count = atc->count_thread[0];
2739         snew(atc->rcount, atc->nslab);
2740         snew(atc->buf_index, atc->nslab);
2741     }
2742
2743     atc->nthread = pme->nthread;
2744     if (atc->nthread > 1)
2745     {
2746         snew(atc->thread_plist, atc->nthread);
2747     }
2748     snew(atc->spline, atc->nthread);
2749     for (thread = 0; thread < atc->nthread; thread++)
2750     {
2751         if (atc->nthread > 1)
2752         {
2753             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2754             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2755         }
2756         snew(atc->spline[thread].thread_one, pme->nthread);
2757         atc->spline[thread].thread_one[thread] = 1;
2758     }
2759 }
2760
2761 static void
2762 init_overlap_comm(pme_overlap_t *  ol,
2763                   int              norder,
2764 #ifdef GMX_MPI
2765                   MPI_Comm         comm,
2766 #endif
2767                   int              nnodes,
2768                   int              nodeid,
2769                   int              ndata,
2770                   int              commplainsize)
2771 {
2772     int lbnd, rbnd, maxlr, b, i;
2773     int exten;
2774     int nn, nk;
2775     pme_grid_comm_t *pgc;
2776     gmx_bool bCont;
2777     int fft_start, fft_end, send_index1, recv_index1;
2778 #ifdef GMX_MPI
2779     MPI_Status stat;
2780
2781     ol->mpi_comm = comm;
2782 #endif
2783
2784     ol->nnodes = nnodes;
2785     ol->nodeid = nodeid;
2786
2787     /* Linear translation of the PME grid won't affect reciprocal space
2788      * calculations, so to optimize we only interpolate "upwards",
2789      * which also means we only have to consider overlap in one direction.
2790      * I.e., particles on this node might also be spread to grid indices
2791      * that belong to higher nodes (modulo nnodes)
2792      */
2793
2794     snew(ol->s2g0, ol->nnodes+1);
2795     snew(ol->s2g1, ol->nnodes);
2796     if (debug)
2797     {
2798         fprintf(debug, "PME slab boundaries:");
2799     }
2800     for (i = 0; i < nnodes; i++)
2801     {
2802         /* s2g0 the local interpolation grid start.
2803          * s2g1 the local interpolation grid end.
2804          * Because grid overlap communication only goes forward,
2805          * the grid the slabs for fft's should be rounded down.
2806          */
2807         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2808         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2809
2810         if (debug)
2811         {
2812             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2813         }
2814     }
2815     ol->s2g0[nnodes] = ndata;
2816     if (debug)
2817     {
2818         fprintf(debug, "\n");
2819     }
2820
2821     /* Determine with how many nodes we need to communicate the grid overlap */
2822     b = 0;
2823     do
2824     {
2825         b++;
2826         bCont = FALSE;
2827         for (i = 0; i < nnodes; i++)
2828         {
2829             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2830                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2831             {
2832                 bCont = TRUE;
2833             }
2834         }
2835     }
2836     while (bCont && b < nnodes);
2837     ol->noverlap_nodes = b - 1;
2838
2839     snew(ol->send_id, ol->noverlap_nodes);
2840     snew(ol->recv_id, ol->noverlap_nodes);
2841     for (b = 0; b < ol->noverlap_nodes; b++)
2842     {
2843         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2844         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2845     }
2846     snew(ol->comm_data, ol->noverlap_nodes);
2847
2848     ol->send_size = 0;
2849     for (b = 0; b < ol->noverlap_nodes; b++)
2850     {
2851         pgc = &ol->comm_data[b];
2852         /* Send */
2853         fft_start        = ol->s2g0[ol->send_id[b]];
2854         fft_end          = ol->s2g0[ol->send_id[b]+1];
2855         if (ol->send_id[b] < nodeid)
2856         {
2857             fft_start += ndata;
2858             fft_end   += ndata;
2859         }
2860         send_index1       = ol->s2g1[nodeid];
2861         send_index1       = min(send_index1, fft_end);
2862         pgc->send_index0  = fft_start;
2863         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2864         ol->send_size    += pgc->send_nindex;
2865
2866         /* We always start receiving to the first index of our slab */
2867         fft_start        = ol->s2g0[ol->nodeid];
2868         fft_end          = ol->s2g0[ol->nodeid+1];
2869         recv_index1      = ol->s2g1[ol->recv_id[b]];
2870         if (ol->recv_id[b] > nodeid)
2871         {
2872             recv_index1 -= ndata;
2873         }
2874         recv_index1      = min(recv_index1, fft_end);
2875         pgc->recv_index0 = fft_start;
2876         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2877     }
2878
2879 #ifdef GMX_MPI
2880     /* Communicate the buffer sizes to receive */
2881     for (b = 0; b < ol->noverlap_nodes; b++)
2882     {
2883         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2884                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2885                      ol->mpi_comm, &stat);
2886     }
2887 #endif
2888
2889     /* For non-divisible grid we need pme_order iso pme_order-1 */
2890     snew(ol->sendbuf, norder*commplainsize);
2891     snew(ol->recvbuf, norder*commplainsize);
2892 }
2893
2894 static void
2895 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2896                               int **global_to_local,
2897                               real **fraction_shift)
2898 {
2899     int i;
2900     int * gtl;
2901     real * fsh;
2902
2903     snew(gtl, 5*n);
2904     snew(fsh, 5*n);
2905     for (i = 0; (i < 5*n); i++)
2906     {
2907         /* Determine the global to local grid index */
2908         gtl[i] = (i - local_start + n) % n;
2909         /* For coordinates that fall within the local grid the fraction
2910          * is correct, we don't need to shift it.
2911          */
2912         fsh[i] = 0;
2913         if (local_range < n)
2914         {
2915             /* Due to rounding issues i could be 1 beyond the lower or
2916              * upper boundary of the local grid. Correct the index for this.
2917              * If we shift the index, we need to shift the fraction by
2918              * the same amount in the other direction to not affect
2919              * the weights.
2920              * Note that due to this shifting the weights at the end of
2921              * the spline might change, but that will only involve values
2922              * between zero and values close to the precision of a real,
2923              * which is anyhow the accuracy of the whole mesh calculation.
2924              */
2925             /* With local_range=0 we should not change i=local_start */
2926             if (i % n != local_start)
2927             {
2928                 if (gtl[i] == n-1)
2929                 {
2930                     gtl[i] = 0;
2931                     fsh[i] = -1;
2932                 }
2933                 else if (gtl[i] == local_range)
2934                 {
2935                     gtl[i] = local_range - 1;
2936                     fsh[i] = 1;
2937                 }
2938             }
2939         }
2940     }
2941
2942     *global_to_local = gtl;
2943     *fraction_shift  = fsh;
2944 }
2945
2946 static pme_spline_work_t *make_pme_spline_work(int order)
2947 {
2948     pme_spline_work_t *work;
2949
2950 #ifdef PME_SSE
2951     float  tmp[8];
2952     __m128 zero_SSE;
2953     int    of, i;
2954
2955     snew_aligned(work, 1, 16);
2956
2957     zero_SSE = _mm_setzero_ps();
2958
2959     /* Generate bit masks to mask out the unused grid entries,
2960      * as we only operate on order of the 8 grid entries that are
2961      * load into 2 SSE float registers.
2962      */
2963     for (of = 0; of < 8-(order-1); of++)
2964     {
2965         for (i = 0; i < 8; i++)
2966         {
2967             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2968         }
2969         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2970         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2971         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
2972         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
2973     }
2974 #else
2975     work = NULL;
2976 #endif
2977
2978     return work;
2979 }
2980
2981 void gmx_pme_check_restrictions(int pme_order,
2982                                 int nkx, int nky, int nkz,
2983                                 int nnodes_major,
2984                                 int nnodes_minor,
2985                                 gmx_bool bUseThreads,
2986                                 gmx_bool bFatal,
2987                                 gmx_bool *bValidSettings)
2988 {
2989     if (pme_order > PME_ORDER_MAX)
2990     {
2991         if (!bFatal)
2992         {
2993             *bValidSettings = FALSE;
2994             return;
2995         }
2996         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
2997                   pme_order, PME_ORDER_MAX);
2998     }
2999
3000     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3001         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3002         nkz <= pme_order)
3003     {
3004         if (!bFatal)
3005         {
3006             *bValidSettings = FALSE;
3007             return;
3008         }
3009         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3010                   pme_order);
3011     }
3012
3013     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3014      * We only allow multiple communication pulses in dim 1, not in dim 0.
3015      */
3016     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3017                         nkx != nnodes_major*(pme_order - 1)))
3018     {
3019         if (!bFatal)
3020         {
3021             *bValidSettings = FALSE;
3022             return;
3023         }
3024         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3025                   nkx/(double)nnodes_major, pme_order);
3026     }
3027
3028     if (bValidSettings != NULL)
3029     {
3030         *bValidSettings = TRUE;
3031     }
3032
3033     return;
3034 }
3035
3036 int gmx_pme_init(gmx_pme_t *         pmedata,
3037                  t_commrec *         cr,
3038                  int                 nnodes_major,
3039                  int                 nnodes_minor,
3040                  t_inputrec *        ir,
3041                  int                 homenr,
3042                  gmx_bool            bFreeEnergy,
3043                  gmx_bool            bReproducible,
3044                  int                 nthread)
3045 {
3046     gmx_pme_t pme = NULL;
3047
3048     int  use_threads, sum_use_threads;
3049     ivec ndata;
3050
3051     if (debug)
3052     {
3053         fprintf(debug, "Creating PME data structures.\n");
3054     }
3055     snew(pme, 1);
3056
3057     pme->redist_init         = FALSE;
3058     pme->sum_qgrid_tmp       = NULL;
3059     pme->sum_qgrid_dd_tmp    = NULL;
3060     pme->buf_nalloc          = 0;
3061     pme->redist_buf_nalloc   = 0;
3062
3063     pme->nnodes              = 1;
3064     pme->bPPnode             = TRUE;
3065
3066     pme->nnodes_major        = nnodes_major;
3067     pme->nnodes_minor        = nnodes_minor;
3068
3069 #ifdef GMX_MPI
3070     if (nnodes_major*nnodes_minor > 1)
3071     {
3072         pme->mpi_comm = cr->mpi_comm_mygroup;
3073
3074         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3075         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3076         if (pme->nnodes != nnodes_major*nnodes_minor)
3077         {
3078             gmx_incons("PME node count mismatch");
3079         }
3080     }
3081     else
3082     {
3083         pme->mpi_comm = MPI_COMM_NULL;
3084     }
3085 #endif
3086
3087     if (pme->nnodes == 1)
3088     {
3089 #ifdef GMX_MPI
3090         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3091         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3092 #endif
3093         pme->ndecompdim   = 0;
3094         pme->nodeid_major = 0;
3095         pme->nodeid_minor = 0;
3096 #ifdef GMX_MPI
3097         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3098 #endif
3099     }
3100     else
3101     {
3102         if (nnodes_minor == 1)
3103         {
3104 #ifdef GMX_MPI
3105             pme->mpi_comm_d[0] = pme->mpi_comm;
3106             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3107 #endif
3108             pme->ndecompdim   = 1;
3109             pme->nodeid_major = pme->nodeid;
3110             pme->nodeid_minor = 0;
3111
3112         }
3113         else if (nnodes_major == 1)
3114         {
3115 #ifdef GMX_MPI
3116             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3117             pme->mpi_comm_d[1] = pme->mpi_comm;
3118 #endif
3119             pme->ndecompdim   = 1;
3120             pme->nodeid_major = 0;
3121             pme->nodeid_minor = pme->nodeid;
3122         }
3123         else
3124         {
3125             if (pme->nnodes % nnodes_major != 0)
3126             {
3127                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3128             }
3129             pme->ndecompdim = 2;
3130
3131 #ifdef GMX_MPI
3132             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3133                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3134             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3135                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3136
3137             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3138             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3139             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3140             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3141 #endif
3142         }
3143         pme->bPPnode = (cr->duty & DUTY_PP);
3144     }
3145
3146     pme->nthread = nthread;
3147
3148     /* Check if any of the PME MPI ranks uses threads */
3149     use_threads = (pme->nthread > 1 ? 1 : 0);
3150 #ifdef GMX_MPI
3151     if (pme->nnodes > 1)
3152     {
3153         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3154                       MPI_SUM, pme->mpi_comm);
3155     }
3156     else
3157 #endif
3158     {
3159         sum_use_threads = use_threads;
3160     }
3161     pme->bUseThreads = (sum_use_threads > 0);
3162
3163     if (ir->ePBC == epbcSCREW)
3164     {
3165         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3166     }
3167
3168     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3169     pme->nkx         = ir->nkx;
3170     pme->nky         = ir->nky;
3171     pme->nkz         = ir->nkz;
3172     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3173     pme->pme_order   = ir->pme_order;
3174     pme->epsilon_r   = ir->epsilon_r;
3175
3176     /* If we violate restrictions, generate a fatal error here */
3177     gmx_pme_check_restrictions(pme->pme_order,
3178                                pme->nkx, pme->nky, pme->nkz,
3179                                pme->nnodes_major,
3180                                pme->nnodes_minor,
3181                                pme->bUseThreads,
3182                                TRUE,
3183                                NULL);
3184
3185     if (pme->nnodes > 1)
3186     {
3187         double imbal;
3188
3189 #ifdef GMX_MPI
3190         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3191         MPI_Type_commit(&(pme->rvec_mpi));
3192 #endif
3193
3194         /* Note that the charge spreading and force gathering, which usually
3195          * takes about the same amount of time as FFT+solve_pme,
3196          * is always fully load balanced
3197          * (unless the charge distribution is inhomogeneous).
3198          */
3199
3200         imbal = pme_load_imbalance(pme);
3201         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3202         {
3203             fprintf(stderr,
3204                     "\n"
3205                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3206                     "      For optimal PME load balancing\n"
3207                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3208                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3209                     "\n",
3210                     (int)((imbal-1)*100 + 0.5),
3211                     pme->nkx, pme->nky, pme->nnodes_major,
3212                     pme->nky, pme->nkz, pme->nnodes_minor);
3213         }
3214     }
3215
3216     /* For non-divisible grid we need pme_order iso pme_order-1 */
3217     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3218      * y is always copied through a buffer: we don't need padding in z,
3219      * but we do need the overlap in x because of the communication order.
3220      */
3221     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3222 #ifdef GMX_MPI
3223                       pme->mpi_comm_d[0],
3224 #endif
3225                       pme->nnodes_major, pme->nodeid_major,
3226                       pme->nkx,
3227                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3228
3229     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3230      * We do this with an offset buffer of equal size, so we need to allocate
3231      * extra for the offset. That's what the (+1)*pme->nkz is for.
3232      */
3233     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3234 #ifdef GMX_MPI
3235                       pme->mpi_comm_d[1],
3236 #endif
3237                       pme->nnodes_minor, pme->nodeid_minor,
3238                       pme->nky,
3239                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3240
3241     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3242      * Note that gmx_pme_check_restrictions checked for this already.
3243      */
3244     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3245     {
3246         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3247     }
3248
3249     snew(pme->bsp_mod[XX], pme->nkx);
3250     snew(pme->bsp_mod[YY], pme->nky);
3251     snew(pme->bsp_mod[ZZ], pme->nkz);
3252
3253     /* The required size of the interpolation grid, including overlap.
3254      * The allocated size (pmegrid_n?) might be slightly larger.
3255      */
3256     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3257         pme->overlap[0].s2g0[pme->nodeid_major];
3258     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3259         pme->overlap[1].s2g0[pme->nodeid_minor];
3260     pme->pmegrid_nz_base = pme->nkz;
3261     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3262     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3263
3264     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3265     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3266     pme->pmegrid_start_iz = 0;
3267
3268     make_gridindex5_to_localindex(pme->nkx,
3269                                   pme->pmegrid_start_ix,
3270                                   pme->pmegrid_nx - (pme->pme_order-1),
3271                                   &pme->nnx, &pme->fshx);
3272     make_gridindex5_to_localindex(pme->nky,
3273                                   pme->pmegrid_start_iy,
3274                                   pme->pmegrid_ny - (pme->pme_order-1),
3275                                   &pme->nny, &pme->fshy);
3276     make_gridindex5_to_localindex(pme->nkz,
3277                                   pme->pmegrid_start_iz,
3278                                   pme->pmegrid_nz_base,
3279                                   &pme->nnz, &pme->fshz);
3280
3281     pmegrids_init(&pme->pmegridA,
3282                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3283                   pme->pmegrid_nz_base,
3284                   pme->pme_order,
3285                   pme->bUseThreads,
3286                   pme->nthread,
3287                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3288                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3289
3290     pme->spline_work = make_pme_spline_work(pme->pme_order);
3291
3292     ndata[0] = pme->nkx;
3293     ndata[1] = pme->nky;
3294     ndata[2] = pme->nkz;
3295
3296     /* This routine will allocate the grid data to fit the FFTs */
3297     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3298                             &pme->fftgridA, &pme->cfftgridA,
3299                             pme->mpi_comm_d,
3300                             bReproducible, pme->nthread);
3301
3302     if (bFreeEnergy)
3303     {
3304         pmegrids_init(&pme->pmegridB,
3305                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3306                       pme->pmegrid_nz_base,
3307                       pme->pme_order,
3308                       pme->bUseThreads,
3309                       pme->nthread,
3310                       pme->nkx % pme->nnodes_major != 0,
3311                       pme->nky % pme->nnodes_minor != 0);
3312
3313         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3314                                 &pme->fftgridB, &pme->cfftgridB,
3315                                 pme->mpi_comm_d,
3316                                 bReproducible, pme->nthread);
3317     }
3318     else
3319     {
3320         pme->pmegridB.grid.grid = NULL;
3321         pme->fftgridB           = NULL;
3322         pme->cfftgridB          = NULL;
3323     }
3324
3325     if (!pme->bP3M)
3326     {
3327         /* Use plain SPME B-spline interpolation */
3328         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3329     }
3330     else
3331     {
3332         /* Use the P3M grid-optimized influence function */
3333         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3334     }
3335
3336     /* Use atc[0] for spreading */
3337     init_atomcomm(pme, &pme->atc[0], nnodes_major > 1 ? 0 : 1, TRUE);
3338     if (pme->ndecompdim >= 2)
3339     {
3340         init_atomcomm(pme, &pme->atc[1], 1, FALSE);
3341     }
3342
3343     if (pme->nnodes == 1)
3344     {
3345         pme->atc[0].n = homenr;
3346         pme_realloc_atomcomm_things(&pme->atc[0]);
3347     }
3348
3349     {
3350         int thread;
3351
3352         /* Use fft5d, order after FFT is y major, z, x minor */
3353
3354         snew(pme->work, pme->nthread);
3355         for (thread = 0; thread < pme->nthread; thread++)
3356         {
3357             realloc_work(&pme->work[thread], pme->nkx);
3358         }
3359     }
3360
3361     *pmedata = pme;
3362
3363     return 0;
3364 }
3365
3366 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3367 {
3368     int d, t;
3369
3370     for (d = 0; d < DIM; d++)
3371     {
3372         if (new->grid.n[d] > old->grid.n[d])
3373         {
3374             return;
3375         }
3376     }
3377
3378     sfree_aligned(new->grid.grid);
3379     new->grid.grid = old->grid.grid;
3380
3381     if (new->grid_th != NULL && new->nthread == old->nthread)
3382     {
3383         sfree_aligned(new->grid_all);
3384         for (t = 0; t < new->nthread; t++)
3385         {
3386             new->grid_th[t].grid = old->grid_th[t].grid;
3387         }
3388     }
3389 }
3390
3391 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3392                    t_commrec *         cr,
3393                    gmx_pme_t           pme_src,
3394                    const t_inputrec *  ir,
3395                    ivec                grid_size)
3396 {
3397     t_inputrec irc;
3398     int homenr;
3399     int ret;
3400
3401     irc     = *ir;
3402     irc.nkx = grid_size[XX];
3403     irc.nky = grid_size[YY];
3404     irc.nkz = grid_size[ZZ];
3405
3406     if (pme_src->nnodes == 1)
3407     {
3408         homenr = pme_src->atc[0].n;
3409     }
3410     else
3411     {
3412         homenr = -1;
3413     }
3414
3415     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3416                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3417
3418     if (ret == 0)
3419     {
3420         /* We can easily reuse the allocated pme grids in pme_src */
3421         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3422         /* We would like to reuse the fft grids, but that's harder */
3423     }
3424
3425     return ret;
3426 }
3427
3428
3429 static void copy_local_grid(gmx_pme_t pme,
3430                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3431 {
3432     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3433     int  fft_my, fft_mz;
3434     int  nsx, nsy, nsz;
3435     ivec nf;
3436     int  offx, offy, offz, x, y, z, i0, i0t;
3437     int  d;
3438     pmegrid_t *pmegrid;
3439     real *grid_th;
3440
3441     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3442                                    local_fft_ndata,
3443                                    local_fft_offset,
3444                                    local_fft_size);
3445     fft_my = local_fft_size[YY];
3446     fft_mz = local_fft_size[ZZ];
3447
3448     pmegrid = &pmegrids->grid_th[thread];
3449
3450     nsx = pmegrid->s[XX];
3451     nsy = pmegrid->s[YY];
3452     nsz = pmegrid->s[ZZ];
3453
3454     for (d = 0; d < DIM; d++)
3455     {
3456         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3457                     local_fft_ndata[d] - pmegrid->offset[d]);
3458     }
3459
3460     offx = pmegrid->offset[XX];
3461     offy = pmegrid->offset[YY];
3462     offz = pmegrid->offset[ZZ];
3463
3464     /* Directly copy the non-overlapping parts of the local grids.
3465      * This also initializes the full grid.
3466      */
3467     grid_th = pmegrid->grid;
3468     for (x = 0; x < nf[XX]; x++)
3469     {
3470         for (y = 0; y < nf[YY]; y++)
3471         {
3472             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3473             i0t = (x*nsy + y)*nsz;
3474             for (z = 0; z < nf[ZZ]; z++)
3475             {
3476                 fftgrid[i0+z] = grid_th[i0t+z];
3477             }
3478         }
3479     }
3480 }
3481
3482 static void
3483 reduce_threadgrid_overlap(gmx_pme_t pme,
3484                           const pmegrids_t *pmegrids, int thread,
3485                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3486 {
3487     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3488     int  fft_nx, fft_ny, fft_nz;
3489     int  fft_my, fft_mz;
3490     int  buf_my = -1;
3491     int  nsx, nsy, nsz;
3492     ivec ne;
3493     int  offx, offy, offz, x, y, z, i0, i0t;
3494     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3495     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3496     gmx_bool bCommX, bCommY;
3497     int  d;
3498     int  thread_f;
3499     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3500     const real *grid_th;
3501     real *commbuf = NULL;
3502
3503     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3504                                    local_fft_ndata,
3505                                    local_fft_offset,
3506                                    local_fft_size);
3507     fft_nx = local_fft_ndata[XX];
3508     fft_ny = local_fft_ndata[YY];
3509     fft_nz = local_fft_ndata[ZZ];
3510
3511     fft_my = local_fft_size[YY];
3512     fft_mz = local_fft_size[ZZ];
3513
3514     /* This routine is called when all thread have finished spreading.
3515      * Here each thread sums grid contributions calculated by other threads
3516      * to the thread local grid volume.
3517      * To minimize the number of grid copying operations,
3518      * this routines sums immediately from the pmegrid to the fftgrid.
3519      */
3520
3521     /* Determine which part of the full node grid we should operate on,
3522      * this is our thread local part of the full grid.
3523      */
3524     pmegrid = &pmegrids->grid_th[thread];
3525
3526     for (d = 0; d < DIM; d++)
3527     {
3528         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3529                     local_fft_ndata[d]);
3530     }
3531
3532     offx = pmegrid->offset[XX];
3533     offy = pmegrid->offset[YY];
3534     offz = pmegrid->offset[ZZ];
3535
3536
3537     bClearBufX  = TRUE;
3538     bClearBufY  = TRUE;
3539     bClearBufXY = TRUE;
3540
3541     /* Now loop over all the thread data blocks that contribute
3542      * to the grid region we (our thread) are operating on.
3543      */
3544     /* Note that ffy_nx/y is equal to the number of grid points
3545      * between the first point of our node grid and the one of the next node.
3546      */
3547     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3548     {
3549         fx     = pmegrid->ci[XX] + sx;
3550         ox     = 0;
3551         bCommX = FALSE;
3552         if (fx < 0)
3553         {
3554             fx    += pmegrids->nc[XX];
3555             ox    -= fft_nx;
3556             bCommX = (pme->nnodes_major > 1);
3557         }
3558         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3559         ox       += pmegrid_g->offset[XX];
3560         if (!bCommX)
3561         {
3562             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3563         }
3564         else
3565         {
3566             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3567         }
3568
3569         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3570         {
3571             fy     = pmegrid->ci[YY] + sy;
3572             oy     = 0;
3573             bCommY = FALSE;
3574             if (fy < 0)
3575             {
3576                 fy    += pmegrids->nc[YY];
3577                 oy    -= fft_ny;
3578                 bCommY = (pme->nnodes_minor > 1);
3579             }
3580             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3581             oy       += pmegrid_g->offset[YY];
3582             if (!bCommY)
3583             {
3584                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3585             }
3586             else
3587             {
3588                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3589             }
3590
3591             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3592             {
3593                 fz = pmegrid->ci[ZZ] + sz;
3594                 oz = 0;
3595                 if (fz < 0)
3596                 {
3597                     fz += pmegrids->nc[ZZ];
3598                     oz -= fft_nz;
3599                 }
3600                 pmegrid_g = &pmegrids->grid_th[fz];
3601                 oz       += pmegrid_g->offset[ZZ];
3602                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3603
3604                 if (sx == 0 && sy == 0 && sz == 0)
3605                 {
3606                     /* We have already added our local contribution
3607                      * before calling this routine, so skip it here.
3608                      */
3609                     continue;
3610                 }
3611
3612                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3613
3614                 pmegrid_f = &pmegrids->grid_th[thread_f];
3615
3616                 grid_th = pmegrid_f->grid;
3617
3618                 nsx = pmegrid_f->s[XX];
3619                 nsy = pmegrid_f->s[YY];
3620                 nsz = pmegrid_f->s[ZZ];
3621
3622 #ifdef DEBUG_PME_REDUCE
3623                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3624                        pme->nodeid, thread, thread_f,
3625                        pme->pmegrid_start_ix,
3626                        pme->pmegrid_start_iy,
3627                        pme->pmegrid_start_iz,
3628                        sx, sy, sz,
3629                        offx-ox, tx1-ox, offx, tx1,
3630                        offy-oy, ty1-oy, offy, ty1,
3631                        offz-oz, tz1-oz, offz, tz1);
3632 #endif
3633
3634                 if (!(bCommX || bCommY))
3635                 {
3636                     /* Copy from the thread local grid to the node grid */
3637                     for (x = offx; x < tx1; x++)
3638                     {
3639                         for (y = offy; y < ty1; y++)
3640                         {
3641                             i0  = (x*fft_my + y)*fft_mz;
3642                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3643                             for (z = offz; z < tz1; z++)
3644                             {
3645                                 fftgrid[i0+z] += grid_th[i0t+z];
3646                             }
3647                         }
3648                     }
3649                 }
3650                 else
3651                 {
3652                     /* The order of this conditional decides
3653                      * where the corner volume gets stored with x+y decomp.
3654                      */
3655                     if (bCommY)
3656                     {
3657                         commbuf = commbuf_y;
3658                         buf_my  = ty1 - offy;
3659                         if (bCommX)
3660                         {
3661                             /* We index commbuf modulo the local grid size */
3662                             commbuf += buf_my*fft_nx*fft_nz;
3663
3664                             bClearBuf   = bClearBufXY;
3665                             bClearBufXY = FALSE;
3666                         }
3667                         else
3668                         {
3669                             bClearBuf  = bClearBufY;
3670                             bClearBufY = FALSE;
3671                         }
3672                     }
3673                     else
3674                     {
3675                         commbuf    = commbuf_x;
3676                         buf_my     = fft_ny;
3677                         bClearBuf  = bClearBufX;
3678                         bClearBufX = FALSE;
3679                     }
3680
3681                     /* Copy to the communication buffer */
3682                     for (x = offx; x < tx1; x++)
3683                     {
3684                         for (y = offy; y < ty1; y++)
3685                         {
3686                             i0  = (x*buf_my + y)*fft_nz;
3687                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3688
3689                             if (bClearBuf)
3690                             {
3691                                 /* First access of commbuf, initialize it */
3692                                 for (z = offz; z < tz1; z++)
3693                                 {
3694                                     commbuf[i0+z]  = grid_th[i0t+z];
3695                                 }
3696                             }
3697                             else
3698                             {
3699                                 for (z = offz; z < tz1; z++)
3700                                 {
3701                                     commbuf[i0+z] += grid_th[i0t+z];
3702                                 }
3703                             }
3704                         }
3705                     }
3706                 }
3707             }
3708         }
3709     }
3710 }
3711
3712
3713 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3714 {
3715     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3716     pme_overlap_t *overlap;
3717     int  send_index0, send_nindex;
3718     int  recv_nindex;
3719 #ifdef GMX_MPI
3720     MPI_Status stat;
3721 #endif
3722     int  send_size_y, recv_size_y;
3723     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3724     real *sendptr, *recvptr;
3725     int  x, y, z, indg, indb;
3726
3727     /* Note that this routine is only used for forward communication.
3728      * Since the force gathering, unlike the charge spreading,
3729      * can be trivially parallelized over the particles,
3730      * the backwards process is much simpler and can use the "old"
3731      * communication setup.
3732      */
3733
3734     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3735                                    local_fft_ndata,
3736                                    local_fft_offset,
3737                                    local_fft_size);
3738
3739     if (pme->nnodes_minor > 1)
3740     {
3741         /* Major dimension */
3742         overlap = &pme->overlap[1];
3743
3744         if (pme->nnodes_major > 1)
3745         {
3746             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3747         }
3748         else
3749         {
3750             size_yx = 0;
3751         }
3752         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3753
3754         send_size_y = overlap->send_size;
3755
3756         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3757         {
3758             send_id       = overlap->send_id[ipulse];
3759             recv_id       = overlap->recv_id[ipulse];
3760             send_index0   =
3761                 overlap->comm_data[ipulse].send_index0 -
3762                 overlap->comm_data[0].send_index0;
3763             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3764             /* We don't use recv_index0, as we always receive starting at 0 */
3765             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3766             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3767
3768             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3769             recvptr = overlap->recvbuf;
3770
3771 #ifdef GMX_MPI
3772             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3773                          send_id, ipulse,
3774                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3775                          recv_id, ipulse,
3776                          overlap->mpi_comm, &stat);
3777 #endif
3778
3779             for (x = 0; x < local_fft_ndata[XX]; x++)
3780             {
3781                 for (y = 0; y < recv_nindex; y++)
3782                 {
3783                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3784                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3785                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3786                     {
3787                         fftgrid[indg+z] += recvptr[indb+z];
3788                     }
3789                 }
3790             }
3791
3792             if (pme->nnodes_major > 1)
3793             {
3794                 /* Copy from the received buffer to the send buffer for dim 0 */
3795                 sendptr = pme->overlap[0].sendbuf;
3796                 for (x = 0; x < size_yx; x++)
3797                 {
3798                     for (y = 0; y < recv_nindex; y++)
3799                     {
3800                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3801                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3802                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3803                         {
3804                             sendptr[indg+z] += recvptr[indb+z];
3805                         }
3806                     }
3807                 }
3808             }
3809         }
3810     }
3811
3812     /* We only support a single pulse here.
3813      * This is not a severe limitation, as this code is only used
3814      * with OpenMP and with OpenMP the (PME) domains can be larger.
3815      */
3816     if (pme->nnodes_major > 1)
3817     {
3818         /* Major dimension */
3819         overlap = &pme->overlap[0];
3820
3821         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3822         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3823
3824         ipulse = 0;
3825
3826         send_id       = overlap->send_id[ipulse];
3827         recv_id       = overlap->recv_id[ipulse];
3828         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3829         /* We don't use recv_index0, as we always receive starting at 0 */
3830         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3831
3832         sendptr = overlap->sendbuf;
3833         recvptr = overlap->recvbuf;
3834
3835         if (debug != NULL)
3836         {
3837             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3838                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3839         }
3840
3841 #ifdef GMX_MPI
3842         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3843                      send_id, ipulse,
3844                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3845                      recv_id, ipulse,
3846                      overlap->mpi_comm, &stat);
3847 #endif
3848
3849         for (x = 0; x < recv_nindex; x++)
3850         {
3851             for (y = 0; y < local_fft_ndata[YY]; y++)
3852             {
3853                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3854                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3855                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3856                 {
3857                     fftgrid[indg+z] += recvptr[indb+z];
3858                 }
3859             }
3860         }
3861     }
3862 }
3863
3864
3865 static void spread_on_grid(gmx_pme_t pme,
3866                            pme_atomcomm_t *atc, pmegrids_t *grids,
3867                            gmx_bool bCalcSplines, gmx_bool bSpread,
3868                            real *fftgrid)
3869 {
3870     int nthread, thread;
3871 #ifdef PME_TIME_THREADS
3872     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3873     static double cs1     = 0, cs2 = 0, cs3 = 0;
3874     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3875     static int cnt        = 0;
3876 #endif
3877
3878     nthread = pme->nthread;
3879     assert(nthread > 0);
3880
3881 #ifdef PME_TIME_THREADS
3882     c1 = omp_cyc_start();
3883 #endif
3884     if (bCalcSplines)
3885     {
3886 #pragma omp parallel for num_threads(nthread) schedule(static)
3887         for (thread = 0; thread < nthread; thread++)
3888         {
3889             int start, end;
3890
3891             start = atc->n* thread   /nthread;
3892             end   = atc->n*(thread+1)/nthread;
3893
3894             /* Compute fftgrid index for all atoms,
3895              * with help of some extra variables.
3896              */
3897             calc_interpolation_idx(pme, atc, start, end, thread);
3898         }
3899     }
3900 #ifdef PME_TIME_THREADS
3901     c1   = omp_cyc_end(c1);
3902     cs1 += (double)c1;
3903 #endif
3904
3905 #ifdef PME_TIME_THREADS
3906     c2 = omp_cyc_start();
3907 #endif
3908 #pragma omp parallel for num_threads(nthread) schedule(static)
3909     for (thread = 0; thread < nthread; thread++)
3910     {
3911         splinedata_t *spline;
3912         pmegrid_t *grid = NULL;
3913
3914         /* make local bsplines  */
3915         if (grids == NULL || !pme->bUseThreads)
3916         {
3917             spline = &atc->spline[0];
3918
3919             spline->n = atc->n;
3920
3921             if (bSpread)
3922             {
3923                 grid = &grids->grid;
3924             }
3925         }
3926         else
3927         {
3928             spline = &atc->spline[thread];
3929
3930             if (grids->nthread == 1)
3931             {
3932                 /* One thread, we operate on all charges */
3933                 spline->n = atc->n;
3934             }
3935             else
3936             {
3937                 /* Get the indices our thread should operate on */
3938                 make_thread_local_ind(atc, thread, spline);
3939             }
3940
3941             grid = &grids->grid_th[thread];
3942         }
3943
3944         if (bCalcSplines)
3945         {
3946             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
3947                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
3948         }
3949
3950         if (bSpread)
3951         {
3952             /* put local atoms on grid. */
3953 #ifdef PME_TIME_SPREAD
3954             ct1a = omp_cyc_start();
3955 #endif
3956             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
3957
3958             if (pme->bUseThreads)
3959             {
3960                 copy_local_grid(pme, grids, thread, fftgrid);
3961             }
3962 #ifdef PME_TIME_SPREAD
3963             ct1a          = omp_cyc_end(ct1a);
3964             cs1a[thread] += (double)ct1a;
3965 #endif
3966         }
3967     }
3968 #ifdef PME_TIME_THREADS
3969     c2   = omp_cyc_end(c2);
3970     cs2 += (double)c2;
3971 #endif
3972
3973     if (bSpread && pme->bUseThreads)
3974     {
3975 #ifdef PME_TIME_THREADS
3976         c3 = omp_cyc_start();
3977 #endif
3978 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3979         for (thread = 0; thread < grids->nthread; thread++)
3980         {
3981             reduce_threadgrid_overlap(pme, grids, thread,
3982                                       fftgrid,
3983                                       pme->overlap[0].sendbuf,
3984                                       pme->overlap[1].sendbuf);
3985         }
3986 #ifdef PME_TIME_THREADS
3987         c3   = omp_cyc_end(c3);
3988         cs3 += (double)c3;
3989 #endif
3990
3991         if (pme->nnodes > 1)
3992         {
3993             /* Communicate the overlapping part of the fftgrid.
3994              * For this communication call we need to check pme->bUseThreads
3995              * to have all ranks communicate here, regardless of pme->nthread.
3996              */
3997             sum_fftgrid_dd(pme, fftgrid);
3998         }
3999     }
4000
4001 #ifdef PME_TIME_THREADS
4002     cnt++;
4003     if (cnt % 20 == 0)
4004     {
4005         printf("idx %.2f spread %.2f red %.2f",
4006                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4007 #ifdef PME_TIME_SPREAD
4008         for (thread = 0; thread < nthread; thread++)
4009         {
4010             printf(" %.2f", cs1a[thread]*1e-9);
4011         }
4012 #endif
4013         printf("\n");
4014     }
4015 #endif
4016 }
4017
4018
4019 static void dump_grid(FILE *fp,
4020                       int sx, int sy, int sz, int nx, int ny, int nz,
4021                       int my, int mz, const real *g)
4022 {
4023     int x, y, z;
4024
4025     for (x = 0; x < nx; x++)
4026     {
4027         for (y = 0; y < ny; y++)
4028         {
4029             for (z = 0; z < nz; z++)
4030             {
4031                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4032                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4033             }
4034         }
4035     }
4036 }
4037
4038 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4039 {
4040     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4041
4042     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4043                                    local_fft_ndata,
4044                                    local_fft_offset,
4045                                    local_fft_size);
4046
4047     dump_grid(stderr,
4048               pme->pmegrid_start_ix,
4049               pme->pmegrid_start_iy,
4050               pme->pmegrid_start_iz,
4051               pme->pmegrid_nx-pme->pme_order+1,
4052               pme->pmegrid_ny-pme->pme_order+1,
4053               pme->pmegrid_nz-pme->pme_order+1,
4054               local_fft_size[YY],
4055               local_fft_size[ZZ],
4056               fftgrid);
4057 }
4058
4059
4060 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4061 {
4062     pme_atomcomm_t *atc;
4063     pmegrids_t *grid;
4064
4065     if (pme->nnodes > 1)
4066     {
4067         gmx_incons("gmx_pme_calc_energy called in parallel");
4068     }
4069     if (pme->bFEP > 1)
4070     {
4071         gmx_incons("gmx_pme_calc_energy with free energy");
4072     }
4073
4074     atc            = &pme->atc_energy;
4075     atc->nthread   = 1;
4076     if (atc->spline == NULL)
4077     {
4078         snew(atc->spline, atc->nthread);
4079     }
4080     atc->nslab     = 1;
4081     atc->bSpread   = TRUE;
4082     atc->pme_order = pme->pme_order;
4083     atc->n         = n;
4084     pme_realloc_atomcomm_things(atc);
4085     atc->x         = x;
4086     atc->q         = q;
4087
4088     /* We only use the A-charges grid */
4089     grid = &pme->pmegridA;
4090
4091     /* Only calculate the spline coefficients, don't actually spread */
4092     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4093
4094     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4095 }
4096
4097
4098 static void reset_pmeonly_counters(gmx_wallcycle_t wcycle,
4099                                    t_nrnb *nrnb, t_inputrec *ir,
4100                                    gmx_large_int_t step)
4101 {
4102     /* Reset all the counters related to performance over the run */
4103     wallcycle_stop(wcycle, ewcRUN);
4104     wallcycle_reset_all(wcycle);
4105     init_nrnb(nrnb);
4106     if (ir->nsteps >= 0)
4107     {
4108         /* ir->nsteps is not used here, but we update it for consistency */
4109         ir->nsteps -= step - ir->init_step;
4110     }
4111     ir->init_step = step;
4112     wallcycle_start(wcycle, ewcRUN);
4113 }
4114
4115
4116 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4117                                ivec grid_size,
4118                                t_commrec *cr, t_inputrec *ir,
4119                                gmx_pme_t *pme_ret)
4120 {
4121     int ind;
4122     gmx_pme_t pme = NULL;
4123
4124     ind = 0;
4125     while (ind < *npmedata)
4126     {
4127         pme = (*pmedata)[ind];
4128         if (pme->nkx == grid_size[XX] &&
4129             pme->nky == grid_size[YY] &&
4130             pme->nkz == grid_size[ZZ])
4131         {
4132             *pme_ret = pme;
4133
4134             return;
4135         }
4136
4137         ind++;
4138     }
4139
4140     (*npmedata)++;
4141     srenew(*pmedata, *npmedata);
4142
4143     /* Generate a new PME data structure, copying part of the old pointers */
4144     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4145
4146     *pme_ret = (*pmedata)[ind];
4147 }
4148
4149
4150 int gmx_pmeonly(gmx_pme_t pme,
4151                 t_commrec *cr,    t_nrnb *nrnb,
4152                 gmx_wallcycle_t wcycle,
4153                 real ewaldcoeff,
4154                 t_inputrec *ir)
4155 {
4156     int npmedata;
4157     gmx_pme_t *pmedata;
4158     gmx_pme_pp_t pme_pp;
4159     int  ret;
4160     int  natoms;
4161     matrix box;
4162     rvec *x_pp      = NULL, *f_pp = NULL;
4163     real *chargeA   = NULL, *chargeB = NULL;
4164     real lambda     = 0;
4165     int  maxshift_x = 0, maxshift_y = 0;
4166     real energy, dvdlambda;
4167     matrix vir;
4168     float cycles;
4169     int  count;
4170     gmx_bool bEnerVir;
4171     gmx_large_int_t step, step_rel;
4172     ivec grid_switch;
4173
4174     /* This data will only use with PME tuning, i.e. switching PME grids */
4175     npmedata = 1;
4176     snew(pmedata, npmedata);
4177     pmedata[0] = pme;
4178
4179     pme_pp = gmx_pme_pp_init(cr);
4180
4181     init_nrnb(nrnb);
4182
4183     count = 0;
4184     do /****** this is a quasi-loop over time steps! */
4185     {
4186         /* The reason for having a loop here is PME grid tuning/switching */
4187         do
4188         {
4189             /* Domain decomposition */
4190             ret = gmx_pme_recv_q_x(pme_pp,
4191                                    &natoms,
4192                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4193                                    &maxshift_x, &maxshift_y,
4194                                    &pme->bFEP, &lambda,
4195                                    &bEnerVir,
4196                                    &step,
4197                                    grid_switch, &ewaldcoeff);
4198
4199             if (ret == pmerecvqxSWITCHGRID)
4200             {
4201                 /* Switch the PME grid to grid_switch */
4202                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4203             }
4204
4205             if (ret == pmerecvqxRESETCOUNTERS)
4206             {
4207                 /* Reset the cycle and flop counters */
4208                 reset_pmeonly_counters(wcycle, nrnb, ir, step);
4209             }
4210         }
4211         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4212
4213         if (ret == pmerecvqxFINISH)
4214         {
4215             /* We should stop: break out of the loop */
4216             break;
4217         }
4218
4219         step_rel = step - ir->init_step;
4220
4221         if (count == 0)
4222         {
4223             wallcycle_start(wcycle, ewcRUN);
4224         }
4225
4226         wallcycle_start(wcycle, ewcPMEMESH);
4227
4228         dvdlambda = 0;
4229         clear_mat(vir);
4230         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4231                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4232                    &energy, lambda, &dvdlambda,
4233                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4234
4235         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4236
4237         gmx_pme_send_force_vir_ener(pme_pp,
4238                                     f_pp, vir, energy, dvdlambda,
4239                                     cycles);
4240
4241         count++;
4242     } /***** end of quasi-loop, we stop with the break above */
4243     while (TRUE);
4244
4245     return 0;
4246 }
4247
4248 int gmx_pme_do(gmx_pme_t pme,
4249                int start,       int homenr,
4250                rvec x[],        rvec f[],
4251                real *chargeA,   real *chargeB,
4252                matrix box, t_commrec *cr,
4253                int  maxshift_x, int maxshift_y,
4254                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4255                matrix vir,      real ewaldcoeff,
4256                real *energy,    real lambda,
4257                real *dvdlambda, int flags)
4258 {
4259     int     q, d, i, j, ntot, npme;
4260     int     nx, ny, nz;
4261     int     n_d, local_ny;
4262     pme_atomcomm_t *atc = NULL;
4263     pmegrids_t *pmegrid = NULL;
4264     real    *grid       = NULL;
4265     real    *ptr;
4266     rvec    *x_d, *f_d;
4267     real    *charge = NULL, *q_d;
4268     real    energy_AB[2];
4269     matrix  vir_AB[2];
4270     gmx_bool bClearF;
4271     gmx_parallel_3dfft_t pfft_setup;
4272     real *  fftgrid;
4273     t_complex * cfftgrid;
4274     int     thread;
4275     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4276     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4277
4278     assert(pme->nnodes > 0);
4279     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4280
4281     if (pme->nnodes > 1)
4282     {
4283         atc      = &pme->atc[0];
4284         atc->npd = homenr;
4285         if (atc->npd > atc->pd_nalloc)
4286         {
4287             atc->pd_nalloc = over_alloc_dd(atc->npd);
4288             srenew(atc->pd, atc->pd_nalloc);
4289         }
4290         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4291     }
4292     else
4293     {
4294         /* This could be necessary for TPI */
4295         pme->atc[0].n = homenr;
4296     }
4297
4298     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4299     {
4300         if (q == 0)
4301         {
4302             pmegrid    = &pme->pmegridA;
4303             fftgrid    = pme->fftgridA;
4304             cfftgrid   = pme->cfftgridA;
4305             pfft_setup = pme->pfft_setupA;
4306             charge     = chargeA+start;
4307         }
4308         else
4309         {
4310             pmegrid    = &pme->pmegridB;
4311             fftgrid    = pme->fftgridB;
4312             cfftgrid   = pme->cfftgridB;
4313             pfft_setup = pme->pfft_setupB;
4314             charge     = chargeB+start;
4315         }
4316         grid = pmegrid->grid.grid;
4317         /* Unpack structure */
4318         if (debug)
4319         {
4320             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4321                     cr->nnodes, cr->nodeid);
4322             fprintf(debug, "Grid = %p\n", (void*)grid);
4323             if (grid == NULL)
4324             {
4325                 gmx_fatal(FARGS, "No grid!");
4326             }
4327         }
4328         where();
4329
4330         m_inv_ur0(box, pme->recipbox);
4331
4332         if (pme->nnodes == 1)
4333         {
4334             atc = &pme->atc[0];
4335             if (DOMAINDECOMP(cr))
4336             {
4337                 atc->n = homenr;
4338                 pme_realloc_atomcomm_things(atc);
4339             }
4340             atc->x = x;
4341             atc->q = charge;
4342             atc->f = f;
4343         }
4344         else
4345         {
4346             wallcycle_start(wcycle, ewcPME_REDISTXF);
4347             for (d = pme->ndecompdim-1; d >= 0; d--)
4348             {
4349                 if (d == pme->ndecompdim-1)
4350                 {
4351                     n_d = homenr;
4352                     x_d = x + start;
4353                     q_d = charge;
4354                 }
4355                 else
4356                 {
4357                     n_d = pme->atc[d+1].n;
4358                     x_d = atc->x;
4359                     q_d = atc->q;
4360                 }
4361                 atc      = &pme->atc[d];
4362                 atc->npd = n_d;
4363                 if (atc->npd > atc->pd_nalloc)
4364                 {
4365                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4366                     srenew(atc->pd, atc->pd_nalloc);
4367                 }
4368                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4369                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4370                 where();
4371
4372                 /* Redistribute x (only once) and qA or qB */
4373                 if (DOMAINDECOMP(cr))
4374                 {
4375                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4376                 }
4377                 else
4378                 {
4379                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4380                 }
4381             }
4382             where();
4383
4384             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4385         }
4386
4387         if (debug)
4388         {
4389             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4390                     cr->nodeid, atc->n);
4391         }
4392
4393         if (flags & GMX_PME_SPREAD_Q)
4394         {
4395             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4396
4397             /* Spread the charges on a grid */
4398             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4399
4400             if (q == 0)
4401             {
4402                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4403             }
4404             inc_nrnb(nrnb, eNR_SPREADQBSP,
4405                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4406
4407             if (!pme->bUseThreads)
4408             {
4409                 wrap_periodic_pmegrid(pme, grid);
4410
4411                 /* sum contributions to local grid from other nodes */
4412 #ifdef GMX_MPI
4413                 if (pme->nnodes > 1)
4414                 {
4415                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4416                     where();
4417                 }
4418 #endif
4419
4420                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4421             }
4422
4423             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4424
4425             /*
4426                dump_local_fftgrid(pme,fftgrid);
4427                exit(0);
4428              */
4429         }
4430
4431         /* Here we start a large thread parallel region */
4432 #pragma omp parallel num_threads(pme->nthread) private(thread)
4433         {
4434             thread = gmx_omp_get_thread_num();
4435             if (flags & GMX_PME_SOLVE)
4436             {
4437                 int loop_count;
4438
4439                 /* do 3d-fft */
4440                 if (thread == 0)
4441                 {
4442                     wallcycle_start(wcycle, ewcPME_FFT);
4443                 }
4444                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4445                                            thread, wcycle);
4446                 if (thread == 0)
4447                 {
4448                     wallcycle_stop(wcycle, ewcPME_FFT);
4449                 }
4450                 where();
4451
4452                 /* solve in k-space for our local cells */
4453                 if (thread == 0)
4454                 {
4455                     wallcycle_start(wcycle, ewcPME_SOLVE);
4456                 }
4457                 loop_count =
4458                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4459                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4460                                   bCalcEnerVir,
4461                                   pme->nthread, thread);
4462                 if (thread == 0)
4463                 {
4464                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4465                     where();
4466                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4467                 }
4468             }
4469
4470             if (bCalcF)
4471             {
4472                 /* do 3d-invfft */
4473                 if (thread == 0)
4474                 {
4475                     where();
4476                     wallcycle_start(wcycle, ewcPME_FFT);
4477                 }
4478                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4479                                            thread, wcycle);
4480                 if (thread == 0)
4481                 {
4482                     wallcycle_stop(wcycle, ewcPME_FFT);
4483
4484                     where();
4485
4486                     if (pme->nodeid == 0)
4487                     {
4488                         ntot  = pme->nkx*pme->nky*pme->nkz;
4489                         npme  = ntot*log((real)ntot)/log(2.0);
4490                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4491                     }
4492
4493                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4494                 }
4495
4496                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4497             }
4498         }
4499         /* End of thread parallel section.
4500          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4501          */
4502
4503         if (bCalcF)
4504         {
4505             /* distribute local grid to all nodes */
4506 #ifdef GMX_MPI
4507             if (pme->nnodes > 1)
4508             {
4509                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4510             }
4511 #endif
4512             where();
4513
4514             unwrap_periodic_pmegrid(pme, grid);
4515
4516             /* interpolate forces for our local atoms */
4517
4518             where();
4519
4520             /* If we are running without parallelization,
4521              * atc->f is the actual force array, not a buffer,
4522              * therefore we should not clear it.
4523              */
4524             bClearF = (q == 0 && PAR(cr));
4525 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4526             for (thread = 0; thread < pme->nthread; thread++)
4527             {
4528                 gather_f_bsplines(pme, grid, bClearF, atc,
4529                                   &atc->spline[thread],
4530                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4531             }
4532
4533             where();
4534
4535             inc_nrnb(nrnb, eNR_GATHERFBSP,
4536                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4537             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4538         }
4539
4540         if (bCalcEnerVir)
4541         {
4542             /* This should only be called on the master thread
4543              * and after the threads have synchronized.
4544              */
4545             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4546         }
4547     } /* of q-loop */
4548
4549     if (bCalcF && pme->nnodes > 1)
4550     {
4551         wallcycle_start(wcycle, ewcPME_REDISTXF);
4552         for (d = 0; d < pme->ndecompdim; d++)
4553         {
4554             atc = &pme->atc[d];
4555             if (d == pme->ndecompdim - 1)
4556             {
4557                 n_d = homenr;
4558                 f_d = f + start;
4559             }
4560             else
4561             {
4562                 n_d = pme->atc[d+1].n;
4563                 f_d = pme->atc[d+1].f;
4564             }
4565             if (DOMAINDECOMP(cr))
4566             {
4567                 dd_pmeredist_f(pme, atc, n_d, f_d,
4568                                d == pme->ndecompdim-1 && pme->bPPnode);
4569             }
4570             else
4571             {
4572                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4573             }
4574         }
4575
4576         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4577     }
4578     where();
4579
4580     if (bCalcEnerVir)
4581     {
4582         if (!pme->bFEP)
4583         {
4584             *energy = energy_AB[0];
4585             m_add(vir, vir_AB[0], vir);
4586         }
4587         else
4588         {
4589             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4590             *dvdlambda += energy_AB[1] - energy_AB[0];
4591             for (i = 0; i < DIM; i++)
4592             {
4593                 for (j = 0; j < DIM; j++)
4594                 {
4595                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4596                         lambda*vir_AB[1][i][j];
4597                 }
4598             }
4599         }
4600     }
4601     else
4602     {
4603         *energy = 0;
4604     }
4605
4606     if (debug)
4607     {
4608         fprintf(debug, "PME mesh energy: %g\n", *energy);
4609     }
4610
4611     return 0;
4612 }