Move FFT routines to src/gromacs/fft/.
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #include "gromacs/fft/parallel_3dfft.h"
64 #include "gromacs/utility/gmxmpi.h"
65
66 #include <stdio.h>
67 #include <string.h>
68 #include <math.h>
69 #include <assert.h>
70 #include "typedefs.h"
71 #include "txtdump.h"
72 #include "vec.h"
73 #include "gmxcomplex.h"
74 #include "smalloc.h"
75 #include "futil.h"
76 #include "coulomb.h"
77 #include "gmx_fatal.h"
78 #include "pme.h"
79 #include "network.h"
80 #include "physics.h"
81 #include "nrnb.h"
82 #include "copyrite.h"
83 #include "gmx_wallcycle.h"
84 #include "pdbio.h"
85 #include "gmx_cyclecounter.h"
86 #include "gmx_omp.h"
87 #include "macros.h"
88
89 /* Single precision, with SSE2 or higher available */
90 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
91
92 #include "gmx_x86_simd_single.h"
93
94 #define PME_SSE
95 /* Some old AMD processors could have problems with unaligned loads+stores */
96 #ifndef GMX_FAHCORE
97 #define PME_SSE_UNALIGNED
98 #endif
99 #endif
100
101 #define DFT_TOL 1e-7
102 /* #define PRT_FORCE */
103 /* conditions for on the fly time-measurement */
104 /* #define TAKETIME (step > 1 && timesteps < 10) */
105 #define TAKETIME FALSE
106
107 /* #define PME_TIME_THREADS */
108
109 #ifdef GMX_DOUBLE
110 #define mpi_type MPI_DOUBLE
111 #else
112 #define mpi_type MPI_FLOAT
113 #endif
114
115 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
116 #define GMX_CACHE_SEP 64
117
118 /* We only define a maximum to be able to use local arrays without allocation.
119  * An order larger than 12 should never be needed, even for test cases.
120  * If needed it can be changed here.
121  */
122 #define PME_ORDER_MAX 12
123
124 /* Internal datastructures */
125 typedef struct {
126     int send_index0;
127     int send_nindex;
128     int recv_index0;
129     int recv_nindex;
130     int recv_size;   /* Receive buffer width, used with OpenMP */
131 } pme_grid_comm_t;
132
133 typedef struct {
134 #ifdef GMX_MPI
135     MPI_Comm         mpi_comm;
136 #endif
137     int              nnodes, nodeid;
138     int             *s2g0;
139     int             *s2g1;
140     int              noverlap_nodes;
141     int             *send_id, *recv_id;
142     int              send_size; /* Send buffer width, used with OpenMP */
143     pme_grid_comm_t *comm_data;
144     real            *sendbuf;
145     real            *recvbuf;
146 } pme_overlap_t;
147
148 typedef struct {
149     int *n;      /* Cumulative counts of the number of particles per thread */
150     int  nalloc; /* Allocation size of i */
151     int *i;      /* Particle indices ordered on thread index (n) */
152 } thread_plist_t;
153
154 typedef struct {
155     int      *thread_one;
156     int       n;
157     int      *ind;
158     splinevec theta;
159     real     *ptr_theta_z;
160     splinevec dtheta;
161     real     *ptr_dtheta_z;
162 } splinedata_t;
163
164 typedef struct {
165     int      dimind;        /* The index of the dimension, 0=x, 1=y */
166     int      nslab;
167     int      nodeid;
168 #ifdef GMX_MPI
169     MPI_Comm mpi_comm;
170 #endif
171
172     int     *node_dest;     /* The nodes to send x and q to with DD */
173     int     *node_src;      /* The nodes to receive x and q from with DD */
174     int     *buf_index;     /* Index for commnode into the buffers */
175
176     int      maxshift;
177
178     int      npd;
179     int      pd_nalloc;
180     int     *pd;
181     int     *count;         /* The number of atoms to send to each node */
182     int    **count_thread;
183     int     *rcount;        /* The number of atoms to receive */
184
185     int      n;
186     int      nalloc;
187     rvec    *x;
188     real    *q;
189     rvec    *f;
190     gmx_bool bSpread;       /* These coordinates are used for spreading */
191     int      pme_order;
192     ivec    *idx;
193     rvec    *fractx;            /* Fractional coordinate relative to the
194                                  * lower cell boundary
195                                  */
196     int             nthread;
197     int            *thread_idx; /* Which thread should spread which charge */
198     thread_plist_t *thread_plist;
199     splinedata_t   *spline;
200 } pme_atomcomm_t;
201
202 #define FLBS  3
203 #define FLBSZ 4
204
205 typedef struct {
206     ivec  ci;     /* The spatial location of this grid         */
207     ivec  n;      /* The used size of *grid, including order-1 */
208     ivec  offset; /* The grid offset from the full node grid   */
209     int   order;  /* PME spreading order                       */
210     ivec  s;      /* The allocated size of *grid, s >= n       */
211     real *grid;   /* The grid local thread, size n             */
212 } pmegrid_t;
213
214 typedef struct {
215     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
216     int        nthread;      /* The number of threads operating on this grid     */
217     ivec       nc;           /* The local spatial decomposition over the threads */
218     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
219     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
220     int      **g2t;          /* The grid to thread index                         */
221     ivec       nthread_comm; /* The number of threads to communicate with        */
222 } pmegrids_t;
223
224
225 typedef struct {
226 #ifdef PME_SSE
227     /* Masks for SSE aligned spreading and gathering */
228     __m128 mask_SSE0[6], mask_SSE1[6];
229 #else
230     int    dummy; /* C89 requires that struct has at least one member */
231 #endif
232 } pme_spline_work_t;
233
234 typedef struct {
235     /* work data for solve_pme */
236     int      nalloc;
237     real *   mhx;
238     real *   mhy;
239     real *   mhz;
240     real *   m2;
241     real *   denom;
242     real *   tmp1_alloc;
243     real *   tmp1;
244     real *   eterm;
245     real *   m2inv;
246
247     real     energy;
248     matrix   vir;
249 } pme_work_t;
250
251 typedef struct gmx_pme {
252     int           ndecompdim; /* The number of decomposition dimensions */
253     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
254     int           nodeid_major;
255     int           nodeid_minor;
256     int           nnodes;    /* The number of nodes doing PME */
257     int           nnodes_major;
258     int           nnodes_minor;
259
260     MPI_Comm      mpi_comm;
261     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
262 #ifdef GMX_MPI
263     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
264 #endif
265
266     gmx_bool   bUseThreads;   /* Does any of the PME ranks have nthread>1 ?  */
267     int        nthread;       /* The number of threads doing PME on our rank */
268
269     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
270     gmx_bool   bFEP;          /* Compute Free energy contribution */
271     int        nkx, nky, nkz; /* Grid dimensions */
272     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
273     int        pme_order;
274     real       epsilon_r;
275
276     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
277     pmegrids_t pmegridB;
278     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
279     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
280     /* pmegrid_nz might be larger than strictly necessary to ensure
281      * memory alignment, pmegrid_nz_base gives the real base size.
282      */
283     int     pmegrid_nz_base;
284     /* The local PME grid starting indices */
285     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
286
287     /* Work data for spreading and gathering */
288     pme_spline_work_t    *spline_work;
289
290     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
291     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
292     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
293
294     t_complex            *cfftgridA;  /* Grids for complex FFT data */
295     t_complex            *cfftgridB;
296     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
297
298     gmx_parallel_3dfft_t  pfft_setupA;
299     gmx_parallel_3dfft_t  pfft_setupB;
300
301     int                  *nnx, *nny, *nnz;
302     real                 *fshx, *fshy, *fshz;
303
304     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
305     matrix                recipbox;
306     splinevec             bsp_mod;
307
308     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
309
310     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
311
312     rvec                 *bufv;       /* Communication buffer */
313     real                 *bufr;       /* Communication buffer */
314     int                   buf_nalloc; /* The communication buffer size */
315
316     /* thread local work data for solve_pme */
317     pme_work_t *work;
318
319     /* Work data for PME_redist */
320     gmx_bool redist_init;
321     int *    scounts;
322     int *    rcounts;
323     int *    sdispls;
324     int *    rdispls;
325     int *    sidx;
326     int *    idxa;
327     real *   redist_buf;
328     int      redist_buf_nalloc;
329
330     /* Work data for sum_qgrid */
331     real *   sum_qgrid_tmp;
332     real *   sum_qgrid_dd_tmp;
333 } t_gmx_pme;
334
335
336 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
337                                    int start, int end, int thread)
338 {
339     int             i;
340     int            *idxptr, tix, tiy, tiz;
341     real           *xptr, *fptr, tx, ty, tz;
342     real            rxx, ryx, ryy, rzx, rzy, rzz;
343     int             nx, ny, nz;
344     int             start_ix, start_iy, start_iz;
345     int            *g2tx, *g2ty, *g2tz;
346     gmx_bool        bThreads;
347     int            *thread_idx = NULL;
348     thread_plist_t *tpl        = NULL;
349     int            *tpl_n      = NULL;
350     int             thread_i;
351
352     nx  = pme->nkx;
353     ny  = pme->nky;
354     nz  = pme->nkz;
355
356     start_ix = pme->pmegrid_start_ix;
357     start_iy = pme->pmegrid_start_iy;
358     start_iz = pme->pmegrid_start_iz;
359
360     rxx = pme->recipbox[XX][XX];
361     ryx = pme->recipbox[YY][XX];
362     ryy = pme->recipbox[YY][YY];
363     rzx = pme->recipbox[ZZ][XX];
364     rzy = pme->recipbox[ZZ][YY];
365     rzz = pme->recipbox[ZZ][ZZ];
366
367     g2tx = pme->pmegridA.g2t[XX];
368     g2ty = pme->pmegridA.g2t[YY];
369     g2tz = pme->pmegridA.g2t[ZZ];
370
371     bThreads = (atc->nthread > 1);
372     if (bThreads)
373     {
374         thread_idx = atc->thread_idx;
375
376         tpl   = &atc->thread_plist[thread];
377         tpl_n = tpl->n;
378         for (i = 0; i < atc->nthread; i++)
379         {
380             tpl_n[i] = 0;
381         }
382     }
383
384     for (i = start; i < end; i++)
385     {
386         xptr   = atc->x[i];
387         idxptr = atc->idx[i];
388         fptr   = atc->fractx[i];
389
390         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
391         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
392         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
393         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
394
395         tix = (int)(tx);
396         tiy = (int)(ty);
397         tiz = (int)(tz);
398
399         /* Because decomposition only occurs in x and y,
400          * we never have a fraction correction in z.
401          */
402         fptr[XX] = tx - tix + pme->fshx[tix];
403         fptr[YY] = ty - tiy + pme->fshy[tiy];
404         fptr[ZZ] = tz - tiz;
405
406         idxptr[XX] = pme->nnx[tix];
407         idxptr[YY] = pme->nny[tiy];
408         idxptr[ZZ] = pme->nnz[tiz];
409
410 #ifdef DEBUG
411         range_check(idxptr[XX], 0, pme->pmegrid_nx);
412         range_check(idxptr[YY], 0, pme->pmegrid_ny);
413         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
414 #endif
415
416         if (bThreads)
417         {
418             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
419             thread_idx[i] = thread_i;
420             tpl_n[thread_i]++;
421         }
422     }
423
424     if (bThreads)
425     {
426         /* Make a list of particle indices sorted on thread */
427
428         /* Get the cumulative count */
429         for (i = 1; i < atc->nthread; i++)
430         {
431             tpl_n[i] += tpl_n[i-1];
432         }
433         /* The current implementation distributes particles equally
434          * over the threads, so we could actually allocate for that
435          * in pme_realloc_atomcomm_things.
436          */
437         if (tpl_n[atc->nthread-1] > tpl->nalloc)
438         {
439             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
440             srenew(tpl->i, tpl->nalloc);
441         }
442         /* Set tpl_n to the cumulative start */
443         for (i = atc->nthread-1; i >= 1; i--)
444         {
445             tpl_n[i] = tpl_n[i-1];
446         }
447         tpl_n[0] = 0;
448
449         /* Fill our thread local array with indices sorted on thread */
450         for (i = start; i < end; i++)
451         {
452             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
453         }
454         /* Now tpl_n contains the cummulative count again */
455     }
456 }
457
458 static void make_thread_local_ind(pme_atomcomm_t *atc,
459                                   int thread, splinedata_t *spline)
460 {
461     int             n, t, i, start, end;
462     thread_plist_t *tpl;
463
464     /* Combine the indices made by each thread into one index */
465
466     n     = 0;
467     start = 0;
468     for (t = 0; t < atc->nthread; t++)
469     {
470         tpl = &atc->thread_plist[t];
471         /* Copy our part (start - end) from the list of thread t */
472         if (thread > 0)
473         {
474             start = tpl->n[thread-1];
475         }
476         end = tpl->n[thread];
477         for (i = start; i < end; i++)
478         {
479             spline->ind[n++] = tpl->i[i];
480         }
481     }
482
483     spline->n = n;
484 }
485
486
487 static void pme_calc_pidx(int start, int end,
488                           matrix recipbox, rvec x[],
489                           pme_atomcomm_t *atc, int *count)
490 {
491     int   nslab, i;
492     int   si;
493     real *xptr, s;
494     real  rxx, ryx, rzx, ryy, rzy;
495     int  *pd;
496
497     /* Calculate PME task index (pidx) for each grid index.
498      * Here we always assign equally sized slabs to each node
499      * for load balancing reasons (the PME grid spacing is not used).
500      */
501
502     nslab = atc->nslab;
503     pd    = atc->pd;
504
505     /* Reset the count */
506     for (i = 0; i < nslab; i++)
507     {
508         count[i] = 0;
509     }
510
511     if (atc->dimind == 0)
512     {
513         rxx = recipbox[XX][XX];
514         ryx = recipbox[YY][XX];
515         rzx = recipbox[ZZ][XX];
516         /* Calculate the node index in x-dimension */
517         for (i = start; i < end; i++)
518         {
519             xptr   = x[i];
520             /* Fractional coordinates along box vectors */
521             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
522             si    = (int)(s + 2*nslab) % nslab;
523             pd[i] = si;
524             count[si]++;
525         }
526     }
527     else
528     {
529         ryy = recipbox[YY][YY];
530         rzy = recipbox[ZZ][YY];
531         /* Calculate the node index in y-dimension */
532         for (i = start; i < end; i++)
533         {
534             xptr   = x[i];
535             /* Fractional coordinates along box vectors */
536             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
537             si    = (int)(s + 2*nslab) % nslab;
538             pd[i] = si;
539             count[si]++;
540         }
541     }
542 }
543
544 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
545                                   pme_atomcomm_t *atc)
546 {
547     int nthread, thread, slab;
548
549     nthread = atc->nthread;
550
551 #pragma omp parallel for num_threads(nthread) schedule(static)
552     for (thread = 0; thread < nthread; thread++)
553     {
554         pme_calc_pidx(natoms* thread   /nthread,
555                       natoms*(thread+1)/nthread,
556                       recipbox, x, atc, atc->count_thread[thread]);
557     }
558     /* Non-parallel reduction, since nslab is small */
559
560     for (thread = 1; thread < nthread; thread++)
561     {
562         for (slab = 0; slab < atc->nslab; slab++)
563         {
564             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
565         }
566     }
567 }
568
569 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
570 {
571     const int padding = 4;
572     int       i;
573
574     srenew(th[XX], nalloc);
575     srenew(th[YY], nalloc);
576     /* In z we add padding, this is only required for the aligned SSE code */
577     srenew(*ptr_z, nalloc+2*padding);
578     th[ZZ] = *ptr_z + padding;
579
580     for (i = 0; i < padding; i++)
581     {
582         (*ptr_z)[               i] = 0;
583         (*ptr_z)[padding+nalloc+i] = 0;
584     }
585 }
586
587 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
588 {
589     int i, d;
590
591     srenew(spline->ind, atc->nalloc);
592     /* Initialize the index to identity so it works without threads */
593     for (i = 0; i < atc->nalloc; i++)
594     {
595         spline->ind[i] = i;
596     }
597
598     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
599                       atc->pme_order*atc->nalloc);
600     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
601                       atc->pme_order*atc->nalloc);
602 }
603
604 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
605 {
606     int nalloc_old, i, j, nalloc_tpl;
607
608     /* We have to avoid a NULL pointer for atc->x to avoid
609      * possible fatal errors in MPI routines.
610      */
611     if (atc->n > atc->nalloc || atc->nalloc == 0)
612     {
613         nalloc_old  = atc->nalloc;
614         atc->nalloc = over_alloc_dd(max(atc->n, 1));
615
616         if (atc->nslab > 1)
617         {
618             srenew(atc->x, atc->nalloc);
619             srenew(atc->q, atc->nalloc);
620             srenew(atc->f, atc->nalloc);
621             for (i = nalloc_old; i < atc->nalloc; i++)
622             {
623                 clear_rvec(atc->f[i]);
624             }
625         }
626         if (atc->bSpread)
627         {
628             srenew(atc->fractx, atc->nalloc);
629             srenew(atc->idx, atc->nalloc);
630
631             if (atc->nthread > 1)
632             {
633                 srenew(atc->thread_idx, atc->nalloc);
634             }
635
636             for (i = 0; i < atc->nthread; i++)
637             {
638                 pme_realloc_splinedata(&atc->spline[i], atc);
639             }
640         }
641     }
642 }
643
644 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
645                          int n, gmx_bool bXF, rvec *x_f, real *charge,
646                          pme_atomcomm_t *atc)
647 /* Redistribute particle data for PME calculation */
648 /* domain decomposition by x coordinate           */
649 {
650     int *idxa;
651     int  i, ii;
652
653     if (FALSE == pme->redist_init)
654     {
655         snew(pme->scounts, atc->nslab);
656         snew(pme->rcounts, atc->nslab);
657         snew(pme->sdispls, atc->nslab);
658         snew(pme->rdispls, atc->nslab);
659         snew(pme->sidx, atc->nslab);
660         pme->redist_init = TRUE;
661     }
662     if (n > pme->redist_buf_nalloc)
663     {
664         pme->redist_buf_nalloc = over_alloc_dd(n);
665         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
666     }
667
668     pme->idxa = atc->pd;
669
670 #ifdef GMX_MPI
671     if (forw && bXF)
672     {
673         /* forward, redistribution from pp to pme */
674
675         /* Calculate send counts and exchange them with other nodes */
676         for (i = 0; (i < atc->nslab); i++)
677         {
678             pme->scounts[i] = 0;
679         }
680         for (i = 0; (i < n); i++)
681         {
682             pme->scounts[pme->idxa[i]]++;
683         }
684         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
685
686         /* Calculate send and receive displacements and index into send
687            buffer */
688         pme->sdispls[0] = 0;
689         pme->rdispls[0] = 0;
690         pme->sidx[0]    = 0;
691         for (i = 1; i < atc->nslab; i++)
692         {
693             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
694             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
695             pme->sidx[i]    = pme->sdispls[i];
696         }
697         /* Total # of particles to be received */
698         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
699
700         pme_realloc_atomcomm_things(atc);
701
702         /* Copy particle coordinates into send buffer and exchange*/
703         for (i = 0; (i < n); i++)
704         {
705             ii = DIM*pme->sidx[pme->idxa[i]];
706             pme->sidx[pme->idxa[i]]++;
707             pme->redist_buf[ii+XX] = x_f[i][XX];
708             pme->redist_buf[ii+YY] = x_f[i][YY];
709             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
710         }
711         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
712                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
713                       pme->rvec_mpi, atc->mpi_comm);
714     }
715     if (forw)
716     {
717         /* Copy charge into send buffer and exchange*/
718         for (i = 0; i < atc->nslab; i++)
719         {
720             pme->sidx[i] = pme->sdispls[i];
721         }
722         for (i = 0; (i < n); i++)
723         {
724             ii = pme->sidx[pme->idxa[i]];
725             pme->sidx[pme->idxa[i]]++;
726             pme->redist_buf[ii] = charge[i];
727         }
728         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
729                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
730                       atc->mpi_comm);
731     }
732     else   /* backward, redistribution from pme to pp */
733     {
734         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
735                       pme->redist_buf, pme->scounts, pme->sdispls,
736                       pme->rvec_mpi, atc->mpi_comm);
737
738         /* Copy data from receive buffer */
739         for (i = 0; i < atc->nslab; i++)
740         {
741             pme->sidx[i] = pme->sdispls[i];
742         }
743         for (i = 0; (i < n); i++)
744         {
745             ii          = DIM*pme->sidx[pme->idxa[i]];
746             x_f[i][XX] += pme->redist_buf[ii+XX];
747             x_f[i][YY] += pme->redist_buf[ii+YY];
748             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
749             pme->sidx[pme->idxa[i]]++;
750         }
751     }
752 #endif
753 }
754
755 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
756                             gmx_bool bBackward, int shift,
757                             void *buf_s, int nbyte_s,
758                             void *buf_r, int nbyte_r)
759 {
760 #ifdef GMX_MPI
761     int        dest, src;
762     MPI_Status stat;
763
764     if (bBackward == FALSE)
765     {
766         dest = atc->node_dest[shift];
767         src  = atc->node_src[shift];
768     }
769     else
770     {
771         dest = atc->node_src[shift];
772         src  = atc->node_dest[shift];
773     }
774
775     if (nbyte_s > 0 && nbyte_r > 0)
776     {
777         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
778                      dest, shift,
779                      buf_r, nbyte_r, MPI_BYTE,
780                      src, shift,
781                      atc->mpi_comm, &stat);
782     }
783     else if (nbyte_s > 0)
784     {
785         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
786                  dest, shift,
787                  atc->mpi_comm);
788     }
789     else if (nbyte_r > 0)
790     {
791         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
792                  src, shift,
793                  atc->mpi_comm, &stat);
794     }
795 #endif
796 }
797
798 static void dd_pmeredist_x_q(gmx_pme_t pme,
799                              int n, gmx_bool bX, rvec *x, real *charge,
800                              pme_atomcomm_t *atc)
801 {
802     int *commnode, *buf_index;
803     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
804
805     commnode  = atc->node_dest;
806     buf_index = atc->buf_index;
807
808     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
809
810     nsend = 0;
811     for (i = 0; i < nnodes_comm; i++)
812     {
813         buf_index[commnode[i]] = nsend;
814         nsend                 += atc->count[commnode[i]];
815     }
816     if (bX)
817     {
818         if (atc->count[atc->nodeid] + nsend != n)
819         {
820             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
821                       "This usually means that your system is not well equilibrated.",
822                       n - (atc->count[atc->nodeid] + nsend),
823                       pme->nodeid, 'x'+atc->dimind);
824         }
825
826         if (nsend > pme->buf_nalloc)
827         {
828             pme->buf_nalloc = over_alloc_dd(nsend);
829             srenew(pme->bufv, pme->buf_nalloc);
830             srenew(pme->bufr, pme->buf_nalloc);
831         }
832
833         atc->n = atc->count[atc->nodeid];
834         for (i = 0; i < nnodes_comm; i++)
835         {
836             scount = atc->count[commnode[i]];
837             /* Communicate the count */
838             if (debug)
839             {
840                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
841                         atc->dimind, atc->nodeid, commnode[i], scount);
842             }
843             pme_dd_sendrecv(atc, FALSE, i,
844                             &scount, sizeof(int),
845                             &atc->rcount[i], sizeof(int));
846             atc->n += atc->rcount[i];
847         }
848
849         pme_realloc_atomcomm_things(atc);
850     }
851
852     local_pos = 0;
853     for (i = 0; i < n; i++)
854     {
855         node = atc->pd[i];
856         if (node == atc->nodeid)
857         {
858             /* Copy direct to the receive buffer */
859             if (bX)
860             {
861                 copy_rvec(x[i], atc->x[local_pos]);
862             }
863             atc->q[local_pos] = charge[i];
864             local_pos++;
865         }
866         else
867         {
868             /* Copy to the send buffer */
869             if (bX)
870             {
871                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
872             }
873             pme->bufr[buf_index[node]] = charge[i];
874             buf_index[node]++;
875         }
876     }
877
878     buf_pos = 0;
879     for (i = 0; i < nnodes_comm; i++)
880     {
881         scount = atc->count[commnode[i]];
882         rcount = atc->rcount[i];
883         if (scount > 0 || rcount > 0)
884         {
885             if (bX)
886             {
887                 /* Communicate the coordinates */
888                 pme_dd_sendrecv(atc, FALSE, i,
889                                 pme->bufv[buf_pos], scount*sizeof(rvec),
890                                 atc->x[local_pos], rcount*sizeof(rvec));
891             }
892             /* Communicate the charges */
893             pme_dd_sendrecv(atc, FALSE, i,
894                             pme->bufr+buf_pos, scount*sizeof(real),
895                             atc->q+local_pos, rcount*sizeof(real));
896             buf_pos   += scount;
897             local_pos += atc->rcount[i];
898         }
899     }
900 }
901
902 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
903                            int n, rvec *f,
904                            gmx_bool bAddF)
905 {
906     int *commnode, *buf_index;
907     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
908
909     commnode  = atc->node_dest;
910     buf_index = atc->buf_index;
911
912     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
913
914     local_pos = atc->count[atc->nodeid];
915     buf_pos   = 0;
916     for (i = 0; i < nnodes_comm; i++)
917     {
918         scount = atc->rcount[i];
919         rcount = atc->count[commnode[i]];
920         if (scount > 0 || rcount > 0)
921         {
922             /* Communicate the forces */
923             pme_dd_sendrecv(atc, TRUE, i,
924                             atc->f[local_pos], scount*sizeof(rvec),
925                             pme->bufv[buf_pos], rcount*sizeof(rvec));
926             local_pos += scount;
927         }
928         buf_index[commnode[i]] = buf_pos;
929         buf_pos               += rcount;
930     }
931
932     local_pos = 0;
933     if (bAddF)
934     {
935         for (i = 0; i < n; i++)
936         {
937             node = atc->pd[i];
938             if (node == atc->nodeid)
939             {
940                 /* Add from the local force array */
941                 rvec_inc(f[i], atc->f[local_pos]);
942                 local_pos++;
943             }
944             else
945             {
946                 /* Add from the receive buffer */
947                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
948                 buf_index[node]++;
949             }
950         }
951     }
952     else
953     {
954         for (i = 0; i < n; i++)
955         {
956             node = atc->pd[i];
957             if (node == atc->nodeid)
958             {
959                 /* Copy from the local force array */
960                 copy_rvec(atc->f[local_pos], f[i]);
961                 local_pos++;
962             }
963             else
964             {
965                 /* Copy from the receive buffer */
966                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
967                 buf_index[node]++;
968             }
969         }
970     }
971 }
972
973 #ifdef GMX_MPI
974 static void
975 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
976 {
977     pme_overlap_t *overlap;
978     int            send_index0, send_nindex;
979     int            recv_index0, recv_nindex;
980     MPI_Status     stat;
981     int            i, j, k, ix, iy, iz, icnt;
982     int            ipulse, send_id, recv_id, datasize;
983     real          *p;
984     real          *sendptr, *recvptr;
985
986     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
987     overlap = &pme->overlap[1];
988
989     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
990     {
991         /* Since we have already (un)wrapped the overlap in the z-dimension,
992          * we only have to communicate 0 to nkz (not pmegrid_nz).
993          */
994         if (direction == GMX_SUM_QGRID_FORWARD)
995         {
996             send_id       = overlap->send_id[ipulse];
997             recv_id       = overlap->recv_id[ipulse];
998             send_index0   = overlap->comm_data[ipulse].send_index0;
999             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1000             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1001             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1002         }
1003         else
1004         {
1005             send_id       = overlap->recv_id[ipulse];
1006             recv_id       = overlap->send_id[ipulse];
1007             send_index0   = overlap->comm_data[ipulse].recv_index0;
1008             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1009             recv_index0   = overlap->comm_data[ipulse].send_index0;
1010             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1011         }
1012
1013         /* Copy data to contiguous send buffer */
1014         if (debug)
1015         {
1016             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1017                     pme->nodeid, overlap->nodeid, send_id,
1018                     pme->pmegrid_start_iy,
1019                     send_index0-pme->pmegrid_start_iy,
1020                     send_index0-pme->pmegrid_start_iy+send_nindex);
1021         }
1022         icnt = 0;
1023         for (i = 0; i < pme->pmegrid_nx; i++)
1024         {
1025             ix = i;
1026             for (j = 0; j < send_nindex; j++)
1027             {
1028                 iy = j + send_index0 - pme->pmegrid_start_iy;
1029                 for (k = 0; k < pme->nkz; k++)
1030                 {
1031                     iz = k;
1032                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1033                 }
1034             }
1035         }
1036
1037         datasize      = pme->pmegrid_nx * pme->nkz;
1038
1039         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1040                      send_id, ipulse,
1041                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1042                      recv_id, ipulse,
1043                      overlap->mpi_comm, &stat);
1044
1045         /* Get data from contiguous recv buffer */
1046         if (debug)
1047         {
1048             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1049                     pme->nodeid, overlap->nodeid, recv_id,
1050                     pme->pmegrid_start_iy,
1051                     recv_index0-pme->pmegrid_start_iy,
1052                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1053         }
1054         icnt = 0;
1055         for (i = 0; i < pme->pmegrid_nx; i++)
1056         {
1057             ix = i;
1058             for (j = 0; j < recv_nindex; j++)
1059             {
1060                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1061                 for (k = 0; k < pme->nkz; k++)
1062                 {
1063                     iz = k;
1064                     if (direction == GMX_SUM_QGRID_FORWARD)
1065                     {
1066                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1067                     }
1068                     else
1069                     {
1070                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1071                     }
1072                 }
1073             }
1074         }
1075     }
1076
1077     /* Major dimension is easier, no copying required,
1078      * but we might have to sum to separate array.
1079      * Since we don't copy, we have to communicate up to pmegrid_nz,
1080      * not nkz as for the minor direction.
1081      */
1082     overlap = &pme->overlap[0];
1083
1084     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1085     {
1086         if (direction == GMX_SUM_QGRID_FORWARD)
1087         {
1088             send_id       = overlap->send_id[ipulse];
1089             recv_id       = overlap->recv_id[ipulse];
1090             send_index0   = overlap->comm_data[ipulse].send_index0;
1091             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1092             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1093             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1094             recvptr       = overlap->recvbuf;
1095         }
1096         else
1097         {
1098             send_id       = overlap->recv_id[ipulse];
1099             recv_id       = overlap->send_id[ipulse];
1100             send_index0   = overlap->comm_data[ipulse].recv_index0;
1101             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1102             recv_index0   = overlap->comm_data[ipulse].send_index0;
1103             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1104             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1105         }
1106
1107         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1108         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1109
1110         if (debug)
1111         {
1112             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1113                     pme->nodeid, overlap->nodeid, send_id,
1114                     pme->pmegrid_start_ix,
1115                     send_index0-pme->pmegrid_start_ix,
1116                     send_index0-pme->pmegrid_start_ix+send_nindex);
1117             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1118                     pme->nodeid, overlap->nodeid, recv_id,
1119                     pme->pmegrid_start_ix,
1120                     recv_index0-pme->pmegrid_start_ix,
1121                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1122         }
1123
1124         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1125                      send_id, ipulse,
1126                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1127                      recv_id, ipulse,
1128                      overlap->mpi_comm, &stat);
1129
1130         /* ADD data from contiguous recv buffer */
1131         if (direction == GMX_SUM_QGRID_FORWARD)
1132         {
1133             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1134             for (i = 0; i < recv_nindex*datasize; i++)
1135             {
1136                 p[i] += overlap->recvbuf[i];
1137             }
1138         }
1139     }
1140 }
1141 #endif
1142
1143
1144 static int
1145 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1146 {
1147     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1148     ivec    local_pme_size;
1149     int     i, ix, iy, iz;
1150     int     pmeidx, fftidx;
1151
1152     /* Dimensions should be identical for A/B grid, so we just use A here */
1153     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1154                                    local_fft_ndata,
1155                                    local_fft_offset,
1156                                    local_fft_size);
1157
1158     local_pme_size[0] = pme->pmegrid_nx;
1159     local_pme_size[1] = pme->pmegrid_ny;
1160     local_pme_size[2] = pme->pmegrid_nz;
1161
1162     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1163        the offset is identical, and the PME grid always has more data (due to overlap)
1164      */
1165     {
1166 #ifdef DEBUG_PME
1167         FILE *fp, *fp2;
1168         char  fn[STRLEN], format[STRLEN];
1169         real  val;
1170         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1171         fp = ffopen(fn, "w");
1172         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1173         fp2 = ffopen(fn, "w");
1174         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1175 #endif
1176
1177         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1178         {
1179             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1180             {
1181                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1182                 {
1183                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1184                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1185                     fftgrid[fftidx] = pmegrid[pmeidx];
1186 #ifdef DEBUG_PME
1187                     val = 100*pmegrid[pmeidx];
1188                     if (pmegrid[pmeidx] != 0)
1189                     {
1190                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1191                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1192                     }
1193                     if (pmegrid[pmeidx] != 0)
1194                     {
1195                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1196                                 "qgrid",
1197                                 pme->pmegrid_start_ix + ix,
1198                                 pme->pmegrid_start_iy + iy,
1199                                 pme->pmegrid_start_iz + iz,
1200                                 pmegrid[pmeidx]);
1201                     }
1202 #endif
1203                 }
1204             }
1205         }
1206 #ifdef DEBUG_PME
1207         ffclose(fp);
1208         ffclose(fp2);
1209 #endif
1210     }
1211     return 0;
1212 }
1213
1214
1215 static gmx_cycles_t omp_cyc_start()
1216 {
1217     return gmx_cycles_read();
1218 }
1219
1220 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1221 {
1222     return gmx_cycles_read() - c;
1223 }
1224
1225
1226 static int
1227 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1228                         int nthread, int thread)
1229 {
1230     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1231     ivec          local_pme_size;
1232     int           ixy0, ixy1, ixy, ix, iy, iz;
1233     int           pmeidx, fftidx;
1234 #ifdef PME_TIME_THREADS
1235     gmx_cycles_t  c1;
1236     static double cs1 = 0;
1237     static int    cnt = 0;
1238 #endif
1239
1240 #ifdef PME_TIME_THREADS
1241     c1 = omp_cyc_start();
1242 #endif
1243     /* Dimensions should be identical for A/B grid, so we just use A here */
1244     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1245                                    local_fft_ndata,
1246                                    local_fft_offset,
1247                                    local_fft_size);
1248
1249     local_pme_size[0] = pme->pmegrid_nx;
1250     local_pme_size[1] = pme->pmegrid_ny;
1251     local_pme_size[2] = pme->pmegrid_nz;
1252
1253     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1254        the offset is identical, and the PME grid always has more data (due to overlap)
1255      */
1256     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1257     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1258
1259     for (ixy = ixy0; ixy < ixy1; ixy++)
1260     {
1261         ix = ixy/local_fft_ndata[YY];
1262         iy = ixy - ix*local_fft_ndata[YY];
1263
1264         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1265         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1266         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1267         {
1268             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1269         }
1270     }
1271
1272 #ifdef PME_TIME_THREADS
1273     c1   = omp_cyc_end(c1);
1274     cs1 += (double)c1;
1275     cnt++;
1276     if (cnt % 20 == 0)
1277     {
1278         printf("copy %.2f\n", cs1*1e-9);
1279     }
1280 #endif
1281
1282     return 0;
1283 }
1284
1285
1286 static void
1287 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1288 {
1289     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1290
1291     nx = pme->nkx;
1292     ny = pme->nky;
1293     nz = pme->nkz;
1294
1295     pnx = pme->pmegrid_nx;
1296     pny = pme->pmegrid_ny;
1297     pnz = pme->pmegrid_nz;
1298
1299     overlap = pme->pme_order - 1;
1300
1301     /* Add periodic overlap in z */
1302     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1303     {
1304         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1305         {
1306             for (iz = 0; iz < overlap; iz++)
1307             {
1308                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1309                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1310             }
1311         }
1312     }
1313
1314     if (pme->nnodes_minor == 1)
1315     {
1316         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1317         {
1318             for (iy = 0; iy < overlap; iy++)
1319             {
1320                 for (iz = 0; iz < nz; iz++)
1321                 {
1322                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1323                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1324                 }
1325             }
1326         }
1327     }
1328
1329     if (pme->nnodes_major == 1)
1330     {
1331         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1332
1333         for (ix = 0; ix < overlap; ix++)
1334         {
1335             for (iy = 0; iy < ny_x; iy++)
1336             {
1337                 for (iz = 0; iz < nz; iz++)
1338                 {
1339                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1340                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1341                 }
1342             }
1343         }
1344     }
1345 }
1346
1347
1348 static void
1349 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1350 {
1351     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1352
1353     nx = pme->nkx;
1354     ny = pme->nky;
1355     nz = pme->nkz;
1356
1357     pnx = pme->pmegrid_nx;
1358     pny = pme->pmegrid_ny;
1359     pnz = pme->pmegrid_nz;
1360
1361     overlap = pme->pme_order - 1;
1362
1363     if (pme->nnodes_major == 1)
1364     {
1365         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1366
1367         for (ix = 0; ix < overlap; ix++)
1368         {
1369             int iy, iz;
1370
1371             for (iy = 0; iy < ny_x; iy++)
1372             {
1373                 for (iz = 0; iz < nz; iz++)
1374                 {
1375                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1376                         pmegrid[(ix*pny+iy)*pnz+iz];
1377                 }
1378             }
1379         }
1380     }
1381
1382     if (pme->nnodes_minor == 1)
1383     {
1384 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1385         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1386         {
1387             int iy, iz;
1388
1389             for (iy = 0; iy < overlap; iy++)
1390             {
1391                 for (iz = 0; iz < nz; iz++)
1392                 {
1393                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1394                         pmegrid[(ix*pny+iy)*pnz+iz];
1395                 }
1396             }
1397         }
1398     }
1399
1400     /* Copy periodic overlap in z */
1401 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1402     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1403     {
1404         int iy, iz;
1405
1406         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1407         {
1408             for (iz = 0; iz < overlap; iz++)
1409             {
1410                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1411                     pmegrid[(ix*pny+iy)*pnz+iz];
1412             }
1413         }
1414     }
1415 }
1416
1417 static void clear_grid(int nx, int ny, int nz, real *grid,
1418                        ivec fs, int *flag,
1419                        int fx, int fy, int fz,
1420                        int order)
1421 {
1422     int nc, ncz;
1423     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1424     int flind;
1425
1426     nc  = 2 + (order - 2)/FLBS;
1427     ncz = 2 + (order - 2)/FLBSZ;
1428
1429     for (fsx = fx; fsx < fx+nc; fsx++)
1430     {
1431         for (fsy = fy; fsy < fy+nc; fsy++)
1432         {
1433             for (fsz = fz; fsz < fz+ncz; fsz++)
1434             {
1435                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1436                 if (flag[flind] == 0)
1437                 {
1438                     gx  = fsx*FLBS;
1439                     gy  = fsy*FLBS;
1440                     gz  = fsz*FLBSZ;
1441                     g0x = (gx*ny + gy)*nz + gz;
1442                     for (x = 0; x < FLBS; x++)
1443                     {
1444                         g0y = g0x;
1445                         for (y = 0; y < FLBS; y++)
1446                         {
1447                             for (z = 0; z < FLBSZ; z++)
1448                             {
1449                                 grid[g0y+z] = 0;
1450                             }
1451                             g0y += nz;
1452                         }
1453                         g0x += ny*nz;
1454                     }
1455
1456                     flag[flind] = 1;
1457                 }
1458             }
1459         }
1460     }
1461 }
1462
1463 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1464 #define DO_BSPLINE(order)                            \
1465     for (ithx = 0; (ithx < order); ithx++)                    \
1466     {                                                    \
1467         index_x = (i0+ithx)*pny*pnz;                     \
1468         valx    = qn*thx[ithx];                          \
1469                                                      \
1470         for (ithy = 0; (ithy < order); ithy++)                \
1471         {                                                \
1472             valxy    = valx*thy[ithy];                   \
1473             index_xy = index_x+(j0+ithy)*pnz;            \
1474                                                      \
1475             for (ithz = 0; (ithz < order); ithz++)            \
1476             {                                            \
1477                 index_xyz        = index_xy+(k0+ithz);   \
1478                 grid[index_xyz] += valxy*thz[ithz];      \
1479             }                                            \
1480         }                                                \
1481     }
1482
1483
1484 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1485                                      pme_atomcomm_t *atc, splinedata_t *spline,
1486                                      pme_spline_work_t *work)
1487 {
1488
1489     /* spread charges from home atoms to local grid */
1490     real          *grid;
1491     pme_overlap_t *ol;
1492     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1493     int       *    idxptr;
1494     int            order, norder, index_x, index_xy, index_xyz;
1495     real           valx, valxy, qn;
1496     real          *thx, *thy, *thz;
1497     int            localsize, bndsize;
1498     int            pnx, pny, pnz, ndatatot;
1499     int            offx, offy, offz;
1500
1501     pnx = pmegrid->s[XX];
1502     pny = pmegrid->s[YY];
1503     pnz = pmegrid->s[ZZ];
1504
1505     offx = pmegrid->offset[XX];
1506     offy = pmegrid->offset[YY];
1507     offz = pmegrid->offset[ZZ];
1508
1509     ndatatot = pnx*pny*pnz;
1510     grid     = pmegrid->grid;
1511     for (i = 0; i < ndatatot; i++)
1512     {
1513         grid[i] = 0;
1514     }
1515
1516     order = pmegrid->order;
1517
1518     for (nn = 0; nn < spline->n; nn++)
1519     {
1520         n  = spline->ind[nn];
1521         qn = atc->q[n];
1522
1523         if (qn != 0)
1524         {
1525             idxptr = atc->idx[n];
1526             norder = nn*order;
1527
1528             i0   = idxptr[XX] - offx;
1529             j0   = idxptr[YY] - offy;
1530             k0   = idxptr[ZZ] - offz;
1531
1532             thx = spline->theta[XX] + norder;
1533             thy = spline->theta[YY] + norder;
1534             thz = spline->theta[ZZ] + norder;
1535
1536             switch (order)
1537             {
1538                 case 4:
1539 #ifdef PME_SSE
1540 #ifdef PME_SSE_UNALIGNED
1541 #define PME_SPREAD_SSE_ORDER4
1542 #else
1543 #define PME_SPREAD_SSE_ALIGNED
1544 #define PME_ORDER 4
1545 #endif
1546 #include "pme_sse_single.h"
1547 #else
1548                     DO_BSPLINE(4);
1549 #endif
1550                     break;
1551                 case 5:
1552 #ifdef PME_SSE
1553 #define PME_SPREAD_SSE_ALIGNED
1554 #define PME_ORDER 5
1555 #include "pme_sse_single.h"
1556 #else
1557                     DO_BSPLINE(5);
1558 #endif
1559                     break;
1560                 default:
1561                     DO_BSPLINE(order);
1562                     break;
1563             }
1564         }
1565     }
1566 }
1567
1568 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1569 {
1570 #ifdef PME_SSE
1571     if (pme_order == 5
1572 #ifndef PME_SSE_UNALIGNED
1573         || pme_order == 4
1574 #endif
1575         )
1576     {
1577         /* Round nz up to a multiple of 4 to ensure alignment */
1578         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1579     }
1580 #endif
1581 }
1582
1583 static void set_gridsize_alignment(int *gridsize, int pme_order)
1584 {
1585 #ifdef PME_SSE
1586 #ifndef PME_SSE_UNALIGNED
1587     if (pme_order == 4)
1588     {
1589         /* Add extra elements to ensured aligned operations do not go
1590          * beyond the allocated grid size.
1591          * Note that for pme_order=5, the pme grid z-size alignment
1592          * ensures that we will not go beyond the grid size.
1593          */
1594         *gridsize += 4;
1595     }
1596 #endif
1597 #endif
1598 }
1599
1600 static void pmegrid_init(pmegrid_t *grid,
1601                          int cx, int cy, int cz,
1602                          int x0, int y0, int z0,
1603                          int x1, int y1, int z1,
1604                          gmx_bool set_alignment,
1605                          int pme_order,
1606                          real *ptr)
1607 {
1608     int nz, gridsize;
1609
1610     grid->ci[XX]     = cx;
1611     grid->ci[YY]     = cy;
1612     grid->ci[ZZ]     = cz;
1613     grid->offset[XX] = x0;
1614     grid->offset[YY] = y0;
1615     grid->offset[ZZ] = z0;
1616     grid->n[XX]      = x1 - x0 + pme_order - 1;
1617     grid->n[YY]      = y1 - y0 + pme_order - 1;
1618     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1619     copy_ivec(grid->n, grid->s);
1620
1621     nz = grid->s[ZZ];
1622     set_grid_alignment(&nz, pme_order);
1623     if (set_alignment)
1624     {
1625         grid->s[ZZ] = nz;
1626     }
1627     else if (nz != grid->s[ZZ])
1628     {
1629         gmx_incons("pmegrid_init call with an unaligned z size");
1630     }
1631
1632     grid->order = pme_order;
1633     if (ptr == NULL)
1634     {
1635         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1636         set_gridsize_alignment(&gridsize, pme_order);
1637         snew_aligned(grid->grid, gridsize, 16);
1638     }
1639     else
1640     {
1641         grid->grid = ptr;
1642     }
1643 }
1644
1645 static int div_round_up(int enumerator, int denominator)
1646 {
1647     return (enumerator + denominator - 1)/denominator;
1648 }
1649
1650 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1651                                   ivec nsub)
1652 {
1653     int gsize_opt, gsize;
1654     int nsx, nsy, nsz;
1655     char *env;
1656
1657     gsize_opt = -1;
1658     for (nsx = 1; nsx <= nthread; nsx++)
1659     {
1660         if (nthread % nsx == 0)
1661         {
1662             for (nsy = 1; nsy <= nthread; nsy++)
1663             {
1664                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1665                 {
1666                     nsz = nthread/(nsx*nsy);
1667
1668                     /* Determine the number of grid points per thread */
1669                     gsize =
1670                         (div_round_up(n[XX], nsx) + ovl)*
1671                         (div_round_up(n[YY], nsy) + ovl)*
1672                         (div_round_up(n[ZZ], nsz) + ovl);
1673
1674                     /* Minimize the number of grids points per thread
1675                      * and, secondarily, the number of cuts in minor dimensions.
1676                      */
1677                     if (gsize_opt == -1 ||
1678                         gsize < gsize_opt ||
1679                         (gsize == gsize_opt &&
1680                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1681                     {
1682                         nsub[XX]  = nsx;
1683                         nsub[YY]  = nsy;
1684                         nsub[ZZ]  = nsz;
1685                         gsize_opt = gsize;
1686                     }
1687                 }
1688             }
1689         }
1690     }
1691
1692     env = getenv("GMX_PME_THREAD_DIVISION");
1693     if (env != NULL)
1694     {
1695         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1696     }
1697
1698     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1699     {
1700         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1701     }
1702 }
1703
1704 static void pmegrids_init(pmegrids_t *grids,
1705                           int nx, int ny, int nz, int nz_base,
1706                           int pme_order,
1707                           gmx_bool bUseThreads,
1708                           int nthread,
1709                           int overlap_x,
1710                           int overlap_y)
1711 {
1712     ivec n, n_base, g0, g1;
1713     int t, x, y, z, d, i, tfac;
1714     int max_comm_lines = -1;
1715
1716     n[XX] = nx - (pme_order - 1);
1717     n[YY] = ny - (pme_order - 1);
1718     n[ZZ] = nz - (pme_order - 1);
1719
1720     copy_ivec(n, n_base);
1721     n_base[ZZ] = nz_base;
1722
1723     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1724                  NULL);
1725
1726     grids->nthread = nthread;
1727
1728     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1729
1730     if (bUseThreads)
1731     {
1732         ivec nst;
1733         int gridsize;
1734
1735         for (d = 0; d < DIM; d++)
1736         {
1737             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1738         }
1739         set_grid_alignment(&nst[ZZ], pme_order);
1740
1741         if (debug)
1742         {
1743             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1744                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1745             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1746                     nx, ny, nz,
1747                     nst[XX], nst[YY], nst[ZZ]);
1748         }
1749
1750         snew(grids->grid_th, grids->nthread);
1751         t        = 0;
1752         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1753         set_gridsize_alignment(&gridsize, pme_order);
1754         snew_aligned(grids->grid_all,
1755                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1756                      16);
1757
1758         for (x = 0; x < grids->nc[XX]; x++)
1759         {
1760             for (y = 0; y < grids->nc[YY]; y++)
1761             {
1762                 for (z = 0; z < grids->nc[ZZ]; z++)
1763                 {
1764                     pmegrid_init(&grids->grid_th[t],
1765                                  x, y, z,
1766                                  (n[XX]*(x  ))/grids->nc[XX],
1767                                  (n[YY]*(y  ))/grids->nc[YY],
1768                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1769                                  (n[XX]*(x+1))/grids->nc[XX],
1770                                  (n[YY]*(y+1))/grids->nc[YY],
1771                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1772                                  TRUE,
1773                                  pme_order,
1774                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1775                     t++;
1776                 }
1777             }
1778         }
1779     }
1780     else
1781     {
1782         grids->grid_th = NULL;
1783     }
1784
1785     snew(grids->g2t, DIM);
1786     tfac = 1;
1787     for (d = DIM-1; d >= 0; d--)
1788     {
1789         snew(grids->g2t[d], n[d]);
1790         t = 0;
1791         for (i = 0; i < n[d]; i++)
1792         {
1793             /* The second check should match the parameters
1794              * of the pmegrid_init call above.
1795              */
1796             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1797             {
1798                 t++;
1799             }
1800             grids->g2t[d][i] = t*tfac;
1801         }
1802
1803         tfac *= grids->nc[d];
1804
1805         switch (d)
1806         {
1807             case XX: max_comm_lines = overlap_x;     break;
1808             case YY: max_comm_lines = overlap_y;     break;
1809             case ZZ: max_comm_lines = pme_order - 1; break;
1810         }
1811         grids->nthread_comm[d] = 0;
1812         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1813                grids->nthread_comm[d] < grids->nc[d])
1814         {
1815             grids->nthread_comm[d]++;
1816         }
1817         if (debug != NULL)
1818         {
1819             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1820                     'x'+d, grids->nthread_comm[d]);
1821         }
1822         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1823          * work, but this is not a problematic restriction.
1824          */
1825         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1826         {
1827             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1828         }
1829     }
1830 }
1831
1832
1833 static void pmegrids_destroy(pmegrids_t *grids)
1834 {
1835     int t;
1836
1837     if (grids->grid.grid != NULL)
1838     {
1839         sfree(grids->grid.grid);
1840
1841         if (grids->nthread > 0)
1842         {
1843             for (t = 0; t < grids->nthread; t++)
1844             {
1845                 sfree(grids->grid_th[t].grid);
1846             }
1847             sfree(grids->grid_th);
1848         }
1849     }
1850 }
1851
1852
1853 static void realloc_work(pme_work_t *work, int nkx)
1854 {
1855     if (nkx > work->nalloc)
1856     {
1857         work->nalloc = nkx;
1858         srenew(work->mhx, work->nalloc);
1859         srenew(work->mhy, work->nalloc);
1860         srenew(work->mhz, work->nalloc);
1861         srenew(work->m2, work->nalloc);
1862         /* Allocate an aligned pointer for SSE operations, including 3 extra
1863          * elements at the end since SSE operates on 4 elements at a time.
1864          */
1865         sfree_aligned(work->denom);
1866         sfree_aligned(work->tmp1);
1867         sfree_aligned(work->eterm);
1868         snew_aligned(work->denom, work->nalloc+3, 16);
1869         snew_aligned(work->tmp1, work->nalloc+3, 16);
1870         snew_aligned(work->eterm, work->nalloc+3, 16);
1871         srenew(work->m2inv, work->nalloc);
1872     }
1873 }
1874
1875
1876 static void free_work(pme_work_t *work)
1877 {
1878     sfree(work->mhx);
1879     sfree(work->mhy);
1880     sfree(work->mhz);
1881     sfree(work->m2);
1882     sfree_aligned(work->denom);
1883     sfree_aligned(work->tmp1);
1884     sfree_aligned(work->eterm);
1885     sfree(work->m2inv);
1886 }
1887
1888
1889 #ifdef PME_SSE
1890 /* Calculate exponentials through SSE in float precision */
1891 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1892 {
1893     {
1894         const __m128 two = _mm_set_ps(2.0f, 2.0f, 2.0f, 2.0f);
1895         __m128 f_sse;
1896         __m128 lu;
1897         __m128 tmp_d1, d_inv, tmp_r, tmp_e;
1898         int kx;
1899         f_sse = _mm_load1_ps(&f);
1900         for (kx = 0; kx < end; kx += 4)
1901         {
1902             tmp_d1   = _mm_load_ps(d_aligned+kx);
1903             lu       = _mm_rcp_ps(tmp_d1);
1904             d_inv    = _mm_mul_ps(lu, _mm_sub_ps(two, _mm_mul_ps(lu, tmp_d1)));
1905             tmp_r    = _mm_load_ps(r_aligned+kx);
1906             tmp_r    = gmx_mm_exp_ps(tmp_r);
1907             tmp_e    = _mm_mul_ps(f_sse, d_inv);
1908             tmp_e    = _mm_mul_ps(tmp_e, tmp_r);
1909             _mm_store_ps(e_aligned+kx, tmp_e);
1910         }
1911     }
1912 }
1913 #else
1914 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1915 {
1916     int kx;
1917     for (kx = start; kx < end; kx++)
1918     {
1919         d[kx] = 1.0/d[kx];
1920     }
1921     for (kx = start; kx < end; kx++)
1922     {
1923         r[kx] = exp(r[kx]);
1924     }
1925     for (kx = start; kx < end; kx++)
1926     {
1927         e[kx] = f*r[kx]*d[kx];
1928     }
1929 }
1930 #endif
1931
1932
1933 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1934                          real ewaldcoeff, real vol,
1935                          gmx_bool bEnerVir,
1936                          int nthread, int thread)
1937 {
1938     /* do recip sum over local cells in grid */
1939     /* y major, z middle, x minor or continuous */
1940     t_complex *p0;
1941     int     kx, ky, kz, maxkx, maxky, maxkz;
1942     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1943     real    mx, my, mz;
1944     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1945     real    ets2, struct2, vfactor, ets2vf;
1946     real    d1, d2, energy = 0;
1947     real    by, bz;
1948     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1949     real    rxx, ryx, ryy, rzx, rzy, rzz;
1950     pme_work_t *work;
1951     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1952     real    mhxk, mhyk, mhzk, m2k;
1953     real    corner_fac;
1954     ivec    complex_order;
1955     ivec    local_ndata, local_offset, local_size;
1956     real    elfac;
1957
1958     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1959
1960     nx = pme->nkx;
1961     ny = pme->nky;
1962     nz = pme->nkz;
1963
1964     /* Dimensions should be identical for A/B grid, so we just use A here */
1965     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1966                                       complex_order,
1967                                       local_ndata,
1968                                       local_offset,
1969                                       local_size);
1970
1971     rxx = pme->recipbox[XX][XX];
1972     ryx = pme->recipbox[YY][XX];
1973     ryy = pme->recipbox[YY][YY];
1974     rzx = pme->recipbox[ZZ][XX];
1975     rzy = pme->recipbox[ZZ][YY];
1976     rzz = pme->recipbox[ZZ][ZZ];
1977
1978     maxkx = (nx+1)/2;
1979     maxky = (ny+1)/2;
1980     maxkz = nz/2+1;
1981
1982     work  = &pme->work[thread];
1983     mhx   = work->mhx;
1984     mhy   = work->mhy;
1985     mhz   = work->mhz;
1986     m2    = work->m2;
1987     denom = work->denom;
1988     tmp1  = work->tmp1;
1989     eterm = work->eterm;
1990     m2inv = work->m2inv;
1991
1992     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1993     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1994
1995     for (iyz = iyz0; iyz < iyz1; iyz++)
1996     {
1997         iy = iyz/local_ndata[ZZ];
1998         iz = iyz - iy*local_ndata[ZZ];
1999
2000         ky = iy + local_offset[YY];
2001
2002         if (ky < maxky)
2003         {
2004             my = ky;
2005         }
2006         else
2007         {
2008             my = (ky - ny);
2009         }
2010
2011         by = M_PI*vol*pme->bsp_mod[YY][ky];
2012
2013         kz = iz + local_offset[ZZ];
2014
2015         mz = kz;
2016
2017         bz = pme->bsp_mod[ZZ][kz];
2018
2019         /* 0.5 correction for corner points */
2020         corner_fac = 1;
2021         if (kz == 0 || kz == (nz+1)/2)
2022         {
2023             corner_fac = 0.5;
2024         }
2025
2026         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2027
2028         /* We should skip the k-space point (0,0,0) */
2029         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2030         {
2031             kxstart = local_offset[XX];
2032         }
2033         else
2034         {
2035             kxstart = local_offset[XX] + 1;
2036             p0++;
2037         }
2038         kxend = local_offset[XX] + local_ndata[XX];
2039
2040         if (bEnerVir)
2041         {
2042             /* More expensive inner loop, especially because of the storage
2043              * of the mh elements in array's.
2044              * Because x is the minor grid index, all mh elements
2045              * depend on kx for triclinic unit cells.
2046              */
2047
2048             /* Two explicit loops to avoid a conditional inside the loop */
2049             for (kx = kxstart; kx < maxkx; kx++)
2050             {
2051                 mx = kx;
2052
2053                 mhxk      = mx * rxx;
2054                 mhyk      = mx * ryx + my * ryy;
2055                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2056                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2057                 mhx[kx]   = mhxk;
2058                 mhy[kx]   = mhyk;
2059                 mhz[kx]   = mhzk;
2060                 m2[kx]    = m2k;
2061                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2062                 tmp1[kx]  = -factor*m2k;
2063             }
2064
2065             for (kx = maxkx; kx < kxend; kx++)
2066             {
2067                 mx = (kx - nx);
2068
2069                 mhxk      = mx * rxx;
2070                 mhyk      = mx * ryx + my * ryy;
2071                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2072                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2073                 mhx[kx]   = mhxk;
2074                 mhy[kx]   = mhyk;
2075                 mhz[kx]   = mhzk;
2076                 m2[kx]    = m2k;
2077                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2078                 tmp1[kx]  = -factor*m2k;
2079             }
2080
2081             for (kx = kxstart; kx < kxend; kx++)
2082             {
2083                 m2inv[kx] = 1.0/m2[kx];
2084             }
2085
2086             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2087
2088             for (kx = kxstart; kx < kxend; kx++, p0++)
2089             {
2090                 d1      = p0->re;
2091                 d2      = p0->im;
2092
2093                 p0->re  = d1*eterm[kx];
2094                 p0->im  = d2*eterm[kx];
2095
2096                 struct2 = 2.0*(d1*d1+d2*d2);
2097
2098                 tmp1[kx] = eterm[kx]*struct2;
2099             }
2100
2101             for (kx = kxstart; kx < kxend; kx++)
2102             {
2103                 ets2     = corner_fac*tmp1[kx];
2104                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2105                 energy  += ets2;
2106
2107                 ets2vf   = ets2*vfactor;
2108                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2109                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2110                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2111                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2112                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2113                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2114             }
2115         }
2116         else
2117         {
2118             /* We don't need to calculate the energy and the virial.
2119              * In this case the triclinic overhead is small.
2120              */
2121
2122             /* Two explicit loops to avoid a conditional inside the loop */
2123
2124             for (kx = kxstart; kx < maxkx; kx++)
2125             {
2126                 mx = kx;
2127
2128                 mhxk      = mx * rxx;
2129                 mhyk      = mx * ryx + my * ryy;
2130                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2131                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2132                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2133                 tmp1[kx]  = -factor*m2k;
2134             }
2135
2136             for (kx = maxkx; kx < kxend; kx++)
2137             {
2138                 mx = (kx - nx);
2139
2140                 mhxk      = mx * rxx;
2141                 mhyk      = mx * ryx + my * ryy;
2142                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2143                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2144                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2145                 tmp1[kx]  = -factor*m2k;
2146             }
2147
2148             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2149
2150             for (kx = kxstart; kx < kxend; kx++, p0++)
2151             {
2152                 d1      = p0->re;
2153                 d2      = p0->im;
2154
2155                 p0->re  = d1*eterm[kx];
2156                 p0->im  = d2*eterm[kx];
2157             }
2158         }
2159     }
2160
2161     if (bEnerVir)
2162     {
2163         /* Update virial with local values.
2164          * The virial is symmetric by definition.
2165          * this virial seems ok for isotropic scaling, but I'm
2166          * experiencing problems on semiisotropic membranes.
2167          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2168          */
2169         work->vir[XX][XX] = 0.25*virxx;
2170         work->vir[YY][YY] = 0.25*viryy;
2171         work->vir[ZZ][ZZ] = 0.25*virzz;
2172         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2173         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2174         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2175
2176         /* This energy should be corrected for a charged system */
2177         work->energy = 0.5*energy;
2178     }
2179
2180     /* Return the loop count */
2181     return local_ndata[YY]*local_ndata[XX];
2182 }
2183
2184 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2185                              real *mesh_energy, matrix vir)
2186 {
2187     /* This function sums output over threads
2188      * and should therefore only be called after thread synchronization.
2189      */
2190     int thread;
2191
2192     *mesh_energy = pme->work[0].energy;
2193     copy_mat(pme->work[0].vir, vir);
2194
2195     for (thread = 1; thread < nthread; thread++)
2196     {
2197         *mesh_energy += pme->work[thread].energy;
2198         m_add(vir, pme->work[thread].vir, vir);
2199     }
2200 }
2201
2202 #define DO_FSPLINE(order)                      \
2203     for (ithx = 0; (ithx < order); ithx++)              \
2204     {                                              \
2205         index_x = (i0+ithx)*pny*pnz;               \
2206         tx      = thx[ithx];                       \
2207         dx      = dthx[ithx];                      \
2208                                                \
2209         for (ithy = 0; (ithy < order); ithy++)          \
2210         {                                          \
2211             index_xy = index_x+(j0+ithy)*pnz;      \
2212             ty       = thy[ithy];                  \
2213             dy       = dthy[ithy];                 \
2214             fxy1     = fz1 = 0;                    \
2215                                                \
2216             for (ithz = 0; (ithz < order); ithz++)      \
2217             {                                      \
2218                 gval  = grid[index_xy+(k0+ithz)];  \
2219                 fxy1 += thz[ithz]*gval;            \
2220                 fz1  += dthz[ithz]*gval;           \
2221             }                                      \
2222             fx += dx*ty*fxy1;                      \
2223             fy += tx*dy*fxy1;                      \
2224             fz += tx*ty*fz1;                       \
2225         }                                          \
2226     }
2227
2228
2229 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2230                               gmx_bool bClearF, pme_atomcomm_t *atc,
2231                               splinedata_t *spline,
2232                               real scale)
2233 {
2234     /* sum forces for local particles */
2235     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2236     int     index_x, index_xy;
2237     int     nx, ny, nz, pnx, pny, pnz;
2238     int *   idxptr;
2239     real    tx, ty, dx, dy, qn;
2240     real    fx, fy, fz, gval;
2241     real    fxy1, fz1;
2242     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2243     int     norder;
2244     real    rxx, ryx, ryy, rzx, rzy, rzz;
2245     int     order;
2246
2247     pme_spline_work_t *work;
2248
2249     work = pme->spline_work;
2250
2251     order = pme->pme_order;
2252     thx   = spline->theta[XX];
2253     thy   = spline->theta[YY];
2254     thz   = spline->theta[ZZ];
2255     dthx  = spline->dtheta[XX];
2256     dthy  = spline->dtheta[YY];
2257     dthz  = spline->dtheta[ZZ];
2258     nx    = pme->nkx;
2259     ny    = pme->nky;
2260     nz    = pme->nkz;
2261     pnx   = pme->pmegrid_nx;
2262     pny   = pme->pmegrid_ny;
2263     pnz   = pme->pmegrid_nz;
2264
2265     rxx   = pme->recipbox[XX][XX];
2266     ryx   = pme->recipbox[YY][XX];
2267     ryy   = pme->recipbox[YY][YY];
2268     rzx   = pme->recipbox[ZZ][XX];
2269     rzy   = pme->recipbox[ZZ][YY];
2270     rzz   = pme->recipbox[ZZ][ZZ];
2271
2272     for (nn = 0; nn < spline->n; nn++)
2273     {
2274         n  = spline->ind[nn];
2275         qn = scale*atc->q[n];
2276
2277         if (bClearF)
2278         {
2279             atc->f[n][XX] = 0;
2280             atc->f[n][YY] = 0;
2281             atc->f[n][ZZ] = 0;
2282         }
2283         if (qn != 0)
2284         {
2285             fx     = 0;
2286             fy     = 0;
2287             fz     = 0;
2288             idxptr = atc->idx[n];
2289             norder = nn*order;
2290
2291             i0   = idxptr[XX];
2292             j0   = idxptr[YY];
2293             k0   = idxptr[ZZ];
2294
2295             /* Pointer arithmetic alert, next six statements */
2296             thx  = spline->theta[XX] + norder;
2297             thy  = spline->theta[YY] + norder;
2298             thz  = spline->theta[ZZ] + norder;
2299             dthx = spline->dtheta[XX] + norder;
2300             dthy = spline->dtheta[YY] + norder;
2301             dthz = spline->dtheta[ZZ] + norder;
2302
2303             switch (order)
2304             {
2305                 case 4:
2306 #ifdef PME_SSE
2307 #ifdef PME_SSE_UNALIGNED
2308 #define PME_GATHER_F_SSE_ORDER4
2309 #else
2310 #define PME_GATHER_F_SSE_ALIGNED
2311 #define PME_ORDER 4
2312 #endif
2313 #include "pme_sse_single.h"
2314 #else
2315                     DO_FSPLINE(4);
2316 #endif
2317                     break;
2318                 case 5:
2319 #ifdef PME_SSE
2320 #define PME_GATHER_F_SSE_ALIGNED
2321 #define PME_ORDER 5
2322 #include "pme_sse_single.h"
2323 #else
2324                     DO_FSPLINE(5);
2325 #endif
2326                     break;
2327                 default:
2328                     DO_FSPLINE(order);
2329                     break;
2330             }
2331
2332             atc->f[n][XX] += -qn*( fx*nx*rxx );
2333             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2334             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2335         }
2336     }
2337     /* Since the energy and not forces are interpolated
2338      * the net force might not be exactly zero.
2339      * This can be solved by also interpolating F, but
2340      * that comes at a cost.
2341      * A better hack is to remove the net force every
2342      * step, but that must be done at a higher level
2343      * since this routine doesn't see all atoms if running
2344      * in parallel. Don't know how important it is?  EL 990726
2345      */
2346 }
2347
2348
2349 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2350                                    pme_atomcomm_t *atc)
2351 {
2352     splinedata_t *spline;
2353     int     n, ithx, ithy, ithz, i0, j0, k0;
2354     int     index_x, index_xy;
2355     int *   idxptr;
2356     real    energy, pot, tx, ty, qn, gval;
2357     real    *thx, *thy, *thz;
2358     int     norder;
2359     int     order;
2360
2361     spline = &atc->spline[0];
2362
2363     order = pme->pme_order;
2364
2365     energy = 0;
2366     for (n = 0; (n < atc->n); n++)
2367     {
2368         qn      = atc->q[n];
2369
2370         if (qn != 0)
2371         {
2372             idxptr = atc->idx[n];
2373             norder = n*order;
2374
2375             i0   = idxptr[XX];
2376             j0   = idxptr[YY];
2377             k0   = idxptr[ZZ];
2378
2379             /* Pointer arithmetic alert, next three statements */
2380             thx  = spline->theta[XX] + norder;
2381             thy  = spline->theta[YY] + norder;
2382             thz  = spline->theta[ZZ] + norder;
2383
2384             pot = 0;
2385             for (ithx = 0; (ithx < order); ithx++)
2386             {
2387                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2388                 tx      = thx[ithx];
2389
2390                 for (ithy = 0; (ithy < order); ithy++)
2391                 {
2392                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2393                     ty       = thy[ithy];
2394
2395                     for (ithz = 0; (ithz < order); ithz++)
2396                     {
2397                         gval  = grid[index_xy+(k0+ithz)];
2398                         pot  += tx*ty*thz[ithz]*gval;
2399                     }
2400
2401                 }
2402             }
2403
2404             energy += pot*qn;
2405         }
2406     }
2407
2408     return energy;
2409 }
2410
2411 /* Macro to force loop unrolling by fixing order.
2412  * This gives a significant performance gain.
2413  */
2414 #define CALC_SPLINE(order)                     \
2415     {                                              \
2416         int j, k, l;                                 \
2417         real dr, div;                               \
2418         real data[PME_ORDER_MAX];                  \
2419         real ddata[PME_ORDER_MAX];                 \
2420                                                \
2421         for (j = 0; (j < DIM); j++)                     \
2422         {                                          \
2423             dr  = xptr[j];                         \
2424                                                \
2425             /* dr is relative offset from lower cell limit */ \
2426             data[order-1] = 0;                     \
2427             data[1]       = dr;                          \
2428             data[0]       = 1 - dr;                      \
2429                                                \
2430             for (k = 3; (k < order); k++)               \
2431             {                                      \
2432                 div       = 1.0/(k - 1.0);               \
2433                 data[k-1] = div*dr*data[k-2];      \
2434                 for (l = 1; (l < (k-1)); l++)           \
2435                 {                                  \
2436                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2437                                        data[k-l-1]);                \
2438                 }                                  \
2439                 data[0] = div*(1-dr)*data[0];      \
2440             }                                      \
2441             /* differentiate */                    \
2442             ddata[0] = -data[0];                   \
2443             for (k = 1; (k < order); k++)               \
2444             {                                      \
2445                 ddata[k] = data[k-1] - data[k];    \
2446             }                                      \
2447                                                \
2448             div           = 1.0/(order - 1);                 \
2449             data[order-1] = div*dr*data[order-2];  \
2450             for (l = 1; (l < (order-1)); l++)           \
2451             {                                      \
2452                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2453                                        (order-l-dr)*data[order-l-1]); \
2454             }                                      \
2455             data[0] = div*(1 - dr)*data[0];        \
2456                                                \
2457             for (k = 0; k < order; k++)                 \
2458             {                                      \
2459                 theta[j][i*order+k]  = data[k];    \
2460                 dtheta[j][i*order+k] = ddata[k];   \
2461             }                                      \
2462         }                                          \
2463     }
2464
2465 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2466                    rvec fractx[], int nr, int ind[], real charge[],
2467                    gmx_bool bFreeEnergy)
2468 {
2469     /* construct splines for local atoms */
2470     int  i, ii;
2471     real *xptr;
2472
2473     for (i = 0; i < nr; i++)
2474     {
2475         /* With free energy we do not use the charge check.
2476          * In most cases this will be more efficient than calling make_bsplines
2477          * twice, since usually more than half the particles have charges.
2478          */
2479         ii = ind[i];
2480         if (bFreeEnergy || charge[ii] != 0.0)
2481         {
2482             xptr = fractx[ii];
2483             switch (order)
2484             {
2485                 case 4:  CALC_SPLINE(4);     break;
2486                 case 5:  CALC_SPLINE(5);     break;
2487                 default: CALC_SPLINE(order); break;
2488             }
2489         }
2490     }
2491 }
2492
2493
2494 void make_dft_mod(real *mod, real *data, int ndata)
2495 {
2496     int i, j;
2497     real sc, ss, arg;
2498
2499     for (i = 0; i < ndata; i++)
2500     {
2501         sc = ss = 0;
2502         for (j = 0; j < ndata; j++)
2503         {
2504             arg = (2.0*M_PI*i*j)/ndata;
2505             sc += data[j]*cos(arg);
2506             ss += data[j]*sin(arg);
2507         }
2508         mod[i] = sc*sc+ss*ss;
2509     }
2510     for (i = 0; i < ndata; i++)
2511     {
2512         if (mod[i] < 1e-7)
2513         {
2514             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2515         }
2516     }
2517 }
2518
2519
2520 static void make_bspline_moduli(splinevec bsp_mod,
2521                                 int nx, int ny, int nz, int order)
2522 {
2523     int nmax = max(nx, max(ny, nz));
2524     real *data, *ddata, *bsp_data;
2525     int i, k, l;
2526     real div;
2527
2528     snew(data, order);
2529     snew(ddata, order);
2530     snew(bsp_data, nmax);
2531
2532     data[order-1] = 0;
2533     data[1]       = 0;
2534     data[0]       = 1;
2535
2536     for (k = 3; k < order; k++)
2537     {
2538         div       = 1.0/(k-1.0);
2539         data[k-1] = 0;
2540         for (l = 1; l < (k-1); l++)
2541         {
2542             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2543         }
2544         data[0] = div*data[0];
2545     }
2546     /* differentiate */
2547     ddata[0] = -data[0];
2548     for (k = 1; k < order; k++)
2549     {
2550         ddata[k] = data[k-1]-data[k];
2551     }
2552     div           = 1.0/(order-1);
2553     data[order-1] = 0;
2554     for (l = 1; l < (order-1); l++)
2555     {
2556         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2557     }
2558     data[0] = div*data[0];
2559
2560     for (i = 0; i < nmax; i++)
2561     {
2562         bsp_data[i] = 0;
2563     }
2564     for (i = 1; i <= order; i++)
2565     {
2566         bsp_data[i] = data[i-1];
2567     }
2568
2569     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2570     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2571     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2572
2573     sfree(data);
2574     sfree(ddata);
2575     sfree(bsp_data);
2576 }
2577
2578
2579 /* Return the P3M optimal influence function */
2580 static double do_p3m_influence(double z, int order)
2581 {
2582     double z2, z4;
2583
2584     z2 = z*z;
2585     z4 = z2*z2;
2586
2587     /* The formula and most constants can be found in:
2588      * Ballenegger et al., JCTC 8, 936 (2012)
2589      */
2590     switch (order)
2591     {
2592         case 2:
2593             return 1.0 - 2.0*z2/3.0;
2594             break;
2595         case 3:
2596             return 1.0 - z2 + 2.0*z4/15.0;
2597             break;
2598         case 4:
2599             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2600             break;
2601         case 5:
2602             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2603             break;
2604         case 6:
2605             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2606             break;
2607         case 7:
2608             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2609         case 8:
2610             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2611             break;
2612     }
2613
2614     return 0.0;
2615 }
2616
2617 /* Calculate the P3M B-spline moduli for one dimension */
2618 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2619 {
2620     double zarg, zai, sinzai, infl;
2621     int    maxk, i;
2622
2623     if (order > 8)
2624     {
2625         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2626     }
2627
2628     zarg = M_PI/n;
2629
2630     maxk = (n + 1)/2;
2631
2632     for (i = -maxk; i < 0; i++)
2633     {
2634         zai          = zarg*i;
2635         sinzai       = sin(zai);
2636         infl         = do_p3m_influence(sinzai, order);
2637         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2638     }
2639     bsp_mod[0] = 1.0;
2640     for (i = 1; i < maxk; i++)
2641     {
2642         zai        = zarg*i;
2643         sinzai     = sin(zai);
2644         infl       = do_p3m_influence(sinzai, order);
2645         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2646     }
2647 }
2648
2649 /* Calculate the P3M B-spline moduli */
2650 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2651                                     int nx, int ny, int nz, int order)
2652 {
2653     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2654     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2655     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2656 }
2657
2658
2659 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2660 {
2661     int nslab, n, i;
2662     int fw, bw;
2663
2664     nslab = atc->nslab;
2665
2666     n = 0;
2667     for (i = 1; i <= nslab/2; i++)
2668     {
2669         fw = (atc->nodeid + i) % nslab;
2670         bw = (atc->nodeid - i + nslab) % nslab;
2671         if (n < nslab - 1)
2672         {
2673             atc->node_dest[n] = fw;
2674             atc->node_src[n]  = bw;
2675             n++;
2676         }
2677         if (n < nslab - 1)
2678         {
2679             atc->node_dest[n] = bw;
2680             atc->node_src[n]  = fw;
2681             n++;
2682         }
2683     }
2684 }
2685
2686 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2687 {
2688     int thread;
2689
2690     if (NULL != log)
2691     {
2692         fprintf(log, "Destroying PME data structures.\n");
2693     }
2694
2695     sfree((*pmedata)->nnx);
2696     sfree((*pmedata)->nny);
2697     sfree((*pmedata)->nnz);
2698
2699     pmegrids_destroy(&(*pmedata)->pmegridA);
2700
2701     sfree((*pmedata)->fftgridA);
2702     sfree((*pmedata)->cfftgridA);
2703     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2704
2705     if ((*pmedata)->pmegridB.grid.grid != NULL)
2706     {
2707         pmegrids_destroy(&(*pmedata)->pmegridB);
2708         sfree((*pmedata)->fftgridB);
2709         sfree((*pmedata)->cfftgridB);
2710         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2711     }
2712     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2713     {
2714         free_work(&(*pmedata)->work[thread]);
2715     }
2716     sfree((*pmedata)->work);
2717
2718     sfree(*pmedata);
2719     *pmedata = NULL;
2720
2721     return 0;
2722 }
2723
2724 static int mult_up(int n, int f)
2725 {
2726     return ((n + f - 1)/f)*f;
2727 }
2728
2729
2730 static double pme_load_imbalance(gmx_pme_t pme)
2731 {
2732     int    nma, nmi;
2733     double n1, n2, n3;
2734
2735     nma = pme->nnodes_major;
2736     nmi = pme->nnodes_minor;
2737
2738     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2739     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2740     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2741
2742     /* pme_solve is roughly double the cost of an fft */
2743
2744     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2745 }
2746
2747 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2748                           int dimind, gmx_bool bSpread)
2749 {
2750     int nk, k, s, thread;
2751
2752     atc->dimind    = dimind;
2753     atc->nslab     = 1;
2754     atc->nodeid    = 0;
2755     atc->pd_nalloc = 0;
2756 #ifdef GMX_MPI
2757     if (pme->nnodes > 1)
2758     {
2759         atc->mpi_comm = pme->mpi_comm_d[dimind];
2760         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2761         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2762     }
2763     if (debug)
2764     {
2765         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2766     }
2767 #endif
2768
2769     atc->bSpread   = bSpread;
2770     atc->pme_order = pme->pme_order;
2771
2772     if (atc->nslab > 1)
2773     {
2774         /* These three allocations are not required for particle decomp. */
2775         snew(atc->node_dest, atc->nslab);
2776         snew(atc->node_src, atc->nslab);
2777         setup_coordinate_communication(atc);
2778
2779         snew(atc->count_thread, pme->nthread);
2780         for (thread = 0; thread < pme->nthread; thread++)
2781         {
2782             snew(atc->count_thread[thread], atc->nslab);
2783         }
2784         atc->count = atc->count_thread[0];
2785         snew(atc->rcount, atc->nslab);
2786         snew(atc->buf_index, atc->nslab);
2787     }
2788
2789     atc->nthread = pme->nthread;
2790     if (atc->nthread > 1)
2791     {
2792         snew(atc->thread_plist, atc->nthread);
2793     }
2794     snew(atc->spline, atc->nthread);
2795     for (thread = 0; thread < atc->nthread; thread++)
2796     {
2797         if (atc->nthread > 1)
2798         {
2799             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2800             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2801         }
2802         snew(atc->spline[thread].thread_one, pme->nthread);
2803         atc->spline[thread].thread_one[thread] = 1;
2804     }
2805 }
2806
2807 static void
2808 init_overlap_comm(pme_overlap_t *  ol,
2809                   int              norder,
2810 #ifdef GMX_MPI
2811                   MPI_Comm         comm,
2812 #endif
2813                   int              nnodes,
2814                   int              nodeid,
2815                   int              ndata,
2816                   int              commplainsize)
2817 {
2818     int lbnd, rbnd, maxlr, b, i;
2819     int exten;
2820     int nn, nk;
2821     pme_grid_comm_t *pgc;
2822     gmx_bool bCont;
2823     int fft_start, fft_end, send_index1, recv_index1;
2824 #ifdef GMX_MPI
2825     MPI_Status stat;
2826
2827     ol->mpi_comm = comm;
2828 #endif
2829
2830     ol->nnodes = nnodes;
2831     ol->nodeid = nodeid;
2832
2833     /* Linear translation of the PME grid won't affect reciprocal space
2834      * calculations, so to optimize we only interpolate "upwards",
2835      * which also means we only have to consider overlap in one direction.
2836      * I.e., particles on this node might also be spread to grid indices
2837      * that belong to higher nodes (modulo nnodes)
2838      */
2839
2840     snew(ol->s2g0, ol->nnodes+1);
2841     snew(ol->s2g1, ol->nnodes);
2842     if (debug)
2843     {
2844         fprintf(debug, "PME slab boundaries:");
2845     }
2846     for (i = 0; i < nnodes; i++)
2847     {
2848         /* s2g0 the local interpolation grid start.
2849          * s2g1 the local interpolation grid end.
2850          * Because grid overlap communication only goes forward,
2851          * the grid the slabs for fft's should be rounded down.
2852          */
2853         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2854         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2855
2856         if (debug)
2857         {
2858             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2859         }
2860     }
2861     ol->s2g0[nnodes] = ndata;
2862     if (debug)
2863     {
2864         fprintf(debug, "\n");
2865     }
2866
2867     /* Determine with how many nodes we need to communicate the grid overlap */
2868     b = 0;
2869     do
2870     {
2871         b++;
2872         bCont = FALSE;
2873         for (i = 0; i < nnodes; i++)
2874         {
2875             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2876                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2877             {
2878                 bCont = TRUE;
2879             }
2880         }
2881     }
2882     while (bCont && b < nnodes);
2883     ol->noverlap_nodes = b - 1;
2884
2885     snew(ol->send_id, ol->noverlap_nodes);
2886     snew(ol->recv_id, ol->noverlap_nodes);
2887     for (b = 0; b < ol->noverlap_nodes; b++)
2888     {
2889         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2890         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2891     }
2892     snew(ol->comm_data, ol->noverlap_nodes);
2893
2894     ol->send_size = 0;
2895     for (b = 0; b < ol->noverlap_nodes; b++)
2896     {
2897         pgc = &ol->comm_data[b];
2898         /* Send */
2899         fft_start        = ol->s2g0[ol->send_id[b]];
2900         fft_end          = ol->s2g0[ol->send_id[b]+1];
2901         if (ol->send_id[b] < nodeid)
2902         {
2903             fft_start += ndata;
2904             fft_end   += ndata;
2905         }
2906         send_index1       = ol->s2g1[nodeid];
2907         send_index1       = min(send_index1, fft_end);
2908         pgc->send_index0  = fft_start;
2909         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2910         ol->send_size    += pgc->send_nindex;
2911
2912         /* We always start receiving to the first index of our slab */
2913         fft_start        = ol->s2g0[ol->nodeid];
2914         fft_end          = ol->s2g0[ol->nodeid+1];
2915         recv_index1      = ol->s2g1[ol->recv_id[b]];
2916         if (ol->recv_id[b] > nodeid)
2917         {
2918             recv_index1 -= ndata;
2919         }
2920         recv_index1      = min(recv_index1, fft_end);
2921         pgc->recv_index0 = fft_start;
2922         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2923     }
2924
2925 #ifdef GMX_MPI
2926     /* Communicate the buffer sizes to receive */
2927     for (b = 0; b < ol->noverlap_nodes; b++)
2928     {
2929         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2930                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2931                      ol->mpi_comm, &stat);
2932     }
2933 #endif
2934
2935     /* For non-divisible grid we need pme_order iso pme_order-1 */
2936     snew(ol->sendbuf, norder*commplainsize);
2937     snew(ol->recvbuf, norder*commplainsize);
2938 }
2939
2940 static void
2941 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2942                               int **global_to_local,
2943                               real **fraction_shift)
2944 {
2945     int i;
2946     int * gtl;
2947     real * fsh;
2948
2949     snew(gtl, 5*n);
2950     snew(fsh, 5*n);
2951     for (i = 0; (i < 5*n); i++)
2952     {
2953         /* Determine the global to local grid index */
2954         gtl[i] = (i - local_start + n) % n;
2955         /* For coordinates that fall within the local grid the fraction
2956          * is correct, we don't need to shift it.
2957          */
2958         fsh[i] = 0;
2959         if (local_range < n)
2960         {
2961             /* Due to rounding issues i could be 1 beyond the lower or
2962              * upper boundary of the local grid. Correct the index for this.
2963              * If we shift the index, we need to shift the fraction by
2964              * the same amount in the other direction to not affect
2965              * the weights.
2966              * Note that due to this shifting the weights at the end of
2967              * the spline might change, but that will only involve values
2968              * between zero and values close to the precision of a real,
2969              * which is anyhow the accuracy of the whole mesh calculation.
2970              */
2971             /* With local_range=0 we should not change i=local_start */
2972             if (i % n != local_start)
2973             {
2974                 if (gtl[i] == n-1)
2975                 {
2976                     gtl[i] = 0;
2977                     fsh[i] = -1;
2978                 }
2979                 else if (gtl[i] == local_range)
2980                 {
2981                     gtl[i] = local_range - 1;
2982                     fsh[i] = 1;
2983                 }
2984             }
2985         }
2986     }
2987
2988     *global_to_local = gtl;
2989     *fraction_shift  = fsh;
2990 }
2991
2992 static pme_spline_work_t *make_pme_spline_work(int order)
2993 {
2994     pme_spline_work_t *work;
2995
2996 #ifdef PME_SSE
2997     float  tmp[8];
2998     __m128 zero_SSE;
2999     int    of, i;
3000
3001     snew_aligned(work, 1, 16);
3002
3003     zero_SSE = _mm_setzero_ps();
3004
3005     /* Generate bit masks to mask out the unused grid entries,
3006      * as we only operate on order of the 8 grid entries that are
3007      * load into 2 SSE float registers.
3008      */
3009     for (of = 0; of < 8-(order-1); of++)
3010     {
3011         for (i = 0; i < 8; i++)
3012         {
3013             tmp[i] = (i >= of && i < of+order ? 1 : 0);
3014         }
3015         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
3016         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
3017         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
3018         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
3019     }
3020 #else
3021     work = NULL;
3022 #endif
3023
3024     return work;
3025 }
3026
3027 void gmx_pme_check_restrictions(int pme_order,
3028                                 int nkx, int nky, int nkz,
3029                                 int nnodes_major,
3030                                 int nnodes_minor,
3031                                 gmx_bool bUseThreads,
3032                                 gmx_bool bFatal,
3033                                 gmx_bool *bValidSettings)
3034 {
3035     if (pme_order > PME_ORDER_MAX)
3036     {
3037         if (!bFatal)
3038         {
3039             *bValidSettings = FALSE;
3040             return;
3041         }
3042         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3043                   pme_order, PME_ORDER_MAX);
3044     }
3045
3046     if (nkx <= pme_order*(nnodes_major > 1 ? 2 : 1) ||
3047         nky <= pme_order*(nnodes_minor > 1 ? 2 : 1) ||
3048         nkz <= pme_order)
3049     {
3050         if (!bFatal)
3051         {
3052             *bValidSettings = FALSE;
3053             return;
3054         }
3055         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",
3056                   pme_order);
3057     }
3058
3059     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3060      * We only allow multiple communication pulses in dim 1, not in dim 0.
3061      */
3062     if (bUseThreads && (nkx < nnodes_major*pme_order &&
3063                         nkx != nnodes_major*(pme_order - 1)))
3064     {
3065         if (!bFatal)
3066         {
3067             *bValidSettings = FALSE;
3068             return;
3069         }
3070         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x should be >= pme_order (%d) or = pmeorder-1. To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3071                   nkx/(double)nnodes_major, pme_order);
3072     }
3073
3074     if (bValidSettings != NULL)
3075     {
3076         *bValidSettings = TRUE;
3077     }
3078
3079     return;
3080 }
3081
3082 int gmx_pme_init(gmx_pme_t *         pmedata,
3083                  t_commrec *         cr,
3084                  int                 nnodes_major,
3085                  int                 nnodes_minor,
3086                  t_inputrec *        ir,
3087                  int                 homenr,
3088                  gmx_bool            bFreeEnergy,
3089                  gmx_bool            bReproducible,
3090                  int                 nthread)
3091 {
3092     gmx_pme_t pme = NULL;
3093
3094     int  use_threads, sum_use_threads;
3095     ivec ndata;
3096
3097     if (debug)
3098     {
3099         fprintf(debug, "Creating PME data structures.\n");
3100     }
3101     snew(pme, 1);
3102
3103     pme->redist_init         = FALSE;
3104     pme->sum_qgrid_tmp       = NULL;
3105     pme->sum_qgrid_dd_tmp    = NULL;
3106     pme->buf_nalloc          = 0;
3107     pme->redist_buf_nalloc   = 0;
3108
3109     pme->nnodes              = 1;
3110     pme->bPPnode             = TRUE;
3111
3112     pme->nnodes_major        = nnodes_major;
3113     pme->nnodes_minor        = nnodes_minor;
3114
3115 #ifdef GMX_MPI
3116     if (nnodes_major*nnodes_minor > 1)
3117     {
3118         pme->mpi_comm = cr->mpi_comm_mygroup;
3119
3120         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3121         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3122         if (pme->nnodes != nnodes_major*nnodes_minor)
3123         {
3124             gmx_incons("PME node count mismatch");
3125         }
3126     }
3127     else
3128     {
3129         pme->mpi_comm = MPI_COMM_NULL;
3130     }
3131 #endif
3132
3133     if (pme->nnodes == 1)
3134     {
3135 #ifdef GMX_MPI
3136         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3137         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3138 #endif
3139         pme->ndecompdim   = 0;
3140         pme->nodeid_major = 0;
3141         pme->nodeid_minor = 0;
3142 #ifdef GMX_MPI
3143         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3144 #endif
3145     }
3146     else
3147     {
3148         if (nnodes_minor == 1)
3149         {
3150 #ifdef GMX_MPI
3151             pme->mpi_comm_d[0] = pme->mpi_comm;
3152             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3153 #endif
3154             pme->ndecompdim   = 1;
3155             pme->nodeid_major = pme->nodeid;
3156             pme->nodeid_minor = 0;
3157
3158         }
3159         else if (nnodes_major == 1)
3160         {
3161 #ifdef GMX_MPI
3162             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3163             pme->mpi_comm_d[1] = pme->mpi_comm;
3164 #endif
3165             pme->ndecompdim   = 1;
3166             pme->nodeid_major = 0;
3167             pme->nodeid_minor = pme->nodeid;
3168         }
3169         else
3170         {
3171             if (pme->nnodes % nnodes_major != 0)
3172             {
3173                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3174             }
3175             pme->ndecompdim = 2;
3176
3177 #ifdef GMX_MPI
3178             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3179                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3180             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3181                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3182
3183             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3184             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3185             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3186             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3187 #endif
3188         }
3189         pme->bPPnode = (cr->duty & DUTY_PP);
3190     }
3191
3192     pme->nthread = nthread;
3193
3194      /* Check if any of the PME MPI ranks uses threads */
3195     use_threads = (pme->nthread > 1 ? 1 : 0);
3196 #ifdef GMX_MPI
3197     if (pme->nnodes > 1)
3198     {
3199         MPI_Allreduce(&use_threads, &sum_use_threads, 1, MPI_INT,
3200                       MPI_SUM, pme->mpi_comm);
3201     }
3202     else
3203 #endif
3204     {
3205         sum_use_threads = use_threads;
3206     }
3207     pme->bUseThreads = (sum_use_threads > 0);
3208
3209     if (ir->ePBC == epbcSCREW)
3210     {
3211         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3212     }
3213
3214     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3215     pme->nkx         = ir->nkx;
3216     pme->nky         = ir->nky;
3217     pme->nkz         = ir->nkz;
3218     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3219     pme->pme_order   = ir->pme_order;
3220     pme->epsilon_r   = ir->epsilon_r;
3221
3222     /* If we violate restrictions, generate a fatal error here */
3223     gmx_pme_check_restrictions(pme->pme_order,
3224                                pme->nkx, pme->nky, pme->nkz,
3225                                pme->nnodes_major,
3226                                pme->nnodes_minor,
3227                                pme->bUseThreads,
3228                                TRUE,
3229                                NULL);
3230
3231     if (pme->nnodes > 1)
3232     {
3233         double imbal;
3234
3235 #ifdef GMX_MPI
3236         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3237         MPI_Type_commit(&(pme->rvec_mpi));
3238 #endif
3239
3240         /* Note that the charge spreading and force gathering, which usually
3241          * takes about the same amount of time as FFT+solve_pme,
3242          * is always fully load balanced
3243          * (unless the charge distribution is inhomogeneous).
3244          */
3245
3246         imbal = pme_load_imbalance(pme);
3247         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3248         {
3249             fprintf(stderr,
3250                     "\n"
3251                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3252                     "      For optimal PME load balancing\n"
3253                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3254                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3255                     "\n",
3256                     (int)((imbal-1)*100 + 0.5),
3257                     pme->nkx, pme->nky, pme->nnodes_major,
3258                     pme->nky, pme->nkz, pme->nnodes_minor);
3259         }
3260     }
3261
3262     /* For non-divisible grid we need pme_order iso pme_order-1 */
3263     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3264      * y is always copied through a buffer: we don't need padding in z,
3265      * but we do need the overlap in x because of the communication order.
3266      */
3267     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3268 #ifdef GMX_MPI
3269                       pme->mpi_comm_d[0],
3270 #endif
3271                       pme->nnodes_major, pme->nodeid_major,
3272                       pme->nkx,
3273                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3274
3275     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3276      * We do this with an offset buffer of equal size, so we need to allocate
3277      * extra for the offset. That's what the (+1)*pme->nkz is for.
3278      */
3279     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3280 #ifdef GMX_MPI
3281                       pme->mpi_comm_d[1],
3282 #endif
3283                       pme->nnodes_minor, pme->nodeid_minor,
3284                       pme->nky,
3285                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3286
3287     /* Double-check for a limitation of the (current) sum_fftgrid_dd code.
3288      * Note that gmx_pme_check_restrictions checked for this already.
3289      */
3290     if (pme->bUseThreads && pme->overlap[0].noverlap_nodes > 1)
3291     {
3292         gmx_incons("More than one communication pulse required for grid overlap communication along the major dimension while using threads");
3293     }
3294
3295     snew(pme->bsp_mod[XX], pme->nkx);
3296     snew(pme->bsp_mod[YY], pme->nky);
3297     snew(pme->bsp_mod[ZZ], pme->nkz);
3298
3299     /* The required size of the interpolation grid, including overlap.
3300      * The allocated size (pmegrid_n?) might be slightly larger.
3301      */
3302     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3303         pme->overlap[0].s2g0[pme->nodeid_major];
3304     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3305         pme->overlap[1].s2g0[pme->nodeid_minor];
3306     pme->pmegrid_nz_base = pme->nkz;
3307     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3308     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3309
3310     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3311     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3312     pme->pmegrid_start_iz = 0;
3313
3314     make_gridindex5_to_localindex(pme->nkx,
3315                                   pme->pmegrid_start_ix,
3316                                   pme->pmegrid_nx - (pme->pme_order-1),
3317                                   &pme->nnx, &pme->fshx);
3318     make_gridindex5_to_localindex(pme->nky,
3319                                   pme->pmegrid_start_iy,
3320                                   pme->pmegrid_ny - (pme->pme_order-1),
3321                                   &pme->nny, &pme->fshy);
3322     make_gridindex5_to_localindex(pme->nkz,
3323                                   pme->pmegrid_start_iz,
3324                                   pme->pmegrid_nz_base,
3325                                   &pme->nnz, &pme->fshz);
3326
3327     pmegrids_init(&pme->pmegridA,
3328                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3329                   pme->pmegrid_nz_base,
3330                   pme->pme_order,
3331                   pme->bUseThreads,
3332                   pme->nthread,
3333                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3334                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3335
3336     pme->spline_work = make_pme_spline_work(pme->pme_order);
3337
3338     ndata[0] = pme->nkx;
3339     ndata[1] = pme->nky;
3340     ndata[2] = pme->nkz;
3341
3342     /* This routine will allocate the grid data to fit the FFTs */
3343     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3344                             &pme->fftgridA, &pme->cfftgridA,
3345                             pme->mpi_comm_d,
3346                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3347                             bReproducible, pme->nthread);
3348
3349     if (bFreeEnergy)
3350     {
3351         pmegrids_init(&pme->pmegridB,
3352                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3353                       pme->pmegrid_nz_base,
3354                       pme->pme_order,
3355                       pme->bUseThreads,
3356                       pme->nthread,
3357                       pme->nkx % pme->nnodes_major != 0,
3358                       pme->nky % pme->nnodes_minor != 0);
3359
3360         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3361                                 &pme->fftgridB, &pme->cfftgridB,
3362                                 pme->mpi_comm_d,
3363                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3364                                 bReproducible, pme->nthread);
3365     }
3366     else
3367     {
3368         pme->pmegridB.grid.grid = NULL;
3369         pme->fftgridB           = NULL;
3370         pme->cfftgridB          = NULL;
3371     }
3372
3373     if (!pme->bP3M)
3374     {
3375         /* Use plain SPME B-spline interpolation */
3376         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3377     }
3378     else
3379     {
3380         /* Use the P3M grid-optimized influence function */
3381         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3382     }
3383
3384     /* Use atc[0] for spreading */
3385     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3386     if (pme->ndecompdim >= 2)
3387     {
3388         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3389     }
3390
3391     if (pme->nnodes == 1)
3392     {
3393         pme->atc[0].n = homenr;
3394         pme_realloc_atomcomm_things(&pme->atc[0]);
3395     }
3396
3397     {
3398         int thread;
3399
3400         /* Use fft5d, order after FFT is y major, z, x minor */
3401
3402         snew(pme->work, pme->nthread);
3403         for (thread = 0; thread < pme->nthread; thread++)
3404         {
3405             realloc_work(&pme->work[thread], pme->nkx);
3406         }
3407     }
3408
3409     *pmedata = pme;
3410
3411     return 0;
3412 }
3413
3414 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3415 {
3416     int d, t;
3417
3418     for (d = 0; d < DIM; d++)
3419     {
3420         if (new->grid.n[d] > old->grid.n[d])
3421         {
3422             return;
3423         }
3424     }
3425
3426     sfree_aligned(new->grid.grid);
3427     new->grid.grid = old->grid.grid;
3428
3429     if (new->grid_th != NULL && new->nthread == old->nthread)
3430     {
3431         sfree_aligned(new->grid_all);
3432         for (t = 0; t < new->nthread; t++)
3433         {
3434             new->grid_th[t].grid = old->grid_th[t].grid;
3435         }
3436     }
3437 }
3438
3439 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3440                    t_commrec *         cr,
3441                    gmx_pme_t           pme_src,
3442                    const t_inputrec *  ir,
3443                    ivec                grid_size)
3444 {
3445     t_inputrec irc;
3446     int homenr;
3447     int ret;
3448
3449     irc     = *ir;
3450     irc.nkx = grid_size[XX];
3451     irc.nky = grid_size[YY];
3452     irc.nkz = grid_size[ZZ];
3453
3454     if (pme_src->nnodes == 1)
3455     {
3456         homenr = pme_src->atc[0].n;
3457     }
3458     else
3459     {
3460         homenr = -1;
3461     }
3462
3463     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3464                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3465
3466     if (ret == 0)
3467     {
3468         /* We can easily reuse the allocated pme grids in pme_src */
3469         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3470         /* We would like to reuse the fft grids, but that's harder */
3471     }
3472
3473     return ret;
3474 }
3475
3476
3477 static void copy_local_grid(gmx_pme_t pme,
3478                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3479 {
3480     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3481     int  fft_my, fft_mz;
3482     int  nsx, nsy, nsz;
3483     ivec nf;
3484     int  offx, offy, offz, x, y, z, i0, i0t;
3485     int  d;
3486     pmegrid_t *pmegrid;
3487     real *grid_th;
3488
3489     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3490                                    local_fft_ndata,
3491                                    local_fft_offset,
3492                                    local_fft_size);
3493     fft_my = local_fft_size[YY];
3494     fft_mz = local_fft_size[ZZ];
3495
3496     pmegrid = &pmegrids->grid_th[thread];
3497
3498     nsx = pmegrid->s[XX];
3499     nsy = pmegrid->s[YY];
3500     nsz = pmegrid->s[ZZ];
3501
3502     for (d = 0; d < DIM; d++)
3503     {
3504         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3505                     local_fft_ndata[d] - pmegrid->offset[d]);
3506     }
3507
3508     offx = pmegrid->offset[XX];
3509     offy = pmegrid->offset[YY];
3510     offz = pmegrid->offset[ZZ];
3511
3512     /* Directly copy the non-overlapping parts of the local grids.
3513      * This also initializes the full grid.
3514      */
3515     grid_th = pmegrid->grid;
3516     for (x = 0; x < nf[XX]; x++)
3517     {
3518         for (y = 0; y < nf[YY]; y++)
3519         {
3520             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3521             i0t = (x*nsy + y)*nsz;
3522             for (z = 0; z < nf[ZZ]; z++)
3523             {
3524                 fftgrid[i0+z] = grid_th[i0t+z];
3525             }
3526         }
3527     }
3528 }
3529
3530 static void
3531 reduce_threadgrid_overlap(gmx_pme_t pme,
3532                           const pmegrids_t *pmegrids, int thread,
3533                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3534 {
3535     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3536     int  fft_nx, fft_ny, fft_nz;
3537     int  fft_my, fft_mz;
3538     int  buf_my = -1;
3539     int  nsx, nsy, nsz;
3540     ivec ne;
3541     int  offx, offy, offz, x, y, z, i0, i0t;
3542     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3543     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3544     gmx_bool bCommX, bCommY;
3545     int  d;
3546     int  thread_f;
3547     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3548     const real *grid_th;
3549     real *commbuf = NULL;
3550
3551     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3552                                    local_fft_ndata,
3553                                    local_fft_offset,
3554                                    local_fft_size);
3555     fft_nx = local_fft_ndata[XX];
3556     fft_ny = local_fft_ndata[YY];
3557     fft_nz = local_fft_ndata[ZZ];
3558
3559     fft_my = local_fft_size[YY];
3560     fft_mz = local_fft_size[ZZ];
3561
3562     /* This routine is called when all thread have finished spreading.
3563      * Here each thread sums grid contributions calculated by other threads
3564      * to the thread local grid volume.
3565      * To minimize the number of grid copying operations,
3566      * this routines sums immediately from the pmegrid to the fftgrid.
3567      */
3568
3569     /* Determine which part of the full node grid we should operate on,
3570      * this is our thread local part of the full grid.
3571      */
3572     pmegrid = &pmegrids->grid_th[thread];
3573
3574     for (d = 0; d < DIM; d++)
3575     {
3576         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3577                     local_fft_ndata[d]);
3578     }
3579
3580     offx = pmegrid->offset[XX];
3581     offy = pmegrid->offset[YY];
3582     offz = pmegrid->offset[ZZ];
3583
3584
3585     bClearBufX  = TRUE;
3586     bClearBufY  = TRUE;
3587     bClearBufXY = TRUE;
3588
3589     /* Now loop over all the thread data blocks that contribute
3590      * to the grid region we (our thread) are operating on.
3591      */
3592     /* Note that ffy_nx/y is equal to the number of grid points
3593      * between the first point of our node grid and the one of the next node.
3594      */
3595     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3596     {
3597         fx     = pmegrid->ci[XX] + sx;
3598         ox     = 0;
3599         bCommX = FALSE;
3600         if (fx < 0)
3601         {
3602             fx    += pmegrids->nc[XX];
3603             ox    -= fft_nx;
3604             bCommX = (pme->nnodes_major > 1);
3605         }
3606         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3607         ox       += pmegrid_g->offset[XX];
3608         if (!bCommX)
3609         {
3610             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3611         }
3612         else
3613         {
3614             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3615         }
3616
3617         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3618         {
3619             fy     = pmegrid->ci[YY] + sy;
3620             oy     = 0;
3621             bCommY = FALSE;
3622             if (fy < 0)
3623             {
3624                 fy    += pmegrids->nc[YY];
3625                 oy    -= fft_ny;
3626                 bCommY = (pme->nnodes_minor > 1);
3627             }
3628             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3629             oy       += pmegrid_g->offset[YY];
3630             if (!bCommY)
3631             {
3632                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3633             }
3634             else
3635             {
3636                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3637             }
3638
3639             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3640             {
3641                 fz = pmegrid->ci[ZZ] + sz;
3642                 oz = 0;
3643                 if (fz < 0)
3644                 {
3645                     fz += pmegrids->nc[ZZ];
3646                     oz -= fft_nz;
3647                 }
3648                 pmegrid_g = &pmegrids->grid_th[fz];
3649                 oz       += pmegrid_g->offset[ZZ];
3650                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3651
3652                 if (sx == 0 && sy == 0 && sz == 0)
3653                 {
3654                     /* We have already added our local contribution
3655                      * before calling this routine, so skip it here.
3656                      */
3657                     continue;
3658                 }
3659
3660                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3661
3662                 pmegrid_f = &pmegrids->grid_th[thread_f];
3663
3664                 grid_th = pmegrid_f->grid;
3665
3666                 nsx = pmegrid_f->s[XX];
3667                 nsy = pmegrid_f->s[YY];
3668                 nsz = pmegrid_f->s[ZZ];
3669
3670 #ifdef DEBUG_PME_REDUCE
3671                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3672                        pme->nodeid, thread, thread_f,
3673                        pme->pmegrid_start_ix,
3674                        pme->pmegrid_start_iy,
3675                        pme->pmegrid_start_iz,
3676                        sx, sy, sz,
3677                        offx-ox, tx1-ox, offx, tx1,
3678                        offy-oy, ty1-oy, offy, ty1,
3679                        offz-oz, tz1-oz, offz, tz1);
3680 #endif
3681
3682                 if (!(bCommX || bCommY))
3683                 {
3684                     /* Copy from the thread local grid to the node grid */
3685                     for (x = offx; x < tx1; x++)
3686                     {
3687                         for (y = offy; y < ty1; y++)
3688                         {
3689                             i0  = (x*fft_my + y)*fft_mz;
3690                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3691                             for (z = offz; z < tz1; z++)
3692                             {
3693                                 fftgrid[i0+z] += grid_th[i0t+z];
3694                             }
3695                         }
3696                     }
3697                 }
3698                 else
3699                 {
3700                     /* The order of this conditional decides
3701                      * where the corner volume gets stored with x+y decomp.
3702                      */
3703                     if (bCommY)
3704                     {
3705                         commbuf = commbuf_y;
3706                         buf_my  = ty1 - offy;
3707                         if (bCommX)
3708                         {
3709                             /* We index commbuf modulo the local grid size */
3710                             commbuf += buf_my*fft_nx*fft_nz;
3711
3712                             bClearBuf   = bClearBufXY;
3713                             bClearBufXY = FALSE;
3714                         }
3715                         else
3716                         {
3717                             bClearBuf  = bClearBufY;
3718                             bClearBufY = FALSE;
3719                         }
3720                     }
3721                     else
3722                     {
3723                         commbuf    = commbuf_x;
3724                         buf_my     = fft_ny;
3725                         bClearBuf  = bClearBufX;
3726                         bClearBufX = FALSE;
3727                     }
3728
3729                     /* Copy to the communication buffer */
3730                     for (x = offx; x < tx1; x++)
3731                     {
3732                         for (y = offy; y < ty1; y++)
3733                         {
3734                             i0  = (x*buf_my + y)*fft_nz;
3735                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3736
3737                             if (bClearBuf)
3738                             {
3739                                 /* First access of commbuf, initialize it */
3740                                 for (z = offz; z < tz1; z++)
3741                                 {
3742                                     commbuf[i0+z]  = grid_th[i0t+z];
3743                                 }
3744                             }
3745                             else
3746                             {
3747                                 for (z = offz; z < tz1; z++)
3748                                 {
3749                                     commbuf[i0+z] += grid_th[i0t+z];
3750                                 }
3751                             }
3752                         }
3753                     }
3754                 }
3755             }
3756         }
3757     }
3758 }
3759
3760
3761 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3762 {
3763     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3764     pme_overlap_t *overlap;
3765     int  send_index0, send_nindex;
3766     int  recv_nindex;
3767 #ifdef GMX_MPI
3768     MPI_Status stat;
3769 #endif
3770     int  send_size_y, recv_size_y;
3771     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3772     real *sendptr, *recvptr;
3773     int  x, y, z, indg, indb;
3774
3775     /* Note that this routine is only used for forward communication.
3776      * Since the force gathering, unlike the charge spreading,
3777      * can be trivially parallelized over the particles,
3778      * the backwards process is much simpler and can use the "old"
3779      * communication setup.
3780      */
3781
3782     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3783                                    local_fft_ndata,
3784                                    local_fft_offset,
3785                                    local_fft_size);
3786
3787     if (pme->nnodes_minor > 1)
3788     {
3789         /* Major dimension */
3790         overlap = &pme->overlap[1];
3791
3792         if (pme->nnodes_major > 1)
3793         {
3794             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3795         }
3796         else
3797         {
3798             size_yx = 0;
3799         }
3800         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3801
3802         send_size_y = overlap->send_size;
3803
3804         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3805         {
3806             send_id       = overlap->send_id[ipulse];
3807             recv_id       = overlap->recv_id[ipulse];
3808             send_index0   =
3809                 overlap->comm_data[ipulse].send_index0 -
3810                 overlap->comm_data[0].send_index0;
3811             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3812             /* We don't use recv_index0, as we always receive starting at 0 */
3813             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3814             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3815
3816             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3817             recvptr = overlap->recvbuf;
3818
3819 #ifdef GMX_MPI
3820             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3821                          send_id, ipulse,
3822                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3823                          recv_id, ipulse,
3824                          overlap->mpi_comm, &stat);
3825 #endif
3826
3827             for (x = 0; x < local_fft_ndata[XX]; x++)
3828             {
3829                 for (y = 0; y < recv_nindex; y++)
3830                 {
3831                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3832                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3833                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3834                     {
3835                         fftgrid[indg+z] += recvptr[indb+z];
3836                     }
3837                 }
3838             }
3839
3840             if (pme->nnodes_major > 1)
3841             {
3842                 /* Copy from the received buffer to the send buffer for dim 0 */
3843                 sendptr = pme->overlap[0].sendbuf;
3844                 for (x = 0; x < size_yx; x++)
3845                 {
3846                     for (y = 0; y < recv_nindex; y++)
3847                     {
3848                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3849                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3850                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3851                         {
3852                             sendptr[indg+z] += recvptr[indb+z];
3853                         }
3854                     }
3855                 }
3856             }
3857         }
3858     }
3859
3860     /* We only support a single pulse here.
3861      * This is not a severe limitation, as this code is only used
3862      * with OpenMP and with OpenMP the (PME) domains can be larger.
3863      */
3864     if (pme->nnodes_major > 1)
3865     {
3866         /* Major dimension */
3867         overlap = &pme->overlap[0];
3868
3869         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3870         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3871
3872         ipulse = 0;
3873
3874         send_id       = overlap->send_id[ipulse];
3875         recv_id       = overlap->recv_id[ipulse];
3876         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3877         /* We don't use recv_index0, as we always receive starting at 0 */
3878         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3879
3880         sendptr = overlap->sendbuf;
3881         recvptr = overlap->recvbuf;
3882
3883         if (debug != NULL)
3884         {
3885             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3886                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3887         }
3888
3889 #ifdef GMX_MPI
3890         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3891                      send_id, ipulse,
3892                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3893                      recv_id, ipulse,
3894                      overlap->mpi_comm, &stat);
3895 #endif
3896
3897         for (x = 0; x < recv_nindex; x++)
3898         {
3899             for (y = 0; y < local_fft_ndata[YY]; y++)
3900             {
3901                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3902                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3903                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3904                 {
3905                     fftgrid[indg+z] += recvptr[indb+z];
3906                 }
3907             }
3908         }
3909     }
3910 }
3911
3912
3913 static void spread_on_grid(gmx_pme_t pme,
3914                            pme_atomcomm_t *atc, pmegrids_t *grids,
3915                            gmx_bool bCalcSplines, gmx_bool bSpread,
3916                            real *fftgrid)
3917 {
3918     int nthread, thread;
3919 #ifdef PME_TIME_THREADS
3920     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3921     static double cs1     = 0, cs2 = 0, cs3 = 0;
3922     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3923     static int cnt        = 0;
3924 #endif
3925
3926     nthread = pme->nthread;
3927     assert(nthread > 0);
3928
3929 #ifdef PME_TIME_THREADS
3930     c1 = omp_cyc_start();
3931 #endif
3932     if (bCalcSplines)
3933     {
3934 #pragma omp parallel for num_threads(nthread) schedule(static)
3935         for (thread = 0; thread < nthread; thread++)
3936         {
3937             int start, end;
3938
3939             start = atc->n* thread   /nthread;
3940             end   = atc->n*(thread+1)/nthread;
3941
3942             /* Compute fftgrid index for all atoms,
3943              * with help of some extra variables.
3944              */
3945             calc_interpolation_idx(pme, atc, start, end, thread);
3946         }
3947     }
3948 #ifdef PME_TIME_THREADS
3949     c1   = omp_cyc_end(c1);
3950     cs1 += (double)c1;
3951 #endif
3952
3953 #ifdef PME_TIME_THREADS
3954     c2 = omp_cyc_start();
3955 #endif
3956 #pragma omp parallel for num_threads(nthread) schedule(static)
3957     for (thread = 0; thread < nthread; thread++)
3958     {
3959         splinedata_t *spline;
3960         pmegrid_t *grid = NULL;
3961
3962         /* make local bsplines  */
3963         if (grids == NULL || !pme->bUseThreads)
3964         {
3965             spline = &atc->spline[0];
3966
3967             spline->n = atc->n;
3968
3969             if (bSpread)
3970             {
3971                 grid = &grids->grid;
3972             }
3973         }
3974         else
3975         {
3976             spline = &atc->spline[thread];
3977
3978             if (grids->nthread == 1)
3979             {
3980                 /* One thread, we operate on all charges */
3981                 spline->n = atc->n;
3982             }
3983             else
3984             {
3985                 /* Get the indices our thread should operate on */
3986                 make_thread_local_ind(atc, thread, spline);
3987             }
3988
3989             grid = &grids->grid_th[thread];
3990         }
3991
3992         if (bCalcSplines)
3993         {
3994             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
3995                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
3996         }
3997
3998         if (bSpread)
3999         {
4000             /* put local atoms on grid. */
4001 #ifdef PME_TIME_SPREAD
4002             ct1a = omp_cyc_start();
4003 #endif
4004             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
4005
4006             if (pme->bUseThreads)
4007             {
4008                 copy_local_grid(pme, grids, thread, fftgrid);
4009             }
4010 #ifdef PME_TIME_SPREAD
4011             ct1a          = omp_cyc_end(ct1a);
4012             cs1a[thread] += (double)ct1a;
4013 #endif
4014         }
4015     }
4016 #ifdef PME_TIME_THREADS
4017     c2   = omp_cyc_end(c2);
4018     cs2 += (double)c2;
4019 #endif
4020
4021     if (bSpread && pme->bUseThreads)
4022     {
4023 #ifdef PME_TIME_THREADS
4024         c3 = omp_cyc_start();
4025 #endif
4026 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
4027         for (thread = 0; thread < grids->nthread; thread++)
4028         {
4029             reduce_threadgrid_overlap(pme, grids, thread,
4030                                       fftgrid,
4031                                       pme->overlap[0].sendbuf,
4032                                       pme->overlap[1].sendbuf);
4033         }
4034 #ifdef PME_TIME_THREADS
4035         c3   = omp_cyc_end(c3);
4036         cs3 += (double)c3;
4037 #endif
4038
4039         if (pme->nnodes > 1)
4040         {
4041             /* Communicate the overlapping part of the fftgrid.
4042              * For this communication call we need to check pme->bUseThreads
4043              * to have all ranks communicate here, regardless of pme->nthread.
4044              */
4045             sum_fftgrid_dd(pme, fftgrid);
4046         }
4047     }
4048
4049 #ifdef PME_TIME_THREADS
4050     cnt++;
4051     if (cnt % 20 == 0)
4052     {
4053         printf("idx %.2f spread %.2f red %.2f",
4054                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4055 #ifdef PME_TIME_SPREAD
4056         for (thread = 0; thread < nthread; thread++)
4057         {
4058             printf(" %.2f", cs1a[thread]*1e-9);
4059         }
4060 #endif
4061         printf("\n");
4062     }
4063 #endif
4064 }
4065
4066
4067 static void dump_grid(FILE *fp,
4068                       int sx, int sy, int sz, int nx, int ny, int nz,
4069                       int my, int mz, const real *g)
4070 {
4071     int x, y, z;
4072
4073     for (x = 0; x < nx; x++)
4074     {
4075         for (y = 0; y < ny; y++)
4076         {
4077             for (z = 0; z < nz; z++)
4078             {
4079                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4080                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4081             }
4082         }
4083     }
4084 }
4085
4086 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4087 {
4088     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4089
4090     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4091                                    local_fft_ndata,
4092                                    local_fft_offset,
4093                                    local_fft_size);
4094
4095     dump_grid(stderr,
4096               pme->pmegrid_start_ix,
4097               pme->pmegrid_start_iy,
4098               pme->pmegrid_start_iz,
4099               pme->pmegrid_nx-pme->pme_order+1,
4100               pme->pmegrid_ny-pme->pme_order+1,
4101               pme->pmegrid_nz-pme->pme_order+1,
4102               local_fft_size[YY],
4103               local_fft_size[ZZ],
4104               fftgrid);
4105 }
4106
4107
4108 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4109 {
4110     pme_atomcomm_t *atc;
4111     pmegrids_t *grid;
4112
4113     if (pme->nnodes > 1)
4114     {
4115         gmx_incons("gmx_pme_calc_energy called in parallel");
4116     }
4117     if (pme->bFEP > 1)
4118     {
4119         gmx_incons("gmx_pme_calc_energy with free energy");
4120     }
4121
4122     atc            = &pme->atc_energy;
4123     atc->nthread   = 1;
4124     if (atc->spline == NULL)
4125     {
4126         snew(atc->spline, atc->nthread);
4127     }
4128     atc->nslab     = 1;
4129     atc->bSpread   = TRUE;
4130     atc->pme_order = pme->pme_order;
4131     atc->n         = n;
4132     pme_realloc_atomcomm_things(atc);
4133     atc->x         = x;
4134     atc->q         = q;
4135
4136     /* We only use the A-charges grid */
4137     grid = &pme->pmegridA;
4138
4139     /* Only calculate the spline coefficients, don't actually spread */
4140     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4141
4142     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4143 }
4144
4145
4146 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4147                                    t_nrnb *nrnb, t_inputrec *ir,
4148                                    gmx_large_int_t step)
4149 {
4150     /* Reset all the counters related to performance over the run */
4151     wallcycle_stop(wcycle, ewcRUN);
4152     wallcycle_reset_all(wcycle);
4153     init_nrnb(nrnb);
4154     if (ir->nsteps >= 0)
4155     {
4156         /* ir->nsteps is not used here, but we update it for consistency */
4157         ir->nsteps -= step - ir->init_step;
4158     }
4159     ir->init_step = step;
4160     wallcycle_start(wcycle, ewcRUN);
4161 }
4162
4163
4164 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4165                                ivec grid_size,
4166                                t_commrec *cr, t_inputrec *ir,
4167                                gmx_pme_t *pme_ret)
4168 {
4169     int ind;
4170     gmx_pme_t pme = NULL;
4171
4172     ind = 0;
4173     while (ind < *npmedata)
4174     {
4175         pme = (*pmedata)[ind];
4176         if (pme->nkx == grid_size[XX] &&
4177             pme->nky == grid_size[YY] &&
4178             pme->nkz == grid_size[ZZ])
4179         {
4180             *pme_ret = pme;
4181
4182             return;
4183         }
4184
4185         ind++;
4186     }
4187
4188     (*npmedata)++;
4189     srenew(*pmedata, *npmedata);
4190
4191     /* Generate a new PME data structure, copying part of the old pointers */
4192     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4193
4194     *pme_ret = (*pmedata)[ind];
4195 }
4196
4197
4198 int gmx_pmeonly(gmx_pme_t pme,
4199                 t_commrec *cr,    t_nrnb *nrnb,
4200                 gmx_wallcycle_t wcycle,
4201                 real ewaldcoeff,  gmx_bool bGatherOnly,
4202                 t_inputrec *ir)
4203 {
4204     int npmedata;
4205     gmx_pme_t *pmedata;
4206     gmx_pme_pp_t pme_pp;
4207     int  ret;
4208     int  natoms;
4209     matrix box;
4210     rvec *x_pp      = NULL, *f_pp = NULL;
4211     real *chargeA   = NULL, *chargeB = NULL;
4212     real lambda     = 0;
4213     int  maxshift_x = 0, maxshift_y = 0;
4214     real energy, dvdlambda;
4215     matrix vir;
4216     float cycles;
4217     int  count;
4218     gmx_bool bEnerVir;
4219     gmx_large_int_t step, step_rel;
4220     ivec grid_switch;
4221
4222     /* This data will only use with PME tuning, i.e. switching PME grids */
4223     npmedata = 1;
4224     snew(pmedata, npmedata);
4225     pmedata[0] = pme;
4226
4227     pme_pp = gmx_pme_pp_init(cr);
4228
4229     init_nrnb(nrnb);
4230
4231     count = 0;
4232     do /****** this is a quasi-loop over time steps! */
4233     {
4234         /* The reason for having a loop here is PME grid tuning/switching */
4235         do
4236         {
4237             /* Domain decomposition */
4238             ret = gmx_pme_recv_q_x(pme_pp,
4239                                    &natoms,
4240                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4241                                    &maxshift_x, &maxshift_y,
4242                                    &pme->bFEP, &lambda,
4243                                    &bEnerVir,
4244                                    &step,
4245                                    grid_switch, &ewaldcoeff);
4246
4247             if (ret == pmerecvqxSWITCHGRID)
4248             {
4249                 /* Switch the PME grid to grid_switch */
4250                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4251             }
4252
4253             if (ret == pmerecvqxRESETCOUNTERS)
4254             {
4255                 /* Reset the cycle and flop counters */
4256                 reset_pmeonly_counters(cr, wcycle, nrnb, ir, step);
4257             }
4258         }
4259         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4260
4261         if (ret == pmerecvqxFINISH)
4262         {
4263             /* We should stop: break out of the loop */
4264             break;
4265         }
4266
4267         step_rel = step - ir->init_step;
4268
4269         if (count == 0)
4270         {
4271             wallcycle_start(wcycle, ewcRUN);
4272         }
4273
4274         wallcycle_start(wcycle, ewcPMEMESH);
4275
4276         dvdlambda = 0;
4277         clear_mat(vir);
4278         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4279                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4280                    &energy, lambda, &dvdlambda,
4281                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4282
4283         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4284
4285         gmx_pme_send_force_vir_ener(pme_pp,
4286                                     f_pp, vir, energy, dvdlambda,
4287                                     cycles);
4288
4289         count++;
4290     } /***** end of quasi-loop, we stop with the break above */
4291     while (TRUE);
4292
4293     return 0;
4294 }
4295
4296 int gmx_pme_do(gmx_pme_t pme,
4297                int start,       int homenr,
4298                rvec x[],        rvec f[],
4299                real *chargeA,   real *chargeB,
4300                matrix box, t_commrec *cr,
4301                int  maxshift_x, int maxshift_y,
4302                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4303                matrix vir,      real ewaldcoeff,
4304                real *energy,    real lambda,
4305                real *dvdlambda, int flags)
4306 {
4307     int     q, d, i, j, ntot, npme;
4308     int     nx, ny, nz;
4309     int     n_d, local_ny;
4310     pme_atomcomm_t *atc = NULL;
4311     pmegrids_t *pmegrid = NULL;
4312     real    *grid       = NULL;
4313     real    *ptr;
4314     rvec    *x_d, *f_d;
4315     real    *charge = NULL, *q_d;
4316     real    energy_AB[2];
4317     matrix  vir_AB[2];
4318     gmx_bool bClearF;
4319     gmx_parallel_3dfft_t pfft_setup;
4320     real *  fftgrid;
4321     t_complex * cfftgrid;
4322     int     thread;
4323     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4324     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4325
4326     assert(pme->nnodes > 0);
4327     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4328
4329     if (pme->nnodes > 1)
4330     {
4331         atc      = &pme->atc[0];
4332         atc->npd = homenr;
4333         if (atc->npd > atc->pd_nalloc)
4334         {
4335             atc->pd_nalloc = over_alloc_dd(atc->npd);
4336             srenew(atc->pd, atc->pd_nalloc);
4337         }
4338         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4339     }
4340     else
4341     {
4342         /* This could be necessary for TPI */
4343         pme->atc[0].n = homenr;
4344     }
4345
4346     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4347     {
4348         if (q == 0)
4349         {
4350             pmegrid    = &pme->pmegridA;
4351             fftgrid    = pme->fftgridA;
4352             cfftgrid   = pme->cfftgridA;
4353             pfft_setup = pme->pfft_setupA;
4354             charge     = chargeA+start;
4355         }
4356         else
4357         {
4358             pmegrid    = &pme->pmegridB;
4359             fftgrid    = pme->fftgridB;
4360             cfftgrid   = pme->cfftgridB;
4361             pfft_setup = pme->pfft_setupB;
4362             charge     = chargeB+start;
4363         }
4364         grid = pmegrid->grid.grid;
4365         /* Unpack structure */
4366         if (debug)
4367         {
4368             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4369                     cr->nnodes, cr->nodeid);
4370             fprintf(debug, "Grid = %p\n", (void*)grid);
4371             if (grid == NULL)
4372             {
4373                 gmx_fatal(FARGS, "No grid!");
4374             }
4375         }
4376         where();
4377
4378         m_inv_ur0(box, pme->recipbox);
4379
4380         if (pme->nnodes == 1)
4381         {
4382             atc = &pme->atc[0];
4383             if (DOMAINDECOMP(cr))
4384             {
4385                 atc->n = homenr;
4386                 pme_realloc_atomcomm_things(atc);
4387             }
4388             atc->x = x;
4389             atc->q = charge;
4390             atc->f = f;
4391         }
4392         else
4393         {
4394             wallcycle_start(wcycle, ewcPME_REDISTXF);
4395             for (d = pme->ndecompdim-1; d >= 0; d--)
4396             {
4397                 if (d == pme->ndecompdim-1)
4398                 {
4399                     n_d = homenr;
4400                     x_d = x + start;
4401                     q_d = charge;
4402                 }
4403                 else
4404                 {
4405                     n_d = pme->atc[d+1].n;
4406                     x_d = atc->x;
4407                     q_d = atc->q;
4408                 }
4409                 atc      = &pme->atc[d];
4410                 atc->npd = n_d;
4411                 if (atc->npd > atc->pd_nalloc)
4412                 {
4413                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4414                     srenew(atc->pd, atc->pd_nalloc);
4415                 }
4416                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4417                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4418                 where();
4419
4420                 /* Redistribute x (only once) and qA or qB */
4421                 if (DOMAINDECOMP(cr))
4422                 {
4423                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4424                 }
4425                 else
4426                 {
4427                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4428                 }
4429             }
4430             where();
4431
4432             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4433         }
4434
4435         if (debug)
4436         {
4437             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4438                     cr->nodeid, atc->n);
4439         }
4440
4441         if (flags & GMX_PME_SPREAD_Q)
4442         {
4443             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4444
4445             /* Spread the charges on a grid */
4446             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4447
4448             if (q == 0)
4449             {
4450                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4451             }
4452             inc_nrnb(nrnb, eNR_SPREADQBSP,
4453                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4454
4455             if (!pme->bUseThreads)
4456             {
4457                 wrap_periodic_pmegrid(pme, grid);
4458
4459                 /* sum contributions to local grid from other nodes */
4460 #ifdef GMX_MPI
4461                 if (pme->nnodes > 1)
4462                 {
4463                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4464                     where();
4465                 }
4466 #endif
4467
4468                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4469             }
4470
4471             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4472
4473             /*
4474                dump_local_fftgrid(pme,fftgrid);
4475                exit(0);
4476              */
4477         }
4478
4479         /* Here we start a large thread parallel region */
4480 #pragma omp parallel num_threads(pme->nthread) private(thread)
4481         {
4482             thread = gmx_omp_get_thread_num();
4483             if (flags & GMX_PME_SOLVE)
4484             {
4485                 int loop_count;
4486
4487                 /* do 3d-fft */
4488                 if (thread == 0)
4489                 {
4490                     wallcycle_start(wcycle, ewcPME_FFT);
4491                 }
4492                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4493                                            fftgrid, cfftgrid, thread, wcycle);
4494                 if (thread == 0)
4495                 {
4496                     wallcycle_stop(wcycle, ewcPME_FFT);
4497                 }
4498                 where();
4499
4500                 /* solve in k-space for our local cells */
4501                 if (thread == 0)
4502                 {
4503                     wallcycle_start(wcycle, ewcPME_SOLVE);
4504                 }
4505                 loop_count =
4506                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4507                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4508                                   bCalcEnerVir,
4509                                   pme->nthread, thread);
4510                 if (thread == 0)
4511                 {
4512                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4513                     where();
4514                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4515                 }
4516             }
4517
4518             if (bCalcF)
4519             {
4520                 /* do 3d-invfft */
4521                 if (thread == 0)
4522                 {
4523                     where();
4524                     wallcycle_start(wcycle, ewcPME_FFT);
4525                 }
4526                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4527                                            cfftgrid, fftgrid, thread, wcycle);
4528                 if (thread == 0)
4529                 {
4530                     wallcycle_stop(wcycle, ewcPME_FFT);
4531
4532                     where();
4533
4534                     if (pme->nodeid == 0)
4535                     {
4536                         ntot  = pme->nkx*pme->nky*pme->nkz;
4537                         npme  = ntot*log((real)ntot)/log(2.0);
4538                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4539                     }
4540
4541                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4542                 }
4543
4544                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4545             }
4546         }
4547         /* End of thread parallel section.
4548          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4549          */
4550
4551         if (bCalcF)
4552         {
4553             /* distribute local grid to all nodes */
4554 #ifdef GMX_MPI
4555             if (pme->nnodes > 1)
4556             {
4557                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4558             }
4559 #endif
4560             where();
4561
4562             unwrap_periodic_pmegrid(pme, grid);
4563
4564             /* interpolate forces for our local atoms */
4565
4566             where();
4567
4568             /* If we are running without parallelization,
4569              * atc->f is the actual force array, not a buffer,
4570              * therefore we should not clear it.
4571              */
4572             bClearF = (q == 0 && PAR(cr));
4573 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4574             for (thread = 0; thread < pme->nthread; thread++)
4575             {
4576                 gather_f_bsplines(pme, grid, bClearF, atc,
4577                                   &atc->spline[thread],
4578                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4579             }
4580
4581             where();
4582
4583             inc_nrnb(nrnb, eNR_GATHERFBSP,
4584                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4585             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4586         }
4587
4588         if (bCalcEnerVir)
4589         {
4590             /* This should only be called on the master thread
4591              * and after the threads have synchronized.
4592              */
4593             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4594         }
4595     } /* of q-loop */
4596
4597     if (bCalcF && pme->nnodes > 1)
4598     {
4599         wallcycle_start(wcycle, ewcPME_REDISTXF);
4600         for (d = 0; d < pme->ndecompdim; d++)
4601         {
4602             atc = &pme->atc[d];
4603             if (d == pme->ndecompdim - 1)
4604             {
4605                 n_d = homenr;
4606                 f_d = f + start;
4607             }
4608             else
4609             {
4610                 n_d = pme->atc[d+1].n;
4611                 f_d = pme->atc[d+1].f;
4612             }
4613             if (DOMAINDECOMP(cr))
4614             {
4615                 dd_pmeredist_f(pme, atc, n_d, f_d,
4616                                d == pme->ndecompdim-1 && pme->bPPnode);
4617             }
4618             else
4619             {
4620                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4621             }
4622         }
4623
4624         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4625     }
4626     where();
4627
4628     if (bCalcEnerVir)
4629     {
4630         if (!pme->bFEP)
4631         {
4632             *energy = energy_AB[0];
4633             m_add(vir, vir_AB[0], vir);
4634         }
4635         else
4636         {
4637             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4638             *dvdlambda += energy_AB[1] - energy_AB[0];
4639             for (i = 0; i < DIM; i++)
4640             {
4641                 for (j = 0; j < DIM; j++)
4642                 {
4643                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4644                         lambda*vir_AB[1][i][j];
4645                 }
4646             }
4647         }
4648     }
4649     else
4650     {
4651         *energy = 0;
4652     }
4653
4654     if (debug)
4655     {
4656         fprintf(debug, "PME mesh energy: %g\n", *energy);
4657     }
4658
4659     return 0;
4660 }