Merge release-4-6 into master
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include <stdio.h>
71 #include <string.h>
72 #include <math.h>
73 #include <assert.h>
74 #include "typedefs.h"
75 #include "txtdump.h"
76 #include "vec.h"
77 #include "gmxcomplex.h"
78 #include "smalloc.h"
79 #include "futil.h"
80 #include "coulomb.h"
81 #include "gmx_fatal.h"
82 #include "pme.h"
83 #include "network.h"
84 #include "physics.h"
85 #include "nrnb.h"
86 #include "copyrite.h"
87 #include "gmx_wallcycle.h"
88 #include "gmx_parallel_3dfft.h"
89 #include "pdbio.h"
90 #include "gmx_cyclecounter.h"
91 #include "gmx_omp.h"
92 #include "macros.h"
93
94 /* Single precision, with SSE2 or higher available */
95 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
96
97 #include "gmx_x86_simd_single.h"
98
99 #define PME_SSE
100 /* Some old AMD processors could have problems with unaligned loads+stores */
101 #ifndef GMX_FAHCORE
102 #define PME_SSE_UNALIGNED
103 #endif
104 #endif
105
106 #define DFT_TOL 1e-7
107 /* #define PRT_FORCE */
108 /* conditions for on the fly time-measurement */
109 /* #define TAKETIME (step > 1 && timesteps < 10) */
110 #define TAKETIME FALSE
111
112 /* #define PME_TIME_THREADS */
113
114 #ifdef GMX_DOUBLE
115 #define mpi_type MPI_DOUBLE
116 #else
117 #define mpi_type MPI_FLOAT
118 #endif
119
120 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
121 #define GMX_CACHE_SEP 64
122
123 /* We only define a maximum to be able to use local arrays without allocation.
124  * An order larger than 12 should never be needed, even for test cases.
125  * If needed it can be changed here.
126  */
127 #define PME_ORDER_MAX 12
128
129 /* Internal datastructures */
130 typedef struct {
131     int send_index0;
132     int send_nindex;
133     int recv_index0;
134     int recv_nindex;
135     int recv_size;   /* Receive buffer width, used with OpenMP */
136 } pme_grid_comm_t;
137
138 typedef struct {
139 #ifdef GMX_MPI
140     MPI_Comm         mpi_comm;
141 #endif
142     int              nnodes, nodeid;
143     int             *s2g0;
144     int             *s2g1;
145     int              noverlap_nodes;
146     int             *send_id, *recv_id;
147     int              send_size; /* Send buffer width, used with OpenMP */
148     pme_grid_comm_t *comm_data;
149     real            *sendbuf;
150     real            *recvbuf;
151 } pme_overlap_t;
152
153 typedef struct {
154     int *n;      /* Cumulative counts of the number of particles per thread */
155     int  nalloc; /* Allocation size of i */
156     int *i;      /* Particle indices ordered on thread index (n) */
157 } thread_plist_t;
158
159 typedef struct {
160     int      *thread_one;
161     int       n;
162     int      *ind;
163     splinevec theta;
164     real     *ptr_theta_z;
165     splinevec dtheta;
166     real     *ptr_dtheta_z;
167 } splinedata_t;
168
169 typedef struct {
170     int      dimind;        /* The index of the dimension, 0=x, 1=y */
171     int      nslab;
172     int      nodeid;
173 #ifdef GMX_MPI
174     MPI_Comm mpi_comm;
175 #endif
176
177     int     *node_dest;     /* The nodes to send x and q to with DD */
178     int     *node_src;      /* The nodes to receive x and q from with DD */
179     int     *buf_index;     /* Index for commnode into the buffers */
180
181     int      maxshift;
182
183     int      npd;
184     int      pd_nalloc;
185     int     *pd;
186     int     *count;         /* The number of atoms to send to each node */
187     int    **count_thread;
188     int     *rcount;        /* The number of atoms to receive */
189
190     int      n;
191     int      nalloc;
192     rvec    *x;
193     real    *q;
194     rvec    *f;
195     gmx_bool bSpread;       /* These coordinates are used for spreading */
196     int      pme_order;
197     ivec    *idx;
198     rvec    *fractx;            /* Fractional coordinate relative to the
199                                  * lower cell boundary
200                                  */
201     int             nthread;
202     int            *thread_idx; /* Which thread should spread which charge */
203     thread_plist_t *thread_plist;
204     splinedata_t   *spline;
205 } pme_atomcomm_t;
206
207 #define FLBS  3
208 #define FLBSZ 4
209
210 typedef struct {
211     ivec  ci;     /* The spatial location of this grid         */
212     ivec  n;      /* The used size of *grid, including order-1 */
213     ivec  offset; /* The grid offset from the full node grid   */
214     int   order;  /* PME spreading order                       */
215     ivec  s;      /* The allocated size of *grid, s >= n       */
216     real *grid;   /* The grid local thread, size n             */
217 } pmegrid_t;
218
219 typedef struct {
220     pmegrid_t  grid;         /* The full node grid (non thread-local)            */
221     int        nthread;      /* The number of threads operating on this grid     */
222     ivec       nc;           /* The local spatial decomposition over the threads */
223     pmegrid_t *grid_th;      /* Array of grids for each thread                   */
224     real      *grid_all;     /* Allocated array for the grids in *grid_th        */
225     int      **g2t;          /* The grid to thread index                         */
226     ivec       nthread_comm; /* The number of threads to communicate with        */
227 } pmegrids_t;
228
229
230 typedef struct {
231 #ifdef PME_SSE
232     /* Masks for SSE aligned spreading and gathering */
233     __m128 mask_SSE0[6], mask_SSE1[6];
234 #else
235     int    dummy; /* C89 requires that struct has at least one member */
236 #endif
237 } pme_spline_work_t;
238
239 typedef struct {
240     /* work data for solve_pme */
241     int      nalloc;
242     real *   mhx;
243     real *   mhy;
244     real *   mhz;
245     real *   m2;
246     real *   denom;
247     real *   tmp1_alloc;
248     real *   tmp1;
249     real *   eterm;
250     real *   m2inv;
251
252     real     energy;
253     matrix   vir;
254 } pme_work_t;
255
256 typedef struct gmx_pme {
257     int           ndecompdim; /* The number of decomposition dimensions */
258     int           nodeid;     /* Our nodeid in mpi->mpi_comm */
259     int           nodeid_major;
260     int           nodeid_minor;
261     int           nnodes;    /* The number of nodes doing PME */
262     int           nnodes_major;
263     int           nnodes_minor;
264
265     MPI_Comm      mpi_comm;
266     MPI_Comm      mpi_comm_d[2]; /* Indexed on dimension, 0=x, 1=y */
267 #ifdef GMX_MPI
268     MPI_Datatype  rvec_mpi;      /* the pme vector's MPI type */
269 #endif
270
271     int        nthread;       /* The number of threads doing PME */
272
273     gmx_bool   bPPnode;       /* Node also does particle-particle forces */
274     gmx_bool   bFEP;          /* Compute Free energy contribution */
275     int        nkx, nky, nkz; /* Grid dimensions */
276     gmx_bool   bP3M;          /* Do P3M: optimize the influence function */
277     int        pme_order;
278     real       epsilon_r;
279
280     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
281     pmegrids_t pmegridB;
282     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
283     int        pmegrid_nx, pmegrid_ny, pmegrid_nz;
284     /* pmegrid_nz might be larger than strictly necessary to ensure
285      * memory alignment, pmegrid_nz_base gives the real base size.
286      */
287     int     pmegrid_nz_base;
288     /* The local PME grid starting indices */
289     int     pmegrid_start_ix, pmegrid_start_iy, pmegrid_start_iz;
290
291     /* Work data for spreading and gathering */
292     pme_spline_work_t    *spline_work;
293
294     real                 *fftgridA; /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
295     real                 *fftgridB; /* inside the interpolation grid, but separate for 2D PME decomp. */
296     int                   fftgrid_nx, fftgrid_ny, fftgrid_nz;
297
298     t_complex            *cfftgridA;  /* Grids for complex FFT data */
299     t_complex            *cfftgridB;
300     int                   cfftgrid_nx, cfftgrid_ny, cfftgrid_nz;
301
302     gmx_parallel_3dfft_t  pfft_setupA;
303     gmx_parallel_3dfft_t  pfft_setupB;
304
305     int                  *nnx, *nny, *nnz;
306     real                 *fshx, *fshy, *fshz;
307
308     pme_atomcomm_t        atc[2]; /* Indexed on decomposition index */
309     matrix                recipbox;
310     splinevec             bsp_mod;
311
312     pme_overlap_t         overlap[2]; /* Indexed on dimension, 0=x, 1=y */
313
314     pme_atomcomm_t        atc_energy; /* Only for gmx_pme_calc_energy */
315
316     rvec                 *bufv;       /* Communication buffer */
317     real                 *bufr;       /* Communication buffer */
318     int                   buf_nalloc; /* The communication buffer size */
319
320     /* thread local work data for solve_pme */
321     pme_work_t *work;
322
323     /* Work data for PME_redist */
324     gmx_bool redist_init;
325     int *    scounts;
326     int *    rcounts;
327     int *    sdispls;
328     int *    rdispls;
329     int *    sidx;
330     int *    idxa;
331     real *   redist_buf;
332     int      redist_buf_nalloc;
333
334     /* Work data for sum_qgrid */
335     real *   sum_qgrid_tmp;
336     real *   sum_qgrid_dd_tmp;
337 } t_gmx_pme;
338
339
340 static void calc_interpolation_idx(gmx_pme_t pme, pme_atomcomm_t *atc,
341                                    int start, int end, int thread)
342 {
343     int             i;
344     int            *idxptr, tix, tiy, tiz;
345     real           *xptr, *fptr, tx, ty, tz;
346     real            rxx, ryx, ryy, rzx, rzy, rzz;
347     int             nx, ny, nz;
348     int             start_ix, start_iy, start_iz;
349     int            *g2tx, *g2ty, *g2tz;
350     gmx_bool        bThreads;
351     int            *thread_idx = NULL;
352     thread_plist_t *tpl        = NULL;
353     int            *tpl_n      = NULL;
354     int             thread_i;
355
356     nx  = pme->nkx;
357     ny  = pme->nky;
358     nz  = pme->nkz;
359
360     start_ix = pme->pmegrid_start_ix;
361     start_iy = pme->pmegrid_start_iy;
362     start_iz = pme->pmegrid_start_iz;
363
364     rxx = pme->recipbox[XX][XX];
365     ryx = pme->recipbox[YY][XX];
366     ryy = pme->recipbox[YY][YY];
367     rzx = pme->recipbox[ZZ][XX];
368     rzy = pme->recipbox[ZZ][YY];
369     rzz = pme->recipbox[ZZ][ZZ];
370
371     g2tx = pme->pmegridA.g2t[XX];
372     g2ty = pme->pmegridA.g2t[YY];
373     g2tz = pme->pmegridA.g2t[ZZ];
374
375     bThreads = (atc->nthread > 1);
376     if (bThreads)
377     {
378         thread_idx = atc->thread_idx;
379
380         tpl   = &atc->thread_plist[thread];
381         tpl_n = tpl->n;
382         for (i = 0; i < atc->nthread; i++)
383         {
384             tpl_n[i] = 0;
385         }
386     }
387
388     for (i = start; i < end; i++)
389     {
390         xptr   = atc->x[i];
391         idxptr = atc->idx[i];
392         fptr   = atc->fractx[i];
393
394         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
395         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
396         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
397         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
398
399         tix = (int)(tx);
400         tiy = (int)(ty);
401         tiz = (int)(tz);
402
403         /* Because decomposition only occurs in x and y,
404          * we never have a fraction correction in z.
405          */
406         fptr[XX] = tx - tix + pme->fshx[tix];
407         fptr[YY] = ty - tiy + pme->fshy[tiy];
408         fptr[ZZ] = tz - tiz;
409
410         idxptr[XX] = pme->nnx[tix];
411         idxptr[YY] = pme->nny[tiy];
412         idxptr[ZZ] = pme->nnz[tiz];
413
414 #ifdef DEBUG
415         range_check(idxptr[XX], 0, pme->pmegrid_nx);
416         range_check(idxptr[YY], 0, pme->pmegrid_ny);
417         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
418 #endif
419
420         if (bThreads)
421         {
422             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
423             thread_idx[i] = thread_i;
424             tpl_n[thread_i]++;
425         }
426     }
427
428     if (bThreads)
429     {
430         /* Make a list of particle indices sorted on thread */
431
432         /* Get the cumulative count */
433         for (i = 1; i < atc->nthread; i++)
434         {
435             tpl_n[i] += tpl_n[i-1];
436         }
437         /* The current implementation distributes particles equally
438          * over the threads, so we could actually allocate for that
439          * in pme_realloc_atomcomm_things.
440          */
441         if (tpl_n[atc->nthread-1] > tpl->nalloc)
442         {
443             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
444             srenew(tpl->i, tpl->nalloc);
445         }
446         /* Set tpl_n to the cumulative start */
447         for (i = atc->nthread-1; i >= 1; i--)
448         {
449             tpl_n[i] = tpl_n[i-1];
450         }
451         tpl_n[0] = 0;
452
453         /* Fill our thread local array with indices sorted on thread */
454         for (i = start; i < end; i++)
455         {
456             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
457         }
458         /* Now tpl_n contains the cummulative count again */
459     }
460 }
461
462 static void make_thread_local_ind(pme_atomcomm_t *atc,
463                                   int thread, splinedata_t *spline)
464 {
465     int             n, t, i, start, end;
466     thread_plist_t *tpl;
467
468     /* Combine the indices made by each thread into one index */
469
470     n     = 0;
471     start = 0;
472     for (t = 0; t < atc->nthread; t++)
473     {
474         tpl = &atc->thread_plist[t];
475         /* Copy our part (start - end) from the list of thread t */
476         if (thread > 0)
477         {
478             start = tpl->n[thread-1];
479         }
480         end = tpl->n[thread];
481         for (i = start; i < end; i++)
482         {
483             spline->ind[n++] = tpl->i[i];
484         }
485     }
486
487     spline->n = n;
488 }
489
490
491 static void pme_calc_pidx(int start, int end,
492                           matrix recipbox, rvec x[],
493                           pme_atomcomm_t *atc, int *count)
494 {
495     int   nslab, i;
496     int   si;
497     real *xptr, s;
498     real  rxx, ryx, rzx, ryy, rzy;
499     int  *pd;
500
501     /* Calculate PME task index (pidx) for each grid index.
502      * Here we always assign equally sized slabs to each node
503      * for load balancing reasons (the PME grid spacing is not used).
504      */
505
506     nslab = atc->nslab;
507     pd    = atc->pd;
508
509     /* Reset the count */
510     for (i = 0; i < nslab; i++)
511     {
512         count[i] = 0;
513     }
514
515     if (atc->dimind == 0)
516     {
517         rxx = recipbox[XX][XX];
518         ryx = recipbox[YY][XX];
519         rzx = recipbox[ZZ][XX];
520         /* Calculate the node index in x-dimension */
521         for (i = start; i < end; i++)
522         {
523             xptr   = x[i];
524             /* Fractional coordinates along box vectors */
525             s     = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
526             si    = (int)(s + 2*nslab) % nslab;
527             pd[i] = si;
528             count[si]++;
529         }
530     }
531     else
532     {
533         ryy = recipbox[YY][YY];
534         rzy = recipbox[ZZ][YY];
535         /* Calculate the node index in y-dimension */
536         for (i = start; i < end; i++)
537         {
538             xptr   = x[i];
539             /* Fractional coordinates along box vectors */
540             s     = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
541             si    = (int)(s + 2*nslab) % nslab;
542             pd[i] = si;
543             count[si]++;
544         }
545     }
546 }
547
548 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
549                                   pme_atomcomm_t *atc)
550 {
551     int nthread, thread, slab;
552
553     nthread = atc->nthread;
554
555 #pragma omp parallel for num_threads(nthread) schedule(static)
556     for (thread = 0; thread < nthread; thread++)
557     {
558         pme_calc_pidx(natoms* thread   /nthread,
559                       natoms*(thread+1)/nthread,
560                       recipbox, x, atc, atc->count_thread[thread]);
561     }
562     /* Non-parallel reduction, since nslab is small */
563
564     for (thread = 1; thread < nthread; thread++)
565     {
566         for (slab = 0; slab < atc->nslab; slab++)
567         {
568             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
569         }
570     }
571 }
572
573 static void realloc_splinevec(splinevec th, real **ptr_z, int nalloc)
574 {
575     const int padding = 4;
576     int       i;
577
578     srenew(th[XX], nalloc);
579     srenew(th[YY], nalloc);
580     /* In z we add padding, this is only required for the aligned SSE code */
581     srenew(*ptr_z, nalloc+2*padding);
582     th[ZZ] = *ptr_z + padding;
583
584     for (i = 0; i < padding; i++)
585     {
586         (*ptr_z)[               i] = 0;
587         (*ptr_z)[padding+nalloc+i] = 0;
588     }
589 }
590
591 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
592 {
593     int i, d;
594
595     srenew(spline->ind, atc->nalloc);
596     /* Initialize the index to identity so it works without threads */
597     for (i = 0; i < atc->nalloc; i++)
598     {
599         spline->ind[i] = i;
600     }
601
602     realloc_splinevec(spline->theta, &spline->ptr_theta_z,
603                       atc->pme_order*atc->nalloc);
604     realloc_splinevec(spline->dtheta, &spline->ptr_dtheta_z,
605                       atc->pme_order*atc->nalloc);
606 }
607
608 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
609 {
610     int nalloc_old, i, j, nalloc_tpl;
611
612     /* We have to avoid a NULL pointer for atc->x to avoid
613      * possible fatal errors in MPI routines.
614      */
615     if (atc->n > atc->nalloc || atc->nalloc == 0)
616     {
617         nalloc_old  = atc->nalloc;
618         atc->nalloc = over_alloc_dd(max(atc->n, 1));
619
620         if (atc->nslab > 1)
621         {
622             srenew(atc->x, atc->nalloc);
623             srenew(atc->q, atc->nalloc);
624             srenew(atc->f, atc->nalloc);
625             for (i = nalloc_old; i < atc->nalloc; i++)
626             {
627                 clear_rvec(atc->f[i]);
628             }
629         }
630         if (atc->bSpread)
631         {
632             srenew(atc->fractx, atc->nalloc);
633             srenew(atc->idx, atc->nalloc);
634
635             if (atc->nthread > 1)
636             {
637                 srenew(atc->thread_idx, atc->nalloc);
638             }
639
640             for (i = 0; i < atc->nthread; i++)
641             {
642                 pme_realloc_splinedata(&atc->spline[i], atc);
643             }
644         }
645     }
646 }
647
648 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
649                          int n, gmx_bool bXF, rvec *x_f, real *charge,
650                          pme_atomcomm_t *atc)
651 /* Redistribute particle data for PME calculation */
652 /* domain decomposition by x coordinate           */
653 {
654     int *idxa;
655     int  i, ii;
656
657     if (FALSE == pme->redist_init)
658     {
659         snew(pme->scounts, atc->nslab);
660         snew(pme->rcounts, atc->nslab);
661         snew(pme->sdispls, atc->nslab);
662         snew(pme->rdispls, atc->nslab);
663         snew(pme->sidx, atc->nslab);
664         pme->redist_init = TRUE;
665     }
666     if (n > pme->redist_buf_nalloc)
667     {
668         pme->redist_buf_nalloc = over_alloc_dd(n);
669         srenew(pme->redist_buf, pme->redist_buf_nalloc*DIM);
670     }
671
672     pme->idxa = atc->pd;
673
674 #ifdef GMX_MPI
675     if (forw && bXF)
676     {
677         /* forward, redistribution from pp to pme */
678
679         /* Calculate send counts and exchange them with other nodes */
680         for (i = 0; (i < atc->nslab); i++)
681         {
682             pme->scounts[i] = 0;
683         }
684         for (i = 0; (i < n); i++)
685         {
686             pme->scounts[pme->idxa[i]]++;
687         }
688         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
689
690         /* Calculate send and receive displacements and index into send
691            buffer */
692         pme->sdispls[0] = 0;
693         pme->rdispls[0] = 0;
694         pme->sidx[0]    = 0;
695         for (i = 1; i < atc->nslab; i++)
696         {
697             pme->sdispls[i] = pme->sdispls[i-1]+pme->scounts[i-1];
698             pme->rdispls[i] = pme->rdispls[i-1]+pme->rcounts[i-1];
699             pme->sidx[i]    = pme->sdispls[i];
700         }
701         /* Total # of particles to be received */
702         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
703
704         pme_realloc_atomcomm_things(atc);
705
706         /* Copy particle coordinates into send buffer and exchange*/
707         for (i = 0; (i < n); i++)
708         {
709             ii = DIM*pme->sidx[pme->idxa[i]];
710             pme->sidx[pme->idxa[i]]++;
711             pme->redist_buf[ii+XX] = x_f[i][XX];
712             pme->redist_buf[ii+YY] = x_f[i][YY];
713             pme->redist_buf[ii+ZZ] = x_f[i][ZZ];
714         }
715         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
716                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
717                       pme->rvec_mpi, atc->mpi_comm);
718     }
719     if (forw)
720     {
721         /* Copy charge into send buffer and exchange*/
722         for (i = 0; i < atc->nslab; i++)
723         {
724             pme->sidx[i] = pme->sdispls[i];
725         }
726         for (i = 0; (i < n); i++)
727         {
728             ii = pme->sidx[pme->idxa[i]];
729             pme->sidx[pme->idxa[i]]++;
730             pme->redist_buf[ii] = charge[i];
731         }
732         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
733                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
734                       atc->mpi_comm);
735     }
736     else   /* backward, redistribution from pme to pp */
737     {
738         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
739                       pme->redist_buf, pme->scounts, pme->sdispls,
740                       pme->rvec_mpi, atc->mpi_comm);
741
742         /* Copy data from receive buffer */
743         for (i = 0; i < atc->nslab; i++)
744         {
745             pme->sidx[i] = pme->sdispls[i];
746         }
747         for (i = 0; (i < n); i++)
748         {
749             ii          = DIM*pme->sidx[pme->idxa[i]];
750             x_f[i][XX] += pme->redist_buf[ii+XX];
751             x_f[i][YY] += pme->redist_buf[ii+YY];
752             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
753             pme->sidx[pme->idxa[i]]++;
754         }
755     }
756 #endif
757 }
758
759 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
760                             gmx_bool bBackward, int shift,
761                             void *buf_s, int nbyte_s,
762                             void *buf_r, int nbyte_r)
763 {
764 #ifdef GMX_MPI
765     int        dest, src;
766     MPI_Status stat;
767
768     if (bBackward == FALSE)
769     {
770         dest = atc->node_dest[shift];
771         src  = atc->node_src[shift];
772     }
773     else
774     {
775         dest = atc->node_src[shift];
776         src  = atc->node_dest[shift];
777     }
778
779     if (nbyte_s > 0 && nbyte_r > 0)
780     {
781         MPI_Sendrecv(buf_s, nbyte_s, MPI_BYTE,
782                      dest, shift,
783                      buf_r, nbyte_r, MPI_BYTE,
784                      src, shift,
785                      atc->mpi_comm, &stat);
786     }
787     else if (nbyte_s > 0)
788     {
789         MPI_Send(buf_s, nbyte_s, MPI_BYTE,
790                  dest, shift,
791                  atc->mpi_comm);
792     }
793     else if (nbyte_r > 0)
794     {
795         MPI_Recv(buf_r, nbyte_r, MPI_BYTE,
796                  src, shift,
797                  atc->mpi_comm, &stat);
798     }
799 #endif
800 }
801
802 static void dd_pmeredist_x_q(gmx_pme_t pme,
803                              int n, gmx_bool bX, rvec *x, real *charge,
804                              pme_atomcomm_t *atc)
805 {
806     int *commnode, *buf_index;
807     int  nnodes_comm, i, nsend, local_pos, buf_pos, node, scount, rcount;
808
809     commnode  = atc->node_dest;
810     buf_index = atc->buf_index;
811
812     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
813
814     nsend = 0;
815     for (i = 0; i < nnodes_comm; i++)
816     {
817         buf_index[commnode[i]] = nsend;
818         nsend                 += atc->count[commnode[i]];
819     }
820     if (bX)
821     {
822         if (atc->count[atc->nodeid] + nsend != n)
823         {
824             gmx_fatal(FARGS, "%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
825                       "This usually means that your system is not well equilibrated.",
826                       n - (atc->count[atc->nodeid] + nsend),
827                       pme->nodeid, 'x'+atc->dimind);
828         }
829
830         if (nsend > pme->buf_nalloc)
831         {
832             pme->buf_nalloc = over_alloc_dd(nsend);
833             srenew(pme->bufv, pme->buf_nalloc);
834             srenew(pme->bufr, pme->buf_nalloc);
835         }
836
837         atc->n = atc->count[atc->nodeid];
838         for (i = 0; i < nnodes_comm; i++)
839         {
840             scount = atc->count[commnode[i]];
841             /* Communicate the count */
842             if (debug)
843             {
844                 fprintf(debug, "dimind %d PME node %d send to node %d: %d\n",
845                         atc->dimind, atc->nodeid, commnode[i], scount);
846             }
847             pme_dd_sendrecv(atc, FALSE, i,
848                             &scount, sizeof(int),
849                             &atc->rcount[i], sizeof(int));
850             atc->n += atc->rcount[i];
851         }
852
853         pme_realloc_atomcomm_things(atc);
854     }
855
856     local_pos = 0;
857     for (i = 0; i < n; i++)
858     {
859         node = atc->pd[i];
860         if (node == atc->nodeid)
861         {
862             /* Copy direct to the receive buffer */
863             if (bX)
864             {
865                 copy_rvec(x[i], atc->x[local_pos]);
866             }
867             atc->q[local_pos] = charge[i];
868             local_pos++;
869         }
870         else
871         {
872             /* Copy to the send buffer */
873             if (bX)
874             {
875                 copy_rvec(x[i], pme->bufv[buf_index[node]]);
876             }
877             pme->bufr[buf_index[node]] = charge[i];
878             buf_index[node]++;
879         }
880     }
881
882     buf_pos = 0;
883     for (i = 0; i < nnodes_comm; i++)
884     {
885         scount = atc->count[commnode[i]];
886         rcount = atc->rcount[i];
887         if (scount > 0 || rcount > 0)
888         {
889             if (bX)
890             {
891                 /* Communicate the coordinates */
892                 pme_dd_sendrecv(atc, FALSE, i,
893                                 pme->bufv[buf_pos], scount*sizeof(rvec),
894                                 atc->x[local_pos], rcount*sizeof(rvec));
895             }
896             /* Communicate the charges */
897             pme_dd_sendrecv(atc, FALSE, i,
898                             pme->bufr+buf_pos, scount*sizeof(real),
899                             atc->q+local_pos, rcount*sizeof(real));
900             buf_pos   += scount;
901             local_pos += atc->rcount[i];
902         }
903     }
904 }
905
906 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
907                            int n, rvec *f,
908                            gmx_bool bAddF)
909 {
910     int *commnode, *buf_index;
911     int  nnodes_comm, local_pos, buf_pos, i, scount, rcount, node;
912
913     commnode  = atc->node_dest;
914     buf_index = atc->buf_index;
915
916     nnodes_comm = min(2*atc->maxshift, atc->nslab-1);
917
918     local_pos = atc->count[atc->nodeid];
919     buf_pos   = 0;
920     for (i = 0; i < nnodes_comm; i++)
921     {
922         scount = atc->rcount[i];
923         rcount = atc->count[commnode[i]];
924         if (scount > 0 || rcount > 0)
925         {
926             /* Communicate the forces */
927             pme_dd_sendrecv(atc, TRUE, i,
928                             atc->f[local_pos], scount*sizeof(rvec),
929                             pme->bufv[buf_pos], rcount*sizeof(rvec));
930             local_pos += scount;
931         }
932         buf_index[commnode[i]] = buf_pos;
933         buf_pos               += rcount;
934     }
935
936     local_pos = 0;
937     if (bAddF)
938     {
939         for (i = 0; i < n; i++)
940         {
941             node = atc->pd[i];
942             if (node == atc->nodeid)
943             {
944                 /* Add from the local force array */
945                 rvec_inc(f[i], atc->f[local_pos]);
946                 local_pos++;
947             }
948             else
949             {
950                 /* Add from the receive buffer */
951                 rvec_inc(f[i], pme->bufv[buf_index[node]]);
952                 buf_index[node]++;
953             }
954         }
955     }
956     else
957     {
958         for (i = 0; i < n; i++)
959         {
960             node = atc->pd[i];
961             if (node == atc->nodeid)
962             {
963                 /* Copy from the local force array */
964                 copy_rvec(atc->f[local_pos], f[i]);
965                 local_pos++;
966             }
967             else
968             {
969                 /* Copy from the receive buffer */
970                 copy_rvec(pme->bufv[buf_index[node]], f[i]);
971                 buf_index[node]++;
972             }
973         }
974     }
975 }
976
977 #ifdef GMX_MPI
978 static void
979 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
980 {
981     pme_overlap_t *overlap;
982     int            send_index0, send_nindex;
983     int            recv_index0, recv_nindex;
984     MPI_Status     stat;
985     int            i, j, k, ix, iy, iz, icnt;
986     int            ipulse, send_id, recv_id, datasize;
987     real          *p;
988     real          *sendptr, *recvptr;
989
990     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
991     overlap = &pme->overlap[1];
992
993     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
994     {
995         /* Since we have already (un)wrapped the overlap in the z-dimension,
996          * we only have to communicate 0 to nkz (not pmegrid_nz).
997          */
998         if (direction == GMX_SUM_QGRID_FORWARD)
999         {
1000             send_id       = overlap->send_id[ipulse];
1001             recv_id       = overlap->recv_id[ipulse];
1002             send_index0   = overlap->comm_data[ipulse].send_index0;
1003             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1004             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1005             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1006         }
1007         else
1008         {
1009             send_id       = overlap->recv_id[ipulse];
1010             recv_id       = overlap->send_id[ipulse];
1011             send_index0   = overlap->comm_data[ipulse].recv_index0;
1012             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1013             recv_index0   = overlap->comm_data[ipulse].send_index0;
1014             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1015         }
1016
1017         /* Copy data to contiguous send buffer */
1018         if (debug)
1019         {
1020             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1021                     pme->nodeid, overlap->nodeid, send_id,
1022                     pme->pmegrid_start_iy,
1023                     send_index0-pme->pmegrid_start_iy,
1024                     send_index0-pme->pmegrid_start_iy+send_nindex);
1025         }
1026         icnt = 0;
1027         for (i = 0; i < pme->pmegrid_nx; i++)
1028         {
1029             ix = i;
1030             for (j = 0; j < send_nindex; j++)
1031             {
1032                 iy = j + send_index0 - pme->pmegrid_start_iy;
1033                 for (k = 0; k < pme->nkz; k++)
1034                 {
1035                     iz = k;
1036                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
1037                 }
1038             }
1039         }
1040
1041         datasize      = pme->pmegrid_nx * pme->nkz;
1042
1043         MPI_Sendrecv(overlap->sendbuf, send_nindex*datasize, GMX_MPI_REAL,
1044                      send_id, ipulse,
1045                      overlap->recvbuf, recv_nindex*datasize, GMX_MPI_REAL,
1046                      recv_id, ipulse,
1047                      overlap->mpi_comm, &stat);
1048
1049         /* Get data from contiguous recv buffer */
1050         if (debug)
1051         {
1052             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1053                     pme->nodeid, overlap->nodeid, recv_id,
1054                     pme->pmegrid_start_iy,
1055                     recv_index0-pme->pmegrid_start_iy,
1056                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1057         }
1058         icnt = 0;
1059         for (i = 0; i < pme->pmegrid_nx; i++)
1060         {
1061             ix = i;
1062             for (j = 0; j < recv_nindex; j++)
1063             {
1064                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1065                 for (k = 0; k < pme->nkz; k++)
1066                 {
1067                     iz = k;
1068                     if (direction == GMX_SUM_QGRID_FORWARD)
1069                     {
1070                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1071                     }
1072                     else
1073                     {
1074                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1075                     }
1076                 }
1077             }
1078         }
1079     }
1080
1081     /* Major dimension is easier, no copying required,
1082      * but we might have to sum to separate array.
1083      * Since we don't copy, we have to communicate up to pmegrid_nz,
1084      * not nkz as for the minor direction.
1085      */
1086     overlap = &pme->overlap[0];
1087
1088     for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
1089     {
1090         if (direction == GMX_SUM_QGRID_FORWARD)
1091         {
1092             send_id       = overlap->send_id[ipulse];
1093             recv_id       = overlap->recv_id[ipulse];
1094             send_index0   = overlap->comm_data[ipulse].send_index0;
1095             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1096             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1097             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1098             recvptr       = overlap->recvbuf;
1099         }
1100         else
1101         {
1102             send_id       = overlap->recv_id[ipulse];
1103             recv_id       = overlap->send_id[ipulse];
1104             send_index0   = overlap->comm_data[ipulse].recv_index0;
1105             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1106             recv_index0   = overlap->comm_data[ipulse].send_index0;
1107             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1108             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1109         }
1110
1111         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1112         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1113
1114         if (debug)
1115         {
1116             fprintf(debug, "PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1117                     pme->nodeid, overlap->nodeid, send_id,
1118                     pme->pmegrid_start_ix,
1119                     send_index0-pme->pmegrid_start_ix,
1120                     send_index0-pme->pmegrid_start_ix+send_nindex);
1121             fprintf(debug, "PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1122                     pme->nodeid, overlap->nodeid, recv_id,
1123                     pme->pmegrid_start_ix,
1124                     recv_index0-pme->pmegrid_start_ix,
1125                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1126         }
1127
1128         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
1129                      send_id, ipulse,
1130                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
1131                      recv_id, ipulse,
1132                      overlap->mpi_comm, &stat);
1133
1134         /* ADD data from contiguous recv buffer */
1135         if (direction == GMX_SUM_QGRID_FORWARD)
1136         {
1137             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1138             for (i = 0; i < recv_nindex*datasize; i++)
1139             {
1140                 p[i] += overlap->recvbuf[i];
1141             }
1142         }
1143     }
1144 }
1145 #endif
1146
1147
1148 static int
1149 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1150 {
1151     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
1152     ivec    local_pme_size;
1153     int     i, ix, iy, iz;
1154     int     pmeidx, fftidx;
1155
1156     /* Dimensions should be identical for A/B grid, so we just use A here */
1157     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1158                                    local_fft_ndata,
1159                                    local_fft_offset,
1160                                    local_fft_size);
1161
1162     local_pme_size[0] = pme->pmegrid_nx;
1163     local_pme_size[1] = pme->pmegrid_ny;
1164     local_pme_size[2] = pme->pmegrid_nz;
1165
1166     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1167        the offset is identical, and the PME grid always has more data (due to overlap)
1168      */
1169     {
1170 #ifdef DEBUG_PME
1171         FILE *fp, *fp2;
1172         char  fn[STRLEN], format[STRLEN];
1173         real  val;
1174         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
1175         fp = ffopen(fn, "w");
1176         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
1177         fp2 = ffopen(fn, "w");
1178         sprintf(format, "%s%s\n", pdbformat, "%6.2f%6.2f");
1179 #endif
1180
1181         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
1182         {
1183             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
1184             {
1185                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1186                 {
1187                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1188                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1189                     fftgrid[fftidx] = pmegrid[pmeidx];
1190 #ifdef DEBUG_PME
1191                     val = 100*pmegrid[pmeidx];
1192                     if (pmegrid[pmeidx] != 0)
1193                     {
1194                         fprintf(fp, format, "ATOM", pmeidx, "CA", "GLY", ' ', pmeidx, ' ',
1195                                 5.0*ix, 5.0*iy, 5.0*iz, 1.0, val);
1196                     }
1197                     if (pmegrid[pmeidx] != 0)
1198                     {
1199                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
1200                                 "qgrid",
1201                                 pme->pmegrid_start_ix + ix,
1202                                 pme->pmegrid_start_iy + iy,
1203                                 pme->pmegrid_start_iz + iz,
1204                                 pmegrid[pmeidx]);
1205                     }
1206 #endif
1207                 }
1208             }
1209         }
1210 #ifdef DEBUG_PME
1211         ffclose(fp);
1212         ffclose(fp2);
1213 #endif
1214     }
1215     return 0;
1216 }
1217
1218
1219 static gmx_cycles_t omp_cyc_start()
1220 {
1221     return gmx_cycles_read();
1222 }
1223
1224 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1225 {
1226     return gmx_cycles_read() - c;
1227 }
1228
1229
1230 static int
1231 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1232                         int nthread, int thread)
1233 {
1234     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
1235     ivec          local_pme_size;
1236     int           ixy0, ixy1, ixy, ix, iy, iz;
1237     int           pmeidx, fftidx;
1238 #ifdef PME_TIME_THREADS
1239     gmx_cycles_t  c1;
1240     static double cs1 = 0;
1241     static int    cnt = 0;
1242 #endif
1243
1244 #ifdef PME_TIME_THREADS
1245     c1 = omp_cyc_start();
1246 #endif
1247     /* Dimensions should be identical for A/B grid, so we just use A here */
1248     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1249                                    local_fft_ndata,
1250                                    local_fft_offset,
1251                                    local_fft_size);
1252
1253     local_pme_size[0] = pme->pmegrid_nx;
1254     local_pme_size[1] = pme->pmegrid_ny;
1255     local_pme_size[2] = pme->pmegrid_nz;
1256
1257     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1258        the offset is identical, and the PME grid always has more data (due to overlap)
1259      */
1260     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1261     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1262
1263     for (ixy = ixy0; ixy < ixy1; ixy++)
1264     {
1265         ix = ixy/local_fft_ndata[YY];
1266         iy = ixy - ix*local_fft_ndata[YY];
1267
1268         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1269         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1270         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
1271         {
1272             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1273         }
1274     }
1275
1276 #ifdef PME_TIME_THREADS
1277     c1   = omp_cyc_end(c1);
1278     cs1 += (double)c1;
1279     cnt++;
1280     if (cnt % 20 == 0)
1281     {
1282         printf("copy %.2f\n", cs1*1e-9);
1283     }
1284 #endif
1285
1286     return 0;
1287 }
1288
1289
1290 static void
1291 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1292 {
1293     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix, iy, iz;
1294
1295     nx = pme->nkx;
1296     ny = pme->nky;
1297     nz = pme->nkz;
1298
1299     pnx = pme->pmegrid_nx;
1300     pny = pme->pmegrid_ny;
1301     pnz = pme->pmegrid_nz;
1302
1303     overlap = pme->pme_order - 1;
1304
1305     /* Add periodic overlap in z */
1306     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1307     {
1308         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1309         {
1310             for (iz = 0; iz < overlap; iz++)
1311             {
1312                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1313                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1314             }
1315         }
1316     }
1317
1318     if (pme->nnodes_minor == 1)
1319     {
1320         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1321         {
1322             for (iy = 0; iy < overlap; iy++)
1323             {
1324                 for (iz = 0; iz < nz; iz++)
1325                 {
1326                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1327                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
1328                 }
1329             }
1330         }
1331     }
1332
1333     if (pme->nnodes_major == 1)
1334     {
1335         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1336
1337         for (ix = 0; ix < overlap; ix++)
1338         {
1339             for (iy = 0; iy < ny_x; iy++)
1340             {
1341                 for (iz = 0; iz < nz; iz++)
1342                 {
1343                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1344                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1345                 }
1346             }
1347         }
1348     }
1349 }
1350
1351
1352 static void
1353 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1354 {
1355     int     nx, ny, nz, pnx, pny, pnz, ny_x, overlap, ix;
1356
1357     nx = pme->nkx;
1358     ny = pme->nky;
1359     nz = pme->nkz;
1360
1361     pnx = pme->pmegrid_nx;
1362     pny = pme->pmegrid_ny;
1363     pnz = pme->pmegrid_nz;
1364
1365     overlap = pme->pme_order - 1;
1366
1367     if (pme->nnodes_major == 1)
1368     {
1369         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1370
1371         for (ix = 0; ix < overlap; ix++)
1372         {
1373             int iy, iz;
1374
1375             for (iy = 0; iy < ny_x; iy++)
1376             {
1377                 for (iz = 0; iz < nz; iz++)
1378                 {
1379                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1380                         pmegrid[(ix*pny+iy)*pnz+iz];
1381                 }
1382             }
1383         }
1384     }
1385
1386     if (pme->nnodes_minor == 1)
1387     {
1388 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1389         for (ix = 0; ix < pme->pmegrid_nx; ix++)
1390         {
1391             int iy, iz;
1392
1393             for (iy = 0; iy < overlap; iy++)
1394             {
1395                 for (iz = 0; iz < nz; iz++)
1396                 {
1397                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1398                         pmegrid[(ix*pny+iy)*pnz+iz];
1399                 }
1400             }
1401         }
1402     }
1403
1404     /* Copy periodic overlap in z */
1405 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1406     for (ix = 0; ix < pme->pmegrid_nx; ix++)
1407     {
1408         int iy, iz;
1409
1410         for (iy = 0; iy < pme->pmegrid_ny; iy++)
1411         {
1412             for (iz = 0; iz < overlap; iz++)
1413             {
1414                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1415                     pmegrid[(ix*pny+iy)*pnz+iz];
1416             }
1417         }
1418     }
1419 }
1420
1421 static void clear_grid(int nx, int ny, int nz, real *grid,
1422                        ivec fs, int *flag,
1423                        int fx, int fy, int fz,
1424                        int order)
1425 {
1426     int nc, ncz;
1427     int fsx, fsy, fsz, gx, gy, gz, g0x, g0y, x, y, z;
1428     int flind;
1429
1430     nc  = 2 + (order - 2)/FLBS;
1431     ncz = 2 + (order - 2)/FLBSZ;
1432
1433     for (fsx = fx; fsx < fx+nc; fsx++)
1434     {
1435         for (fsy = fy; fsy < fy+nc; fsy++)
1436         {
1437             for (fsz = fz; fsz < fz+ncz; fsz++)
1438             {
1439                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1440                 if (flag[flind] == 0)
1441                 {
1442                     gx  = fsx*FLBS;
1443                     gy  = fsy*FLBS;
1444                     gz  = fsz*FLBSZ;
1445                     g0x = (gx*ny + gy)*nz + gz;
1446                     for (x = 0; x < FLBS; x++)
1447                     {
1448                         g0y = g0x;
1449                         for (y = 0; y < FLBS; y++)
1450                         {
1451                             for (z = 0; z < FLBSZ; z++)
1452                             {
1453                                 grid[g0y+z] = 0;
1454                             }
1455                             g0y += nz;
1456                         }
1457                         g0x += ny*nz;
1458                     }
1459
1460                     flag[flind] = 1;
1461                 }
1462             }
1463         }
1464     }
1465 }
1466
1467 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1468 #define DO_BSPLINE(order)                            \
1469     for (ithx = 0; (ithx < order); ithx++)                    \
1470     {                                                    \
1471         index_x = (i0+ithx)*pny*pnz;                     \
1472         valx    = qn*thx[ithx];                          \
1473                                                      \
1474         for (ithy = 0; (ithy < order); ithy++)                \
1475         {                                                \
1476             valxy    = valx*thy[ithy];                   \
1477             index_xy = index_x+(j0+ithy)*pnz;            \
1478                                                      \
1479             for (ithz = 0; (ithz < order); ithz++)            \
1480             {                                            \
1481                 index_xyz        = index_xy+(k0+ithz);   \
1482                 grid[index_xyz] += valxy*thz[ithz];      \
1483             }                                            \
1484         }                                                \
1485     }
1486
1487
1488 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1489                                      pme_atomcomm_t *atc, splinedata_t *spline,
1490                                      pme_spline_work_t *work)
1491 {
1492
1493     /* spread charges from home atoms to local grid */
1494     real          *grid;
1495     pme_overlap_t *ol;
1496     int            b, i, nn, n, ithx, ithy, ithz, i0, j0, k0;
1497     int       *    idxptr;
1498     int            order, norder, index_x, index_xy, index_xyz;
1499     real           valx, valxy, qn;
1500     real          *thx, *thy, *thz;
1501     int            localsize, bndsize;
1502     int            pnx, pny, pnz, ndatatot;
1503     int            offx, offy, offz;
1504
1505     pnx = pmegrid->s[XX];
1506     pny = pmegrid->s[YY];
1507     pnz = pmegrid->s[ZZ];
1508
1509     offx = pmegrid->offset[XX];
1510     offy = pmegrid->offset[YY];
1511     offz = pmegrid->offset[ZZ];
1512
1513     ndatatot = pnx*pny*pnz;
1514     grid     = pmegrid->grid;
1515     for (i = 0; i < ndatatot; i++)
1516     {
1517         grid[i] = 0;
1518     }
1519
1520     order = pmegrid->order;
1521
1522     for (nn = 0; nn < spline->n; nn++)
1523     {
1524         n  = spline->ind[nn];
1525         qn = atc->q[n];
1526
1527         if (qn != 0)
1528         {
1529             idxptr = atc->idx[n];
1530             norder = nn*order;
1531
1532             i0   = idxptr[XX] - offx;
1533             j0   = idxptr[YY] - offy;
1534             k0   = idxptr[ZZ] - offz;
1535
1536             thx = spline->theta[XX] + norder;
1537             thy = spline->theta[YY] + norder;
1538             thz = spline->theta[ZZ] + norder;
1539
1540             switch (order)
1541             {
1542                 case 4:
1543 #ifdef PME_SSE
1544 #ifdef PME_SSE_UNALIGNED
1545 #define PME_SPREAD_SSE_ORDER4
1546 #else
1547 #define PME_SPREAD_SSE_ALIGNED
1548 #define PME_ORDER 4
1549 #endif
1550 #include "pme_sse_single.h"
1551 #else
1552                     DO_BSPLINE(4);
1553 #endif
1554                     break;
1555                 case 5:
1556 #ifdef PME_SSE
1557 #define PME_SPREAD_SSE_ALIGNED
1558 #define PME_ORDER 5
1559 #include "pme_sse_single.h"
1560 #else
1561                     DO_BSPLINE(5);
1562 #endif
1563                     break;
1564                 default:
1565                     DO_BSPLINE(order);
1566                     break;
1567             }
1568         }
1569     }
1570 }
1571
1572 static void set_grid_alignment(int *pmegrid_nz, int pme_order)
1573 {
1574 #ifdef PME_SSE
1575     if (pme_order == 5
1576 #ifndef PME_SSE_UNALIGNED
1577         || pme_order == 4
1578 #endif
1579         )
1580     {
1581         /* Round nz up to a multiple of 4 to ensure alignment */
1582         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1583     }
1584 #endif
1585 }
1586
1587 static void set_gridsize_alignment(int *gridsize, int pme_order)
1588 {
1589 #ifdef PME_SSE
1590 #ifndef PME_SSE_UNALIGNED
1591     if (pme_order == 4)
1592     {
1593         /* Add extra elements to ensured aligned operations do not go
1594          * beyond the allocated grid size.
1595          * Note that for pme_order=5, the pme grid z-size alignment
1596          * ensures that we will not go beyond the grid size.
1597          */
1598         *gridsize += 4;
1599     }
1600 #endif
1601 #endif
1602 }
1603
1604 static void pmegrid_init(pmegrid_t *grid,
1605                          int cx, int cy, int cz,
1606                          int x0, int y0, int z0,
1607                          int x1, int y1, int z1,
1608                          gmx_bool set_alignment,
1609                          int pme_order,
1610                          real *ptr)
1611 {
1612     int nz, gridsize;
1613
1614     grid->ci[XX]     = cx;
1615     grid->ci[YY]     = cy;
1616     grid->ci[ZZ]     = cz;
1617     grid->offset[XX] = x0;
1618     grid->offset[YY] = y0;
1619     grid->offset[ZZ] = z0;
1620     grid->n[XX]      = x1 - x0 + pme_order - 1;
1621     grid->n[YY]      = y1 - y0 + pme_order - 1;
1622     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1623     copy_ivec(grid->n, grid->s);
1624
1625     nz = grid->s[ZZ];
1626     set_grid_alignment(&nz, pme_order);
1627     if (set_alignment)
1628     {
1629         grid->s[ZZ] = nz;
1630     }
1631     else if (nz != grid->s[ZZ])
1632     {
1633         gmx_incons("pmegrid_init call with an unaligned z size");
1634     }
1635
1636     grid->order = pme_order;
1637     if (ptr == NULL)
1638     {
1639         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1640         set_gridsize_alignment(&gridsize, pme_order);
1641         snew_aligned(grid->grid, gridsize, 16);
1642     }
1643     else
1644     {
1645         grid->grid = ptr;
1646     }
1647 }
1648
1649 static int div_round_up(int enumerator, int denominator)
1650 {
1651     return (enumerator + denominator - 1)/denominator;
1652 }
1653
1654 static void make_subgrid_division(const ivec n, int ovl, int nthread,
1655                                   ivec nsub)
1656 {
1657     int gsize_opt, gsize;
1658     int nsx, nsy, nsz;
1659     char *env;
1660
1661     gsize_opt = -1;
1662     for (nsx = 1; nsx <= nthread; nsx++)
1663     {
1664         if (nthread % nsx == 0)
1665         {
1666             for (nsy = 1; nsy <= nthread; nsy++)
1667             {
1668                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1669                 {
1670                     nsz = nthread/(nsx*nsy);
1671
1672                     /* Determine the number of grid points per thread */
1673                     gsize =
1674                         (div_round_up(n[XX], nsx) + ovl)*
1675                         (div_round_up(n[YY], nsy) + ovl)*
1676                         (div_round_up(n[ZZ], nsz) + ovl);
1677
1678                     /* Minimize the number of grids points per thread
1679                      * and, secondarily, the number of cuts in minor dimensions.
1680                      */
1681                     if (gsize_opt == -1 ||
1682                         gsize < gsize_opt ||
1683                         (gsize == gsize_opt &&
1684                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1685                     {
1686                         nsub[XX]  = nsx;
1687                         nsub[YY]  = nsy;
1688                         nsub[ZZ]  = nsz;
1689                         gsize_opt = gsize;
1690                     }
1691                 }
1692             }
1693         }
1694     }
1695
1696     env = getenv("GMX_PME_THREAD_DIVISION");
1697     if (env != NULL)
1698     {
1699         sscanf(env, "%d %d %d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
1700     }
1701
1702     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1703     {
1704         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
1705     }
1706 }
1707
1708 static void pmegrids_init(pmegrids_t *grids,
1709                           int nx, int ny, int nz, int nz_base,
1710                           int pme_order,
1711                           int nthread,
1712                           int overlap_x,
1713                           int overlap_y)
1714 {
1715     ivec n, n_base, g0, g1;
1716     int t, x, y, z, d, i, tfac;
1717     int max_comm_lines = -1;
1718
1719     n[XX] = nx - (pme_order - 1);
1720     n[YY] = ny - (pme_order - 1);
1721     n[ZZ] = nz - (pme_order - 1);
1722
1723     copy_ivec(n, n_base);
1724     n_base[ZZ] = nz_base;
1725
1726     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
1727                  NULL);
1728
1729     grids->nthread = nthread;
1730
1731     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
1732
1733     if (grids->nthread > 1)
1734     {
1735         ivec nst;
1736         int gridsize;
1737
1738         for (d = 0; d < DIM; d++)
1739         {
1740             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
1741         }
1742         set_grid_alignment(&nst[ZZ], pme_order);
1743
1744         if (debug)
1745         {
1746             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
1747                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
1748             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1749                     nx, ny, nz,
1750                     nst[XX], nst[YY], nst[ZZ]);
1751         }
1752
1753         snew(grids->grid_th, grids->nthread);
1754         t        = 0;
1755         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1756         set_gridsize_alignment(&gridsize, pme_order);
1757         snew_aligned(grids->grid_all,
1758                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1759                      16);
1760
1761         for (x = 0; x < grids->nc[XX]; x++)
1762         {
1763             for (y = 0; y < grids->nc[YY]; y++)
1764             {
1765                 for (z = 0; z < grids->nc[ZZ]; z++)
1766                 {
1767                     pmegrid_init(&grids->grid_th[t],
1768                                  x, y, z,
1769                                  (n[XX]*(x  ))/grids->nc[XX],
1770                                  (n[YY]*(y  ))/grids->nc[YY],
1771                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1772                                  (n[XX]*(x+1))/grids->nc[XX],
1773                                  (n[YY]*(y+1))/grids->nc[YY],
1774                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1775                                  TRUE,
1776                                  pme_order,
1777                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1778                     t++;
1779                 }
1780             }
1781         }
1782     }
1783
1784     snew(grids->g2t, DIM);
1785     tfac = 1;
1786     for (d = DIM-1; d >= 0; d--)
1787     {
1788         snew(grids->g2t[d], n[d]);
1789         t = 0;
1790         for (i = 0; i < n[d]; i++)
1791         {
1792             /* The second check should match the parameters
1793              * of the pmegrid_init call above.
1794              */
1795             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1796             {
1797                 t++;
1798             }
1799             grids->g2t[d][i] = t*tfac;
1800         }
1801
1802         tfac *= grids->nc[d];
1803
1804         switch (d)
1805         {
1806             case XX: max_comm_lines = overlap_x;     break;
1807             case YY: max_comm_lines = overlap_y;     break;
1808             case ZZ: max_comm_lines = pme_order - 1; break;
1809         }
1810         grids->nthread_comm[d] = 0;
1811         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1812                grids->nthread_comm[d] < grids->nc[d])
1813         {
1814             grids->nthread_comm[d]++;
1815         }
1816         if (debug != NULL)
1817         {
1818             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
1819                     'x'+d, grids->nthread_comm[d]);
1820         }
1821         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1822          * work, but this is not a problematic restriction.
1823          */
1824         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1825         {
1826             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
1827         }
1828     }
1829 }
1830
1831
1832 static void pmegrids_destroy(pmegrids_t *grids)
1833 {
1834     int t;
1835
1836     if (grids->grid.grid != NULL)
1837     {
1838         sfree(grids->grid.grid);
1839
1840         if (grids->nthread > 0)
1841         {
1842             for (t = 0; t < grids->nthread; t++)
1843             {
1844                 sfree(grids->grid_th[t].grid);
1845             }
1846             sfree(grids->grid_th);
1847         }
1848     }
1849 }
1850
1851
1852 static void realloc_work(pme_work_t *work, int nkx)
1853 {
1854     if (nkx > work->nalloc)
1855     {
1856         work->nalloc = nkx;
1857         srenew(work->mhx, work->nalloc);
1858         srenew(work->mhy, work->nalloc);
1859         srenew(work->mhz, work->nalloc);
1860         srenew(work->m2, work->nalloc);
1861         /* Allocate an aligned pointer for SSE operations, including 3 extra
1862          * elements at the end since SSE operates on 4 elements at a time.
1863          */
1864         sfree_aligned(work->denom);
1865         sfree_aligned(work->tmp1);
1866         sfree_aligned(work->eterm);
1867         snew_aligned(work->denom, work->nalloc+3, 16);
1868         snew_aligned(work->tmp1, work->nalloc+3, 16);
1869         snew_aligned(work->eterm, work->nalloc+3, 16);
1870         srenew(work->m2inv, work->nalloc);
1871     }
1872 }
1873
1874
1875 static void free_work(pme_work_t *work)
1876 {
1877     sfree(work->mhx);
1878     sfree(work->mhy);
1879     sfree(work->mhz);
1880     sfree(work->m2);
1881     sfree_aligned(work->denom);
1882     sfree_aligned(work->tmp1);
1883     sfree_aligned(work->eterm);
1884     sfree(work->m2inv);
1885 }
1886
1887
1888 #ifdef PME_SSE
1889 /* Calculate exponentials through SSE in float precision */
1890 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1891 {
1892     {
1893         const __m128 two = _mm_set_ps(2.0f, 2.0f, 2.0f, 2.0f);
1894         __m128 f_sse;
1895         __m128 lu;
1896         __m128 tmp_d1, d_inv, tmp_r, tmp_e;
1897         int kx;
1898         f_sse = _mm_load1_ps(&f);
1899         for (kx = 0; kx < end; kx += 4)
1900         {
1901             tmp_d1   = _mm_load_ps(d_aligned+kx);
1902             lu       = _mm_rcp_ps(tmp_d1);
1903             d_inv    = _mm_mul_ps(lu, _mm_sub_ps(two, _mm_mul_ps(lu, tmp_d1)));
1904             tmp_r    = _mm_load_ps(r_aligned+kx);
1905             tmp_r    = gmx_mm_exp_ps(tmp_r);
1906             tmp_e    = _mm_mul_ps(f_sse, d_inv);
1907             tmp_e    = _mm_mul_ps(tmp_e, tmp_r);
1908             _mm_store_ps(e_aligned+kx, tmp_e);
1909         }
1910     }
1911 }
1912 #else
1913 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1914 {
1915     int kx;
1916     for (kx = start; kx < end; kx++)
1917     {
1918         d[kx] = 1.0/d[kx];
1919     }
1920     for (kx = start; kx < end; kx++)
1921     {
1922         r[kx] = exp(r[kx]);
1923     }
1924     for (kx = start; kx < end; kx++)
1925     {
1926         e[kx] = f*r[kx]*d[kx];
1927     }
1928 }
1929 #endif
1930
1931
1932 static int solve_pme_yzx(gmx_pme_t pme, t_complex *grid,
1933                          real ewaldcoeff, real vol,
1934                          gmx_bool bEnerVir,
1935                          int nthread, int thread)
1936 {
1937     /* do recip sum over local cells in grid */
1938     /* y major, z middle, x minor or continuous */
1939     t_complex *p0;
1940     int     kx, ky, kz, maxkx, maxky, maxkz;
1941     int     nx, ny, nz, iyz0, iyz1, iyz, iy, iz, kxstart, kxend;
1942     real    mx, my, mz;
1943     real    factor = M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1944     real    ets2, struct2, vfactor, ets2vf;
1945     real    d1, d2, energy = 0;
1946     real    by, bz;
1947     real    virxx = 0, virxy = 0, virxz = 0, viryy = 0, viryz = 0, virzz = 0;
1948     real    rxx, ryx, ryy, rzx, rzy, rzz;
1949     pme_work_t *work;
1950     real    *mhx, *mhy, *mhz, *m2, *denom, *tmp1, *eterm, *m2inv;
1951     real    mhxk, mhyk, mhzk, m2k;
1952     real    corner_fac;
1953     ivec    complex_order;
1954     ivec    local_ndata, local_offset, local_size;
1955     real    elfac;
1956
1957     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1958
1959     nx = pme->nkx;
1960     ny = pme->nky;
1961     nz = pme->nkz;
1962
1963     /* Dimensions should be identical for A/B grid, so we just use A here */
1964     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1965                                       complex_order,
1966                                       local_ndata,
1967                                       local_offset,
1968                                       local_size);
1969
1970     rxx = pme->recipbox[XX][XX];
1971     ryx = pme->recipbox[YY][XX];
1972     ryy = pme->recipbox[YY][YY];
1973     rzx = pme->recipbox[ZZ][XX];
1974     rzy = pme->recipbox[ZZ][YY];
1975     rzz = pme->recipbox[ZZ][ZZ];
1976
1977     maxkx = (nx+1)/2;
1978     maxky = (ny+1)/2;
1979     maxkz = nz/2+1;
1980
1981     work  = &pme->work[thread];
1982     mhx   = work->mhx;
1983     mhy   = work->mhy;
1984     mhz   = work->mhz;
1985     m2    = work->m2;
1986     denom = work->denom;
1987     tmp1  = work->tmp1;
1988     eterm = work->eterm;
1989     m2inv = work->m2inv;
1990
1991     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1992     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1993
1994     for (iyz = iyz0; iyz < iyz1; iyz++)
1995     {
1996         iy = iyz/local_ndata[ZZ];
1997         iz = iyz - iy*local_ndata[ZZ];
1998
1999         ky = iy + local_offset[YY];
2000
2001         if (ky < maxky)
2002         {
2003             my = ky;
2004         }
2005         else
2006         {
2007             my = (ky - ny);
2008         }
2009
2010         by = M_PI*vol*pme->bsp_mod[YY][ky];
2011
2012         kz = iz + local_offset[ZZ];
2013
2014         mz = kz;
2015
2016         bz = pme->bsp_mod[ZZ][kz];
2017
2018         /* 0.5 correction for corner points */
2019         corner_fac = 1;
2020         if (kz == 0 || kz == (nz+1)/2)
2021         {
2022             corner_fac = 0.5;
2023         }
2024
2025         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
2026
2027         /* We should skip the k-space point (0,0,0) */
2028         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
2029         {
2030             kxstart = local_offset[XX];
2031         }
2032         else
2033         {
2034             kxstart = local_offset[XX] + 1;
2035             p0++;
2036         }
2037         kxend = local_offset[XX] + local_ndata[XX];
2038
2039         if (bEnerVir)
2040         {
2041             /* More expensive inner loop, especially because of the storage
2042              * of the mh elements in array's.
2043              * Because x is the minor grid index, all mh elements
2044              * depend on kx for triclinic unit cells.
2045              */
2046
2047             /* Two explicit loops to avoid a conditional inside the loop */
2048             for (kx = kxstart; kx < maxkx; kx++)
2049             {
2050                 mx = kx;
2051
2052                 mhxk      = mx * rxx;
2053                 mhyk      = mx * ryx + my * ryy;
2054                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2055                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2056                 mhx[kx]   = mhxk;
2057                 mhy[kx]   = mhyk;
2058                 mhz[kx]   = mhzk;
2059                 m2[kx]    = m2k;
2060                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2061                 tmp1[kx]  = -factor*m2k;
2062             }
2063
2064             for (kx = maxkx; kx < kxend; kx++)
2065             {
2066                 mx = (kx - nx);
2067
2068                 mhxk      = mx * rxx;
2069                 mhyk      = mx * ryx + my * ryy;
2070                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2071                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2072                 mhx[kx]   = mhxk;
2073                 mhy[kx]   = mhyk;
2074                 mhz[kx]   = mhzk;
2075                 m2[kx]    = m2k;
2076                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2077                 tmp1[kx]  = -factor*m2k;
2078             }
2079
2080             for (kx = kxstart; kx < kxend; kx++)
2081             {
2082                 m2inv[kx] = 1.0/m2[kx];
2083             }
2084
2085             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2086
2087             for (kx = kxstart; kx < kxend; kx++, p0++)
2088             {
2089                 d1      = p0->re;
2090                 d2      = p0->im;
2091
2092                 p0->re  = d1*eterm[kx];
2093                 p0->im  = d2*eterm[kx];
2094
2095                 struct2 = 2.0*(d1*d1+d2*d2);
2096
2097                 tmp1[kx] = eterm[kx]*struct2;
2098             }
2099
2100             for (kx = kxstart; kx < kxend; kx++)
2101             {
2102                 ets2     = corner_fac*tmp1[kx];
2103                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2104                 energy  += ets2;
2105
2106                 ets2vf   = ets2*vfactor;
2107                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2108                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2109                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2110                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2111                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2112                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2113             }
2114         }
2115         else
2116         {
2117             /* We don't need to calculate the energy and the virial.
2118              * In this case the triclinic overhead is small.
2119              */
2120
2121             /* Two explicit loops to avoid a conditional inside the loop */
2122
2123             for (kx = kxstart; kx < maxkx; kx++)
2124             {
2125                 mx = kx;
2126
2127                 mhxk      = mx * rxx;
2128                 mhyk      = mx * ryx + my * ryy;
2129                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2130                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2131                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2132                 tmp1[kx]  = -factor*m2k;
2133             }
2134
2135             for (kx = maxkx; kx < kxend; kx++)
2136             {
2137                 mx = (kx - nx);
2138
2139                 mhxk      = mx * rxx;
2140                 mhyk      = mx * ryx + my * ryy;
2141                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2142                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2143                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2144                 tmp1[kx]  = -factor*m2k;
2145             }
2146
2147             calc_exponentials(kxstart, kxend, elfac, denom, tmp1, eterm);
2148
2149             for (kx = kxstart; kx < kxend; kx++, p0++)
2150             {
2151                 d1      = p0->re;
2152                 d2      = p0->im;
2153
2154                 p0->re  = d1*eterm[kx];
2155                 p0->im  = d2*eterm[kx];
2156             }
2157         }
2158     }
2159
2160     if (bEnerVir)
2161     {
2162         /* Update virial with local values.
2163          * The virial is symmetric by definition.
2164          * this virial seems ok for isotropic scaling, but I'm
2165          * experiencing problems on semiisotropic membranes.
2166          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2167          */
2168         work->vir[XX][XX] = 0.25*virxx;
2169         work->vir[YY][YY] = 0.25*viryy;
2170         work->vir[ZZ][ZZ] = 0.25*virzz;
2171         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2172         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2173         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2174
2175         /* This energy should be corrected for a charged system */
2176         work->energy = 0.5*energy;
2177     }
2178
2179     /* Return the loop count */
2180     return local_ndata[YY]*local_ndata[XX];
2181 }
2182
2183 static void get_pme_ener_vir(const gmx_pme_t pme, int nthread,
2184                              real *mesh_energy, matrix vir)
2185 {
2186     /* This function sums output over threads
2187      * and should therefore only be called after thread synchronization.
2188      */
2189     int thread;
2190
2191     *mesh_energy = pme->work[0].energy;
2192     copy_mat(pme->work[0].vir, vir);
2193
2194     for (thread = 1; thread < nthread; thread++)
2195     {
2196         *mesh_energy += pme->work[thread].energy;
2197         m_add(vir, pme->work[thread].vir, vir);
2198     }
2199 }
2200
2201 #define DO_FSPLINE(order)                      \
2202     for (ithx = 0; (ithx < order); ithx++)              \
2203     {                                              \
2204         index_x = (i0+ithx)*pny*pnz;               \
2205         tx      = thx[ithx];                       \
2206         dx      = dthx[ithx];                      \
2207                                                \
2208         for (ithy = 0; (ithy < order); ithy++)          \
2209         {                                          \
2210             index_xy = index_x+(j0+ithy)*pnz;      \
2211             ty       = thy[ithy];                  \
2212             dy       = dthy[ithy];                 \
2213             fxy1     = fz1 = 0;                    \
2214                                                \
2215             for (ithz = 0; (ithz < order); ithz++)      \
2216             {                                      \
2217                 gval  = grid[index_xy+(k0+ithz)];  \
2218                 fxy1 += thz[ithz]*gval;            \
2219                 fz1  += dthz[ithz]*gval;           \
2220             }                                      \
2221             fx += dx*ty*fxy1;                      \
2222             fy += tx*dy*fxy1;                      \
2223             fz += tx*ty*fz1;                       \
2224         }                                          \
2225     }
2226
2227
2228 static void gather_f_bsplines(gmx_pme_t pme, real *grid,
2229                               gmx_bool bClearF, pme_atomcomm_t *atc,
2230                               splinedata_t *spline,
2231                               real scale)
2232 {
2233     /* sum forces for local particles */
2234     int     nn, n, ithx, ithy, ithz, i0, j0, k0;
2235     int     index_x, index_xy;
2236     int     nx, ny, nz, pnx, pny, pnz;
2237     int *   idxptr;
2238     real    tx, ty, dx, dy, qn;
2239     real    fx, fy, fz, gval;
2240     real    fxy1, fz1;
2241     real    *thx, *thy, *thz, *dthx, *dthy, *dthz;
2242     int     norder;
2243     real    rxx, ryx, ryy, rzx, rzy, rzz;
2244     int     order;
2245
2246     pme_spline_work_t *work;
2247
2248     work = pme->spline_work;
2249
2250     order = pme->pme_order;
2251     thx   = spline->theta[XX];
2252     thy   = spline->theta[YY];
2253     thz   = spline->theta[ZZ];
2254     dthx  = spline->dtheta[XX];
2255     dthy  = spline->dtheta[YY];
2256     dthz  = spline->dtheta[ZZ];
2257     nx    = pme->nkx;
2258     ny    = pme->nky;
2259     nz    = pme->nkz;
2260     pnx   = pme->pmegrid_nx;
2261     pny   = pme->pmegrid_ny;
2262     pnz   = pme->pmegrid_nz;
2263
2264     rxx   = pme->recipbox[XX][XX];
2265     ryx   = pme->recipbox[YY][XX];
2266     ryy   = pme->recipbox[YY][YY];
2267     rzx   = pme->recipbox[ZZ][XX];
2268     rzy   = pme->recipbox[ZZ][YY];
2269     rzz   = pme->recipbox[ZZ][ZZ];
2270
2271     for (nn = 0; nn < spline->n; nn++)
2272     {
2273         n  = spline->ind[nn];
2274         qn = scale*atc->q[n];
2275
2276         if (bClearF)
2277         {
2278             atc->f[n][XX] = 0;
2279             atc->f[n][YY] = 0;
2280             atc->f[n][ZZ] = 0;
2281         }
2282         if (qn != 0)
2283         {
2284             fx     = 0;
2285             fy     = 0;
2286             fz     = 0;
2287             idxptr = atc->idx[n];
2288             norder = nn*order;
2289
2290             i0   = idxptr[XX];
2291             j0   = idxptr[YY];
2292             k0   = idxptr[ZZ];
2293
2294             /* Pointer arithmetic alert, next six statements */
2295             thx  = spline->theta[XX] + norder;
2296             thy  = spline->theta[YY] + norder;
2297             thz  = spline->theta[ZZ] + norder;
2298             dthx = spline->dtheta[XX] + norder;
2299             dthy = spline->dtheta[YY] + norder;
2300             dthz = spline->dtheta[ZZ] + norder;
2301
2302             switch (order)
2303             {
2304                 case 4:
2305 #ifdef PME_SSE
2306 #ifdef PME_SSE_UNALIGNED
2307 #define PME_GATHER_F_SSE_ORDER4
2308 #else
2309 #define PME_GATHER_F_SSE_ALIGNED
2310 #define PME_ORDER 4
2311 #endif
2312 #include "pme_sse_single.h"
2313 #else
2314                     DO_FSPLINE(4);
2315 #endif
2316                     break;
2317                 case 5:
2318 #ifdef PME_SSE
2319 #define PME_GATHER_F_SSE_ALIGNED
2320 #define PME_ORDER 5
2321 #include "pme_sse_single.h"
2322 #else
2323                     DO_FSPLINE(5);
2324 #endif
2325                     break;
2326                 default:
2327                     DO_FSPLINE(order);
2328                     break;
2329             }
2330
2331             atc->f[n][XX] += -qn*( fx*nx*rxx );
2332             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2333             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2334         }
2335     }
2336     /* Since the energy and not forces are interpolated
2337      * the net force might not be exactly zero.
2338      * This can be solved by also interpolating F, but
2339      * that comes at a cost.
2340      * A better hack is to remove the net force every
2341      * step, but that must be done at a higher level
2342      * since this routine doesn't see all atoms if running
2343      * in parallel. Don't know how important it is?  EL 990726
2344      */
2345 }
2346
2347
2348 static real gather_energy_bsplines(gmx_pme_t pme, real *grid,
2349                                    pme_atomcomm_t *atc)
2350 {
2351     splinedata_t *spline;
2352     int     n, ithx, ithy, ithz, i0, j0, k0;
2353     int     index_x, index_xy;
2354     int *   idxptr;
2355     real    energy, pot, tx, ty, qn, gval;
2356     real    *thx, *thy, *thz;
2357     int     norder;
2358     int     order;
2359
2360     spline = &atc->spline[0];
2361
2362     order = pme->pme_order;
2363
2364     energy = 0;
2365     for (n = 0; (n < atc->n); n++)
2366     {
2367         qn      = atc->q[n];
2368
2369         if (qn != 0)
2370         {
2371             idxptr = atc->idx[n];
2372             norder = n*order;
2373
2374             i0   = idxptr[XX];
2375             j0   = idxptr[YY];
2376             k0   = idxptr[ZZ];
2377
2378             /* Pointer arithmetic alert, next three statements */
2379             thx  = spline->theta[XX] + norder;
2380             thy  = spline->theta[YY] + norder;
2381             thz  = spline->theta[ZZ] + norder;
2382
2383             pot = 0;
2384             for (ithx = 0; (ithx < order); ithx++)
2385             {
2386                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2387                 tx      = thx[ithx];
2388
2389                 for (ithy = 0; (ithy < order); ithy++)
2390                 {
2391                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2392                     ty       = thy[ithy];
2393
2394                     for (ithz = 0; (ithz < order); ithz++)
2395                     {
2396                         gval  = grid[index_xy+(k0+ithz)];
2397                         pot  += tx*ty*thz[ithz]*gval;
2398                     }
2399
2400                 }
2401             }
2402
2403             energy += pot*qn;
2404         }
2405     }
2406
2407     return energy;
2408 }
2409
2410 /* Macro to force loop unrolling by fixing order.
2411  * This gives a significant performance gain.
2412  */
2413 #define CALC_SPLINE(order)                     \
2414     {                                              \
2415         int j, k, l;                                 \
2416         real dr, div;                               \
2417         real data[PME_ORDER_MAX];                  \
2418         real ddata[PME_ORDER_MAX];                 \
2419                                                \
2420         for (j = 0; (j < DIM); j++)                     \
2421         {                                          \
2422             dr  = xptr[j];                         \
2423                                                \
2424             /* dr is relative offset from lower cell limit */ \
2425             data[order-1] = 0;                     \
2426             data[1]       = dr;                          \
2427             data[0]       = 1 - dr;                      \
2428                                                \
2429             for (k = 3; (k < order); k++)               \
2430             {                                      \
2431                 div       = 1.0/(k - 1.0);               \
2432                 data[k-1] = div*dr*data[k-2];      \
2433                 for (l = 1; (l < (k-1)); l++)           \
2434                 {                                  \
2435                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2436                                        data[k-l-1]);                \
2437                 }                                  \
2438                 data[0] = div*(1-dr)*data[0];      \
2439             }                                      \
2440             /* differentiate */                    \
2441             ddata[0] = -data[0];                   \
2442             for (k = 1; (k < order); k++)               \
2443             {                                      \
2444                 ddata[k] = data[k-1] - data[k];    \
2445             }                                      \
2446                                                \
2447             div           = 1.0/(order - 1);                 \
2448             data[order-1] = div*dr*data[order-2];  \
2449             for (l = 1; (l < (order-1)); l++)           \
2450             {                                      \
2451                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2452                                        (order-l-dr)*data[order-l-1]); \
2453             }                                      \
2454             data[0] = div*(1 - dr)*data[0];        \
2455                                                \
2456             for (k = 0; k < order; k++)                 \
2457             {                                      \
2458                 theta[j][i*order+k]  = data[k];    \
2459                 dtheta[j][i*order+k] = ddata[k];   \
2460             }                                      \
2461         }                                          \
2462     }
2463
2464 void make_bsplines(splinevec theta, splinevec dtheta, int order,
2465                    rvec fractx[], int nr, int ind[], real charge[],
2466                    gmx_bool bFreeEnergy)
2467 {
2468     /* construct splines for local atoms */
2469     int  i, ii;
2470     real *xptr;
2471
2472     for (i = 0; i < nr; i++)
2473     {
2474         /* With free energy we do not use the charge check.
2475          * In most cases this will be more efficient than calling make_bsplines
2476          * twice, since usually more than half the particles have charges.
2477          */
2478         ii = ind[i];
2479         if (bFreeEnergy || charge[ii] != 0.0)
2480         {
2481             xptr = fractx[ii];
2482             switch (order)
2483             {
2484                 case 4:  CALC_SPLINE(4);     break;
2485                 case 5:  CALC_SPLINE(5);     break;
2486                 default: CALC_SPLINE(order); break;
2487             }
2488         }
2489     }
2490 }
2491
2492
2493 void make_dft_mod(real *mod, real *data, int ndata)
2494 {
2495     int i, j;
2496     real sc, ss, arg;
2497
2498     for (i = 0; i < ndata; i++)
2499     {
2500         sc = ss = 0;
2501         for (j = 0; j < ndata; j++)
2502         {
2503             arg = (2.0*M_PI*i*j)/ndata;
2504             sc += data[j]*cos(arg);
2505             ss += data[j]*sin(arg);
2506         }
2507         mod[i] = sc*sc+ss*ss;
2508     }
2509     for (i = 0; i < ndata; i++)
2510     {
2511         if (mod[i] < 1e-7)
2512         {
2513             mod[i] = (mod[i-1]+mod[i+1])*0.5;
2514         }
2515     }
2516 }
2517
2518
2519 static void make_bspline_moduli(splinevec bsp_mod,
2520                                 int nx, int ny, int nz, int order)
2521 {
2522     int nmax = max(nx, max(ny, nz));
2523     real *data, *ddata, *bsp_data;
2524     int i, k, l;
2525     real div;
2526
2527     snew(data, order);
2528     snew(ddata, order);
2529     snew(bsp_data, nmax);
2530
2531     data[order-1] = 0;
2532     data[1]       = 0;
2533     data[0]       = 1;
2534
2535     for (k = 3; k < order; k++)
2536     {
2537         div       = 1.0/(k-1.0);
2538         data[k-1] = 0;
2539         for (l = 1; l < (k-1); l++)
2540         {
2541             data[k-l-1] = div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2542         }
2543         data[0] = div*data[0];
2544     }
2545     /* differentiate */
2546     ddata[0] = -data[0];
2547     for (k = 1; k < order; k++)
2548     {
2549         ddata[k] = data[k-1]-data[k];
2550     }
2551     div           = 1.0/(order-1);
2552     data[order-1] = 0;
2553     for (l = 1; l < (order-1); l++)
2554     {
2555         data[order-l-1] = div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2556     }
2557     data[0] = div*data[0];
2558
2559     for (i = 0; i < nmax; i++)
2560     {
2561         bsp_data[i] = 0;
2562     }
2563     for (i = 1; i <= order; i++)
2564     {
2565         bsp_data[i] = data[i-1];
2566     }
2567
2568     make_dft_mod(bsp_mod[XX], bsp_data, nx);
2569     make_dft_mod(bsp_mod[YY], bsp_data, ny);
2570     make_dft_mod(bsp_mod[ZZ], bsp_data, nz);
2571
2572     sfree(data);
2573     sfree(ddata);
2574     sfree(bsp_data);
2575 }
2576
2577
2578 /* Return the P3M optimal influence function */
2579 static double do_p3m_influence(double z, int order)
2580 {
2581     double z2, z4;
2582
2583     z2 = z*z;
2584     z4 = z2*z2;
2585
2586     /* The formula and most constants can be found in:
2587      * Ballenegger et al., JCTC 8, 936 (2012)
2588      */
2589     switch (order)
2590     {
2591         case 2:
2592             return 1.0 - 2.0*z2/3.0;
2593             break;
2594         case 3:
2595             return 1.0 - z2 + 2.0*z4/15.0;
2596             break;
2597         case 4:
2598             return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2599             break;
2600         case 5:
2601             return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2602             break;
2603         case 6:
2604             return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2605             break;
2606         case 7:
2607             return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2608         case 8:
2609             return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2610             break;
2611     }
2612
2613     return 0.0;
2614 }
2615
2616 /* Calculate the P3M B-spline moduli for one dimension */
2617 static void make_p3m_bspline_moduli_dim(real *bsp_mod, int n, int order)
2618 {
2619     double zarg, zai, sinzai, infl;
2620     int    maxk, i;
2621
2622     if (order > 8)
2623     {
2624         gmx_fatal(FARGS, "The current P3M code only supports orders up to 8");
2625     }
2626
2627     zarg = M_PI/n;
2628
2629     maxk = (n + 1)/2;
2630
2631     for (i = -maxk; i < 0; i++)
2632     {
2633         zai          = zarg*i;
2634         sinzai       = sin(zai);
2635         infl         = do_p3m_influence(sinzai, order);
2636         bsp_mod[n+i] = infl*infl*pow(sinzai/zai, -2.0*order);
2637     }
2638     bsp_mod[0] = 1.0;
2639     for (i = 1; i < maxk; i++)
2640     {
2641         zai        = zarg*i;
2642         sinzai     = sin(zai);
2643         infl       = do_p3m_influence(sinzai, order);
2644         bsp_mod[i] = infl*infl*pow(sinzai/zai, -2.0*order);
2645     }
2646 }
2647
2648 /* Calculate the P3M B-spline moduli */
2649 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2650                                     int nx, int ny, int nz, int order)
2651 {
2652     make_p3m_bspline_moduli_dim(bsp_mod[XX], nx, order);
2653     make_p3m_bspline_moduli_dim(bsp_mod[YY], ny, order);
2654     make_p3m_bspline_moduli_dim(bsp_mod[ZZ], nz, order);
2655 }
2656
2657
2658 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2659 {
2660     int nslab, n, i;
2661     int fw, bw;
2662
2663     nslab = atc->nslab;
2664
2665     n = 0;
2666     for (i = 1; i <= nslab/2; i++)
2667     {
2668         fw = (atc->nodeid + i) % nslab;
2669         bw = (atc->nodeid - i + nslab) % nslab;
2670         if (n < nslab - 1)
2671         {
2672             atc->node_dest[n] = fw;
2673             atc->node_src[n]  = bw;
2674             n++;
2675         }
2676         if (n < nslab - 1)
2677         {
2678             atc->node_dest[n] = bw;
2679             atc->node_src[n]  = fw;
2680             n++;
2681         }
2682     }
2683 }
2684
2685 int gmx_pme_destroy(FILE *log, gmx_pme_t *pmedata)
2686 {
2687     int thread;
2688
2689     if (NULL != log)
2690     {
2691         fprintf(log, "Destroying PME data structures.\n");
2692     }
2693
2694     sfree((*pmedata)->nnx);
2695     sfree((*pmedata)->nny);
2696     sfree((*pmedata)->nnz);
2697
2698     pmegrids_destroy(&(*pmedata)->pmegridA);
2699
2700     sfree((*pmedata)->fftgridA);
2701     sfree((*pmedata)->cfftgridA);
2702     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2703
2704     if ((*pmedata)->pmegridB.grid.grid != NULL)
2705     {
2706         pmegrids_destroy(&(*pmedata)->pmegridB);
2707         sfree((*pmedata)->fftgridB);
2708         sfree((*pmedata)->cfftgridB);
2709         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2710     }
2711     for (thread = 0; thread < (*pmedata)->nthread; thread++)
2712     {
2713         free_work(&(*pmedata)->work[thread]);
2714     }
2715     sfree((*pmedata)->work);
2716
2717     sfree(*pmedata);
2718     *pmedata = NULL;
2719
2720     return 0;
2721 }
2722
2723 static int mult_up(int n, int f)
2724 {
2725     return ((n + f - 1)/f)*f;
2726 }
2727
2728
2729 static double pme_load_imbalance(gmx_pme_t pme)
2730 {
2731     int    nma, nmi;
2732     double n1, n2, n3;
2733
2734     nma = pme->nnodes_major;
2735     nmi = pme->nnodes_minor;
2736
2737     n1 = mult_up(pme->nkx, nma)*mult_up(pme->nky, nmi)*pme->nkz;
2738     n2 = mult_up(pme->nkx, nma)*mult_up(pme->nkz, nmi)*pme->nky;
2739     n3 = mult_up(pme->nky, nma)*mult_up(pme->nkz, nmi)*pme->nkx;
2740
2741     /* pme_solve is roughly double the cost of an fft */
2742
2743     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2744 }
2745
2746 static void init_atomcomm(gmx_pme_t pme, pme_atomcomm_t *atc, t_commrec *cr,
2747                           int dimind, gmx_bool bSpread)
2748 {
2749     int nk, k, s, thread;
2750
2751     atc->dimind    = dimind;
2752     atc->nslab     = 1;
2753     atc->nodeid    = 0;
2754     atc->pd_nalloc = 0;
2755 #ifdef GMX_MPI
2756     if (pme->nnodes > 1)
2757     {
2758         atc->mpi_comm = pme->mpi_comm_d[dimind];
2759         MPI_Comm_size(atc->mpi_comm, &atc->nslab);
2760         MPI_Comm_rank(atc->mpi_comm, &atc->nodeid);
2761     }
2762     if (debug)
2763     {
2764         fprintf(debug, "For PME atom communication in dimind %d: nslab %d rank %d\n", atc->dimind, atc->nslab, atc->nodeid);
2765     }
2766 #endif
2767
2768     atc->bSpread   = bSpread;
2769     atc->pme_order = pme->pme_order;
2770
2771     if (atc->nslab > 1)
2772     {
2773         /* These three allocations are not required for particle decomp. */
2774         snew(atc->node_dest, atc->nslab);
2775         snew(atc->node_src, atc->nslab);
2776         setup_coordinate_communication(atc);
2777
2778         snew(atc->count_thread, pme->nthread);
2779         for (thread = 0; thread < pme->nthread; thread++)
2780         {
2781             snew(atc->count_thread[thread], atc->nslab);
2782         }
2783         atc->count = atc->count_thread[0];
2784         snew(atc->rcount, atc->nslab);
2785         snew(atc->buf_index, atc->nslab);
2786     }
2787
2788     atc->nthread = pme->nthread;
2789     if (atc->nthread > 1)
2790     {
2791         snew(atc->thread_plist, atc->nthread);
2792     }
2793     snew(atc->spline, atc->nthread);
2794     for (thread = 0; thread < atc->nthread; thread++)
2795     {
2796         if (atc->nthread > 1)
2797         {
2798             snew(atc->thread_plist[thread].n, atc->nthread+2*GMX_CACHE_SEP);
2799             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2800         }
2801         snew(atc->spline[thread].thread_one, pme->nthread);
2802         atc->spline[thread].thread_one[thread] = 1;
2803     }
2804 }
2805
2806 static void
2807 init_overlap_comm(pme_overlap_t *  ol,
2808                   int              norder,
2809 #ifdef GMX_MPI
2810                   MPI_Comm         comm,
2811 #endif
2812                   int              nnodes,
2813                   int              nodeid,
2814                   int              ndata,
2815                   int              commplainsize)
2816 {
2817     int lbnd, rbnd, maxlr, b, i;
2818     int exten;
2819     int nn, nk;
2820     pme_grid_comm_t *pgc;
2821     gmx_bool bCont;
2822     int fft_start, fft_end, send_index1, recv_index1;
2823 #ifdef GMX_MPI
2824     MPI_Status stat;
2825
2826     ol->mpi_comm = comm;
2827 #endif
2828
2829     ol->nnodes = nnodes;
2830     ol->nodeid = nodeid;
2831
2832     /* Linear translation of the PME grid won't affect reciprocal space
2833      * calculations, so to optimize we only interpolate "upwards",
2834      * which also means we only have to consider overlap in one direction.
2835      * I.e., particles on this node might also be spread to grid indices
2836      * that belong to higher nodes (modulo nnodes)
2837      */
2838
2839     snew(ol->s2g0, ol->nnodes+1);
2840     snew(ol->s2g1, ol->nnodes);
2841     if (debug)
2842     {
2843         fprintf(debug, "PME slab boundaries:");
2844     }
2845     for (i = 0; i < nnodes; i++)
2846     {
2847         /* s2g0 the local interpolation grid start.
2848          * s2g1 the local interpolation grid end.
2849          * Because grid overlap communication only goes forward,
2850          * the grid the slabs for fft's should be rounded down.
2851          */
2852         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2853         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2854
2855         if (debug)
2856         {
2857             fprintf(debug, "  %3d %3d", ol->s2g0[i], ol->s2g1[i]);
2858         }
2859     }
2860     ol->s2g0[nnodes] = ndata;
2861     if (debug)
2862     {
2863         fprintf(debug, "\n");
2864     }
2865
2866     /* Determine with how many nodes we need to communicate the grid overlap */
2867     b = 0;
2868     do
2869     {
2870         b++;
2871         bCont = FALSE;
2872         for (i = 0; i < nnodes; i++)
2873         {
2874             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2875                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2876             {
2877                 bCont = TRUE;
2878             }
2879         }
2880     }
2881     while (bCont && b < nnodes);
2882     ol->noverlap_nodes = b - 1;
2883
2884     snew(ol->send_id, ol->noverlap_nodes);
2885     snew(ol->recv_id, ol->noverlap_nodes);
2886     for (b = 0; b < ol->noverlap_nodes; b++)
2887     {
2888         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2889         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2890     }
2891     snew(ol->comm_data, ol->noverlap_nodes);
2892
2893     ol->send_size = 0;
2894     for (b = 0; b < ol->noverlap_nodes; b++)
2895     {
2896         pgc = &ol->comm_data[b];
2897         /* Send */
2898         fft_start        = ol->s2g0[ol->send_id[b]];
2899         fft_end          = ol->s2g0[ol->send_id[b]+1];
2900         if (ol->send_id[b] < nodeid)
2901         {
2902             fft_start += ndata;
2903             fft_end   += ndata;
2904         }
2905         send_index1       = ol->s2g1[nodeid];
2906         send_index1       = min(send_index1, fft_end);
2907         pgc->send_index0  = fft_start;
2908         pgc->send_nindex  = max(0, send_index1 - pgc->send_index0);
2909         ol->send_size    += pgc->send_nindex;
2910
2911         /* We always start receiving to the first index of our slab */
2912         fft_start        = ol->s2g0[ol->nodeid];
2913         fft_end          = ol->s2g0[ol->nodeid+1];
2914         recv_index1      = ol->s2g1[ol->recv_id[b]];
2915         if (ol->recv_id[b] > nodeid)
2916         {
2917             recv_index1 -= ndata;
2918         }
2919         recv_index1      = min(recv_index1, fft_end);
2920         pgc->recv_index0 = fft_start;
2921         pgc->recv_nindex = max(0, recv_index1 - pgc->recv_index0);
2922     }
2923
2924 #ifdef GMX_MPI
2925     /* Communicate the buffer sizes to receive */
2926     for (b = 0; b < ol->noverlap_nodes; b++)
2927     {
2928         MPI_Sendrecv(&ol->send_size, 1, MPI_INT, ol->send_id[b], b,
2929                      &ol->comm_data[b].recv_size, 1, MPI_INT, ol->recv_id[b], b,
2930                      ol->mpi_comm, &stat);
2931     }
2932 #endif
2933
2934     /* For non-divisible grid we need pme_order iso pme_order-1 */
2935     snew(ol->sendbuf, norder*commplainsize);
2936     snew(ol->recvbuf, norder*commplainsize);
2937 }
2938
2939 static void
2940 make_gridindex5_to_localindex(int n, int local_start, int local_range,
2941                               int **global_to_local,
2942                               real **fraction_shift)
2943 {
2944     int i;
2945     int * gtl;
2946     real * fsh;
2947
2948     snew(gtl, 5*n);
2949     snew(fsh, 5*n);
2950     for (i = 0; (i < 5*n); i++)
2951     {
2952         /* Determine the global to local grid index */
2953         gtl[i] = (i - local_start + n) % n;
2954         /* For coordinates that fall within the local grid the fraction
2955          * is correct, we don't need to shift it.
2956          */
2957         fsh[i] = 0;
2958         if (local_range < n)
2959         {
2960             /* Due to rounding issues i could be 1 beyond the lower or
2961              * upper boundary of the local grid. Correct the index for this.
2962              * If we shift the index, we need to shift the fraction by
2963              * the same amount in the other direction to not affect
2964              * the weights.
2965              * Note that due to this shifting the weights at the end of
2966              * the spline might change, but that will only involve values
2967              * between zero and values close to the precision of a real,
2968              * which is anyhow the accuracy of the whole mesh calculation.
2969              */
2970             /* With local_range=0 we should not change i=local_start */
2971             if (i % n != local_start)
2972             {
2973                 if (gtl[i] == n-1)
2974                 {
2975                     gtl[i] = 0;
2976                     fsh[i] = -1;
2977                 }
2978                 else if (gtl[i] == local_range)
2979                 {
2980                     gtl[i] = local_range - 1;
2981                     fsh[i] = 1;
2982                 }
2983             }
2984         }
2985     }
2986
2987     *global_to_local = gtl;
2988     *fraction_shift  = fsh;
2989 }
2990
2991 static pme_spline_work_t *make_pme_spline_work(int order)
2992 {
2993     pme_spline_work_t *work;
2994
2995 #ifdef PME_SSE
2996     float  tmp[8];
2997     __m128 zero_SSE;
2998     int    of, i;
2999
3000     snew_aligned(work, 1, 16);
3001
3002     zero_SSE = _mm_setzero_ps();
3003
3004     /* Generate bit masks to mask out the unused grid entries,
3005      * as we only operate on order of the 8 grid entries that are
3006      * load into 2 SSE float registers.
3007      */
3008     for (of = 0; of < 8-(order-1); of++)
3009     {
3010         for (i = 0; i < 8; i++)
3011         {
3012             tmp[i] = (i >= of && i < of+order ? 1 : 0);
3013         }
3014         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
3015         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
3016         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of], zero_SSE);
3017         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of], zero_SSE);
3018     }
3019 #else
3020     work = NULL;
3021 #endif
3022
3023     return work;
3024 }
3025
3026 static void
3027 gmx_pme_check_grid_restrictions(FILE *fplog, char dim, int nnodes, int *nk)
3028 {
3029     int nk_new;
3030
3031     if (*nk % nnodes != 0)
3032     {
3033         nk_new = nnodes*(*nk/nnodes + 1);
3034
3035         if (2*nk_new >= 3*(*nk))
3036         {
3037             gmx_fatal(FARGS, "The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
3038                       dim, *nk, dim, nnodes, dim);
3039         }
3040
3041         if (fplog != NULL)
3042         {
3043             fprintf(fplog, "\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
3044                     dim, *nk, dim, nnodes, dim, nk_new, dim);
3045         }
3046
3047         *nk = nk_new;
3048     }
3049 }
3050
3051 int gmx_pme_init(gmx_pme_t *         pmedata,
3052                  t_commrec *         cr,
3053                  int                 nnodes_major,
3054                  int                 nnodes_minor,
3055                  t_inputrec *        ir,
3056                  int                 homenr,
3057                  gmx_bool            bFreeEnergy,
3058                  gmx_bool            bReproducible,
3059                  int                 nthread)
3060 {
3061     gmx_pme_t pme = NULL;
3062
3063     pme_atomcomm_t *atc;
3064     ivec ndata;
3065
3066     if (debug)
3067     {
3068         fprintf(debug, "Creating PME data structures.\n");
3069     }
3070     snew(pme, 1);
3071
3072     pme->redist_init         = FALSE;
3073     pme->sum_qgrid_tmp       = NULL;
3074     pme->sum_qgrid_dd_tmp    = NULL;
3075     pme->buf_nalloc          = 0;
3076     pme->redist_buf_nalloc   = 0;
3077
3078     pme->nnodes              = 1;
3079     pme->bPPnode             = TRUE;
3080
3081     pme->nnodes_major        = nnodes_major;
3082     pme->nnodes_minor        = nnodes_minor;
3083
3084 #ifdef GMX_MPI
3085     if (nnodes_major*nnodes_minor > 1)
3086     {
3087         pme->mpi_comm = cr->mpi_comm_mygroup;
3088
3089         MPI_Comm_rank(pme->mpi_comm, &pme->nodeid);
3090         MPI_Comm_size(pme->mpi_comm, &pme->nnodes);
3091         if (pme->nnodes != nnodes_major*nnodes_minor)
3092         {
3093             gmx_incons("PME node count mismatch");
3094         }
3095     }
3096     else
3097     {
3098         pme->mpi_comm = MPI_COMM_NULL;
3099     }
3100 #endif
3101
3102     if (pme->nnodes == 1)
3103     {
3104 #ifdef GMX_MPI
3105         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3106         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3107 #endif
3108         pme->ndecompdim   = 0;
3109         pme->nodeid_major = 0;
3110         pme->nodeid_minor = 0;
3111 #ifdef GMX_MPI
3112         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3113 #endif
3114     }
3115     else
3116     {
3117         if (nnodes_minor == 1)
3118         {
3119 #ifdef GMX_MPI
3120             pme->mpi_comm_d[0] = pme->mpi_comm;
3121             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3122 #endif
3123             pme->ndecompdim   = 1;
3124             pme->nodeid_major = pme->nodeid;
3125             pme->nodeid_minor = 0;
3126
3127         }
3128         else if (nnodes_major == 1)
3129         {
3130 #ifdef GMX_MPI
3131             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3132             pme->mpi_comm_d[1] = pme->mpi_comm;
3133 #endif
3134             pme->ndecompdim   = 1;
3135             pme->nodeid_major = 0;
3136             pme->nodeid_minor = pme->nodeid;
3137         }
3138         else
3139         {
3140             if (pme->nnodes % nnodes_major != 0)
3141             {
3142                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3143             }
3144             pme->ndecompdim = 2;
3145
3146 #ifdef GMX_MPI
3147             MPI_Comm_split(pme->mpi_comm, pme->nodeid % nnodes_minor,
3148                            pme->nodeid, &pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3149             MPI_Comm_split(pme->mpi_comm, pme->nodeid/nnodes_minor,
3150                            pme->nodeid, &pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3151
3152             MPI_Comm_rank(pme->mpi_comm_d[0], &pme->nodeid_major);
3153             MPI_Comm_size(pme->mpi_comm_d[0], &pme->nnodes_major);
3154             MPI_Comm_rank(pme->mpi_comm_d[1], &pme->nodeid_minor);
3155             MPI_Comm_size(pme->mpi_comm_d[1], &pme->nnodes_minor);
3156 #endif
3157         }
3158         pme->bPPnode = (cr->duty & DUTY_PP);
3159     }
3160
3161     pme->nthread = nthread;
3162
3163     if (ir->ePBC == epbcSCREW)
3164     {
3165         gmx_fatal(FARGS, "pme does not (yet) work with pbc = screw");
3166     }
3167
3168     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3169     pme->nkx         = ir->nkx;
3170     pme->nky         = ir->nky;
3171     pme->nkz         = ir->nkz;
3172     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3173     pme->pme_order   = ir->pme_order;
3174     pme->epsilon_r   = ir->epsilon_r;
3175
3176     if (pme->pme_order > PME_ORDER_MAX)
3177     {
3178         gmx_fatal(FARGS, "pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3179                   pme->pme_order, PME_ORDER_MAX);
3180     }
3181
3182     /* Currently pme.c supports only the fft5d FFT code.
3183      * Therefore the grid always needs to be divisible by nnodes.
3184      * When the old 1D code is also supported again, change this check.
3185      *
3186      * This check should be done before calling gmx_pme_init
3187      * and fplog should be passed iso stderr.
3188      *
3189        if (pme->ndecompdim >= 2)
3190      */
3191     if (pme->ndecompdim >= 1)
3192     {
3193         /*
3194            gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3195                                         'x',nnodes_major,&pme->nkx);
3196            gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3197                                         'y',nnodes_minor,&pme->nky);
3198          */
3199     }
3200
3201     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3202         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3203         pme->nkz <= pme->pme_order)
3204     {
3205         gmx_fatal(FARGS, "The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order", pme->pme_order);
3206     }
3207
3208     if (pme->nnodes > 1)
3209     {
3210         double imbal;
3211
3212 #ifdef GMX_MPI
3213         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3214         MPI_Type_commit(&(pme->rvec_mpi));
3215 #endif
3216
3217         /* Note that the charge spreading and force gathering, which usually
3218          * takes about the same amount of time as FFT+solve_pme,
3219          * is always fully load balanced
3220          * (unless the charge distribution is inhomogeneous).
3221          */
3222
3223         imbal = pme_load_imbalance(pme);
3224         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3225         {
3226             fprintf(stderr,
3227                     "\n"
3228                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3229                     "      For optimal PME load balancing\n"
3230                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3231                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3232                     "\n",
3233                     (int)((imbal-1)*100 + 0.5),
3234                     pme->nkx, pme->nky, pme->nnodes_major,
3235                     pme->nky, pme->nkz, pme->nnodes_minor);
3236         }
3237     }
3238
3239     /* For non-divisible grid we need pme_order iso pme_order-1 */
3240     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3241      * y is always copied through a buffer: we don't need padding in z,
3242      * but we do need the overlap in x because of the communication order.
3243      */
3244     init_overlap_comm(&pme->overlap[0], pme->pme_order,
3245 #ifdef GMX_MPI
3246                       pme->mpi_comm_d[0],
3247 #endif
3248                       pme->nnodes_major, pme->nodeid_major,
3249                       pme->nkx,
3250                       (div_round_up(pme->nky, pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3251
3252     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3253      * We do this with an offset buffer of equal size, so we need to allocate
3254      * extra for the offset. That's what the (+1)*pme->nkz is for.
3255      */
3256     init_overlap_comm(&pme->overlap[1], pme->pme_order,
3257 #ifdef GMX_MPI
3258                       pme->mpi_comm_d[1],
3259 #endif
3260                       pme->nnodes_minor, pme->nodeid_minor,
3261                       pme->nky,
3262                       (div_round_up(pme->nkx, pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3263
3264     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3265      * We only allow multiple communication pulses in dim 1, not in dim 0.
3266      */
3267     if (pme->nthread > 1 && (pme->overlap[0].noverlap_nodes > 1 ||
3268                              pme->nkx < pme->nnodes_major*pme->pme_order))
3269     {
3270         gmx_fatal(FARGS, "The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x and should be >= pme_order (%d). To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3271                   pme->nkx/(double)pme->nnodes_major, pme->pme_order);
3272     }
3273
3274     snew(pme->bsp_mod[XX], pme->nkx);
3275     snew(pme->bsp_mod[YY], pme->nky);
3276     snew(pme->bsp_mod[ZZ], pme->nkz);
3277
3278     /* The required size of the interpolation grid, including overlap.
3279      * The allocated size (pmegrid_n?) might be slightly larger.
3280      */
3281     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3282         pme->overlap[0].s2g0[pme->nodeid_major];
3283     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3284         pme->overlap[1].s2g0[pme->nodeid_minor];
3285     pme->pmegrid_nz_base = pme->nkz;
3286     pme->pmegrid_nz      = pme->pmegrid_nz_base + pme->pme_order - 1;
3287     set_grid_alignment(&pme->pmegrid_nz, pme->pme_order);
3288
3289     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3290     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3291     pme->pmegrid_start_iz = 0;
3292
3293     make_gridindex5_to_localindex(pme->nkx,
3294                                   pme->pmegrid_start_ix,
3295                                   pme->pmegrid_nx - (pme->pme_order-1),
3296                                   &pme->nnx, &pme->fshx);
3297     make_gridindex5_to_localindex(pme->nky,
3298                                   pme->pmegrid_start_iy,
3299                                   pme->pmegrid_ny - (pme->pme_order-1),
3300                                   &pme->nny, &pme->fshy);
3301     make_gridindex5_to_localindex(pme->nkz,
3302                                   pme->pmegrid_start_iz,
3303                                   pme->pmegrid_nz_base,
3304                                   &pme->nnz, &pme->fshz);
3305
3306     pmegrids_init(&pme->pmegridA,
3307                   pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3308                   pme->pmegrid_nz_base,
3309                   pme->pme_order,
3310                   pme->nthread,
3311                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3312                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3313
3314     pme->spline_work = make_pme_spline_work(pme->pme_order);
3315
3316     ndata[0] = pme->nkx;
3317     ndata[1] = pme->nky;
3318     ndata[2] = pme->nkz;
3319
3320     /* This routine will allocate the grid data to fit the FFTs */
3321     gmx_parallel_3dfft_init(&pme->pfft_setupA, ndata,
3322                             &pme->fftgridA, &pme->cfftgridA,
3323                             pme->mpi_comm_d,
3324                             pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3325                             bReproducible, pme->nthread);
3326
3327     if (bFreeEnergy)
3328     {
3329         pmegrids_init(&pme->pmegridB,
3330                       pme->pmegrid_nx, pme->pmegrid_ny, pme->pmegrid_nz,
3331                       pme->pmegrid_nz_base,
3332                       pme->pme_order,
3333                       pme->nthread,
3334                       pme->nkx % pme->nnodes_major != 0,
3335                       pme->nky % pme->nnodes_minor != 0);
3336
3337         gmx_parallel_3dfft_init(&pme->pfft_setupB, ndata,
3338                                 &pme->fftgridB, &pme->cfftgridB,
3339                                 pme->mpi_comm_d,
3340                                 pme->overlap[0].s2g0, pme->overlap[1].s2g0,
3341                                 bReproducible, pme->nthread);
3342     }
3343     else
3344     {
3345         pme->pmegridB.grid.grid = NULL;
3346         pme->fftgridB           = NULL;
3347         pme->cfftgridB          = NULL;
3348     }
3349
3350     if (!pme->bP3M)
3351     {
3352         /* Use plain SPME B-spline interpolation */
3353         make_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3354     }
3355     else
3356     {
3357         /* Use the P3M grid-optimized influence function */
3358         make_p3m_bspline_moduli(pme->bsp_mod, pme->nkx, pme->nky, pme->nkz, pme->pme_order);
3359     }
3360
3361     /* Use atc[0] for spreading */
3362     init_atomcomm(pme, &pme->atc[0], cr, nnodes_major > 1 ? 0 : 1, TRUE);
3363     if (pme->ndecompdim >= 2)
3364     {
3365         init_atomcomm(pme, &pme->atc[1], cr, 1, FALSE);
3366     }
3367
3368     if (pme->nnodes == 1)
3369     {
3370         pme->atc[0].n = homenr;
3371         pme_realloc_atomcomm_things(&pme->atc[0]);
3372     }
3373
3374     {
3375         int thread;
3376
3377         /* Use fft5d, order after FFT is y major, z, x minor */
3378
3379         snew(pme->work, pme->nthread);
3380         for (thread = 0; thread < pme->nthread; thread++)
3381         {
3382             realloc_work(&pme->work[thread], pme->nkx);
3383         }
3384     }
3385
3386     *pmedata = pme;
3387
3388     return 0;
3389 }
3390
3391 static void reuse_pmegrids(const pmegrids_t *old, pmegrids_t *new)
3392 {
3393     int d, t;
3394
3395     for (d = 0; d < DIM; d++)
3396     {
3397         if (new->grid.n[d] > old->grid.n[d])
3398         {
3399             return;
3400         }
3401     }
3402
3403     sfree_aligned(new->grid.grid);
3404     new->grid.grid = old->grid.grid;
3405
3406     if (new->nthread > 1 && new->nthread == old->nthread)
3407     {
3408         sfree_aligned(new->grid_all);
3409         for (t = 0; t < new->nthread; t++)
3410         {
3411             new->grid_th[t].grid = old->grid_th[t].grid;
3412         }
3413     }
3414 }
3415
3416 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3417                    t_commrec *         cr,
3418                    gmx_pme_t           pme_src,
3419                    const t_inputrec *  ir,
3420                    ivec                grid_size)
3421 {
3422     t_inputrec irc;
3423     int homenr;
3424     int ret;
3425
3426     irc     = *ir;
3427     irc.nkx = grid_size[XX];
3428     irc.nky = grid_size[YY];
3429     irc.nkz = grid_size[ZZ];
3430
3431     if (pme_src->nnodes == 1)
3432     {
3433         homenr = pme_src->atc[0].n;
3434     }
3435     else
3436     {
3437         homenr = -1;
3438     }
3439
3440     ret = gmx_pme_init(pmedata, cr, pme_src->nnodes_major, pme_src->nnodes_minor,
3441                        &irc, homenr, pme_src->bFEP, FALSE, pme_src->nthread);
3442
3443     if (ret == 0)
3444     {
3445         /* We can easily reuse the allocated pme grids in pme_src */
3446         reuse_pmegrids(&pme_src->pmegridA, &(*pmedata)->pmegridA);
3447         /* We would like to reuse the fft grids, but that's harder */
3448     }
3449
3450     return ret;
3451 }
3452
3453
3454 static void copy_local_grid(gmx_pme_t pme,
3455                             pmegrids_t *pmegrids, int thread, real *fftgrid)
3456 {
3457     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3458     int  fft_my, fft_mz;
3459     int  nsx, nsy, nsz;
3460     ivec nf;
3461     int  offx, offy, offz, x, y, z, i0, i0t;
3462     int  d;
3463     pmegrid_t *pmegrid;
3464     real *grid_th;
3465
3466     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3467                                    local_fft_ndata,
3468                                    local_fft_offset,
3469                                    local_fft_size);
3470     fft_my = local_fft_size[YY];
3471     fft_mz = local_fft_size[ZZ];
3472
3473     pmegrid = &pmegrids->grid_th[thread];
3474
3475     nsx = pmegrid->s[XX];
3476     nsy = pmegrid->s[YY];
3477     nsz = pmegrid->s[ZZ];
3478
3479     for (d = 0; d < DIM; d++)
3480     {
3481         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3482                     local_fft_ndata[d] - pmegrid->offset[d]);
3483     }
3484
3485     offx = pmegrid->offset[XX];
3486     offy = pmegrid->offset[YY];
3487     offz = pmegrid->offset[ZZ];
3488
3489     /* Directly copy the non-overlapping parts of the local grids.
3490      * This also initializes the full grid.
3491      */
3492     grid_th = pmegrid->grid;
3493     for (x = 0; x < nf[XX]; x++)
3494     {
3495         for (y = 0; y < nf[YY]; y++)
3496         {
3497             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3498             i0t = (x*nsy + y)*nsz;
3499             for (z = 0; z < nf[ZZ]; z++)
3500             {
3501                 fftgrid[i0+z] = grid_th[i0t+z];
3502             }
3503         }
3504     }
3505 }
3506
3507 static void
3508 reduce_threadgrid_overlap(gmx_pme_t pme,
3509                           const pmegrids_t *pmegrids, int thread,
3510                           real *fftgrid, real *commbuf_x, real *commbuf_y)
3511 {
3512     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3513     int  fft_nx, fft_ny, fft_nz;
3514     int  fft_my, fft_mz;
3515     int  buf_my = -1;
3516     int  nsx, nsy, nsz;
3517     ivec ne;
3518     int  offx, offy, offz, x, y, z, i0, i0t;
3519     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
3520     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
3521     gmx_bool bCommX, bCommY;
3522     int  d;
3523     int  thread_f;
3524     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
3525     const real *grid_th;
3526     real *commbuf = NULL;
3527
3528     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3529                                    local_fft_ndata,
3530                                    local_fft_offset,
3531                                    local_fft_size);
3532     fft_nx = local_fft_ndata[XX];
3533     fft_ny = local_fft_ndata[YY];
3534     fft_nz = local_fft_ndata[ZZ];
3535
3536     fft_my = local_fft_size[YY];
3537     fft_mz = local_fft_size[ZZ];
3538
3539     /* This routine is called when all thread have finished spreading.
3540      * Here each thread sums grid contributions calculated by other threads
3541      * to the thread local grid volume.
3542      * To minimize the number of grid copying operations,
3543      * this routines sums immediately from the pmegrid to the fftgrid.
3544      */
3545
3546     /* Determine which part of the full node grid we should operate on,
3547      * this is our thread local part of the full grid.
3548      */
3549     pmegrid = &pmegrids->grid_th[thread];
3550
3551     for (d = 0; d < DIM; d++)
3552     {
3553         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3554                     local_fft_ndata[d]);
3555     }
3556
3557     offx = pmegrid->offset[XX];
3558     offy = pmegrid->offset[YY];
3559     offz = pmegrid->offset[ZZ];
3560
3561
3562     bClearBufX  = TRUE;
3563     bClearBufY  = TRUE;
3564     bClearBufXY = TRUE;
3565
3566     /* Now loop over all the thread data blocks that contribute
3567      * to the grid region we (our thread) are operating on.
3568      */
3569     /* Note that ffy_nx/y is equal to the number of grid points
3570      * between the first point of our node grid and the one of the next node.
3571      */
3572     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
3573     {
3574         fx     = pmegrid->ci[XX] + sx;
3575         ox     = 0;
3576         bCommX = FALSE;
3577         if (fx < 0)
3578         {
3579             fx    += pmegrids->nc[XX];
3580             ox    -= fft_nx;
3581             bCommX = (pme->nnodes_major > 1);
3582         }
3583         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3584         ox       += pmegrid_g->offset[XX];
3585         if (!bCommX)
3586         {
3587             tx1 = min(ox + pmegrid_g->n[XX], ne[XX]);
3588         }
3589         else
3590         {
3591             tx1 = min(ox + pmegrid_g->n[XX], pme->pme_order);
3592         }
3593
3594         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
3595         {
3596             fy     = pmegrid->ci[YY] + sy;
3597             oy     = 0;
3598             bCommY = FALSE;
3599             if (fy < 0)
3600             {
3601                 fy    += pmegrids->nc[YY];
3602                 oy    -= fft_ny;
3603                 bCommY = (pme->nnodes_minor > 1);
3604             }
3605             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3606             oy       += pmegrid_g->offset[YY];
3607             if (!bCommY)
3608             {
3609                 ty1 = min(oy + pmegrid_g->n[YY], ne[YY]);
3610             }
3611             else
3612             {
3613                 ty1 = min(oy + pmegrid_g->n[YY], pme->pme_order);
3614             }
3615
3616             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
3617             {
3618                 fz = pmegrid->ci[ZZ] + sz;
3619                 oz = 0;
3620                 if (fz < 0)
3621                 {
3622                     fz += pmegrids->nc[ZZ];
3623                     oz -= fft_nz;
3624                 }
3625                 pmegrid_g = &pmegrids->grid_th[fz];
3626                 oz       += pmegrid_g->offset[ZZ];
3627                 tz1       = min(oz + pmegrid_g->n[ZZ], ne[ZZ]);
3628
3629                 if (sx == 0 && sy == 0 && sz == 0)
3630                 {
3631                     /* We have already added our local contribution
3632                      * before calling this routine, so skip it here.
3633                      */
3634                     continue;
3635                 }
3636
3637                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3638
3639                 pmegrid_f = &pmegrids->grid_th[thread_f];
3640
3641                 grid_th = pmegrid_f->grid;
3642
3643                 nsx = pmegrid_f->s[XX];
3644                 nsy = pmegrid_f->s[YY];
3645                 nsz = pmegrid_f->s[ZZ];
3646
3647 #ifdef DEBUG_PME_REDUCE
3648                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3649                        pme->nodeid, thread, thread_f,
3650                        pme->pmegrid_start_ix,
3651                        pme->pmegrid_start_iy,
3652                        pme->pmegrid_start_iz,
3653                        sx, sy, sz,
3654                        offx-ox, tx1-ox, offx, tx1,
3655                        offy-oy, ty1-oy, offy, ty1,
3656                        offz-oz, tz1-oz, offz, tz1);
3657 #endif
3658
3659                 if (!(bCommX || bCommY))
3660                 {
3661                     /* Copy from the thread local grid to the node grid */
3662                     for (x = offx; x < tx1; x++)
3663                     {
3664                         for (y = offy; y < ty1; y++)
3665                         {
3666                             i0  = (x*fft_my + y)*fft_mz;
3667                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3668                             for (z = offz; z < tz1; z++)
3669                             {
3670                                 fftgrid[i0+z] += grid_th[i0t+z];
3671                             }
3672                         }
3673                     }
3674                 }
3675                 else
3676                 {
3677                     /* The order of this conditional decides
3678                      * where the corner volume gets stored with x+y decomp.
3679                      */
3680                     if (bCommY)
3681                     {
3682                         commbuf = commbuf_y;
3683                         buf_my  = ty1 - offy;
3684                         if (bCommX)
3685                         {
3686                             /* We index commbuf modulo the local grid size */
3687                             commbuf += buf_my*fft_nx*fft_nz;
3688
3689                             bClearBuf   = bClearBufXY;
3690                             bClearBufXY = FALSE;
3691                         }
3692                         else
3693                         {
3694                             bClearBuf  = bClearBufY;
3695                             bClearBufY = FALSE;
3696                         }
3697                     }
3698                     else
3699                     {
3700                         commbuf    = commbuf_x;
3701                         buf_my     = fft_ny;
3702                         bClearBuf  = bClearBufX;
3703                         bClearBufX = FALSE;
3704                     }
3705
3706                     /* Copy to the communication buffer */
3707                     for (x = offx; x < tx1; x++)
3708                     {
3709                         for (y = offy; y < ty1; y++)
3710                         {
3711                             i0  = (x*buf_my + y)*fft_nz;
3712                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3713
3714                             if (bClearBuf)
3715                             {
3716                                 /* First access of commbuf, initialize it */
3717                                 for (z = offz; z < tz1; z++)
3718                                 {
3719                                     commbuf[i0+z]  = grid_th[i0t+z];
3720                                 }
3721                             }
3722                             else
3723                             {
3724                                 for (z = offz; z < tz1; z++)
3725                                 {
3726                                     commbuf[i0+z] += grid_th[i0t+z];
3727                                 }
3728                             }
3729                         }
3730                     }
3731                 }
3732             }
3733         }
3734     }
3735 }
3736
3737
3738 static void sum_fftgrid_dd(gmx_pme_t pme, real *fftgrid)
3739 {
3740     ivec local_fft_ndata, local_fft_offset, local_fft_size;
3741     pme_overlap_t *overlap;
3742     int  send_index0, send_nindex;
3743     int  recv_nindex;
3744 #ifdef GMX_MPI
3745     MPI_Status stat;
3746 #endif
3747     int  send_size_y, recv_size_y;
3748     int  ipulse, send_id, recv_id, datasize, gridsize, size_yx;
3749     real *sendptr, *recvptr;
3750     int  x, y, z, indg, indb;
3751
3752     /* Note that this routine is only used for forward communication.
3753      * Since the force gathering, unlike the charge spreading,
3754      * can be trivially parallelized over the particles,
3755      * the backwards process is much simpler and can use the "old"
3756      * communication setup.
3757      */
3758
3759     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3760                                    local_fft_ndata,
3761                                    local_fft_offset,
3762                                    local_fft_size);
3763
3764     if (pme->nnodes_minor > 1)
3765     {
3766         /* Major dimension */
3767         overlap = &pme->overlap[1];
3768
3769         if (pme->nnodes_major > 1)
3770         {
3771             size_yx = pme->overlap[0].comm_data[0].send_nindex;
3772         }
3773         else
3774         {
3775             size_yx = 0;
3776         }
3777         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3778
3779         send_size_y = overlap->send_size;
3780
3781         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
3782         {
3783             send_id       = overlap->send_id[ipulse];
3784             recv_id       = overlap->recv_id[ipulse];
3785             send_index0   =
3786                 overlap->comm_data[ipulse].send_index0 -
3787                 overlap->comm_data[0].send_index0;
3788             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3789             /* We don't use recv_index0, as we always receive starting at 0 */
3790             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3791             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3792
3793             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3794             recvptr = overlap->recvbuf;
3795
3796 #ifdef GMX_MPI
3797             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
3798                          send_id, ipulse,
3799                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
3800                          recv_id, ipulse,
3801                          overlap->mpi_comm, &stat);
3802 #endif
3803
3804             for (x = 0; x < local_fft_ndata[XX]; x++)
3805             {
3806                 for (y = 0; y < recv_nindex; y++)
3807                 {
3808                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3809                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3810                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
3811                     {
3812                         fftgrid[indg+z] += recvptr[indb+z];
3813                     }
3814                 }
3815             }
3816
3817             if (pme->nnodes_major > 1)
3818             {
3819                 /* Copy from the received buffer to the send buffer for dim 0 */
3820                 sendptr = pme->overlap[0].sendbuf;
3821                 for (x = 0; x < size_yx; x++)
3822                 {
3823                     for (y = 0; y < recv_nindex; y++)
3824                     {
3825                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3826                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3827                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
3828                         {
3829                             sendptr[indg+z] += recvptr[indb+z];
3830                         }
3831                     }
3832                 }
3833             }
3834         }
3835     }
3836
3837     /* We only support a single pulse here.
3838      * This is not a severe limitation, as this code is only used
3839      * with OpenMP and with OpenMP the (PME) domains can be larger.
3840      */
3841     if (pme->nnodes_major > 1)
3842     {
3843         /* Major dimension */
3844         overlap = &pme->overlap[0];
3845
3846         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3847         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3848
3849         ipulse = 0;
3850
3851         send_id       = overlap->send_id[ipulse];
3852         recv_id       = overlap->recv_id[ipulse];
3853         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3854         /* We don't use recv_index0, as we always receive starting at 0 */
3855         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3856
3857         sendptr = overlap->sendbuf;
3858         recvptr = overlap->recvbuf;
3859
3860         if (debug != NULL)
3861         {
3862             fprintf(debug, "PME fftgrid comm %2d x %2d x %2d\n",
3863                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
3864         }
3865
3866 #ifdef GMX_MPI
3867         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
3868                      send_id, ipulse,
3869                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
3870                      recv_id, ipulse,
3871                      overlap->mpi_comm, &stat);
3872 #endif
3873
3874         for (x = 0; x < recv_nindex; x++)
3875         {
3876             for (y = 0; y < local_fft_ndata[YY]; y++)
3877             {
3878                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3879                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3880                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
3881                 {
3882                     fftgrid[indg+z] += recvptr[indb+z];
3883                 }
3884             }
3885         }
3886     }
3887 }
3888
3889
3890 static void spread_on_grid(gmx_pme_t pme,
3891                            pme_atomcomm_t *atc, pmegrids_t *grids,
3892                            gmx_bool bCalcSplines, gmx_bool bSpread,
3893                            real *fftgrid)
3894 {
3895     int nthread, thread;
3896 #ifdef PME_TIME_THREADS
3897     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
3898     static double cs1     = 0, cs2 = 0, cs3 = 0;
3899     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
3900     static int cnt        = 0;
3901 #endif
3902
3903     nthread = pme->nthread;
3904     assert(nthread > 0);
3905
3906 #ifdef PME_TIME_THREADS
3907     c1 = omp_cyc_start();
3908 #endif
3909     if (bCalcSplines)
3910     {
3911 #pragma omp parallel for num_threads(nthread) schedule(static)
3912         for (thread = 0; thread < nthread; thread++)
3913         {
3914             int start, end;
3915
3916             start = atc->n* thread   /nthread;
3917             end   = atc->n*(thread+1)/nthread;
3918
3919             /* Compute fftgrid index for all atoms,
3920              * with help of some extra variables.
3921              */
3922             calc_interpolation_idx(pme, atc, start, end, thread);
3923         }
3924     }
3925 #ifdef PME_TIME_THREADS
3926     c1   = omp_cyc_end(c1);
3927     cs1 += (double)c1;
3928 #endif
3929
3930 #ifdef PME_TIME_THREADS
3931     c2 = omp_cyc_start();
3932 #endif
3933 #pragma omp parallel for num_threads(nthread) schedule(static)
3934     for (thread = 0; thread < nthread; thread++)
3935     {
3936         splinedata_t *spline;
3937         pmegrid_t *grid;
3938
3939         /* make local bsplines  */
3940         if (grids == NULL || grids->nthread == 1)
3941         {
3942             spline = &atc->spline[0];
3943
3944             spline->n = atc->n;
3945
3946             grid = &grids->grid;
3947         }
3948         else
3949         {
3950             spline = &atc->spline[thread];
3951
3952             make_thread_local_ind(atc, thread, spline);
3953
3954             grid = &grids->grid_th[thread];
3955         }
3956
3957         if (bCalcSplines)
3958         {
3959             make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
3960                           atc->fractx, spline->n, spline->ind, atc->q, pme->bFEP);
3961         }
3962
3963         if (bSpread)
3964         {
3965             /* put local atoms on grid. */
3966 #ifdef PME_TIME_SPREAD
3967             ct1a = omp_cyc_start();
3968 #endif
3969             spread_q_bsplines_thread(grid, atc, spline, pme->spline_work);
3970
3971             if (grids->nthread > 1)
3972             {
3973                 copy_local_grid(pme, grids, thread, fftgrid);
3974             }
3975 #ifdef PME_TIME_SPREAD
3976             ct1a          = omp_cyc_end(ct1a);
3977             cs1a[thread] += (double)ct1a;
3978 #endif
3979         }
3980     }
3981 #ifdef PME_TIME_THREADS
3982     c2   = omp_cyc_end(c2);
3983     cs2 += (double)c2;
3984 #endif
3985
3986     if (bSpread && grids->nthread > 1)
3987     {
3988 #ifdef PME_TIME_THREADS
3989         c3 = omp_cyc_start();
3990 #endif
3991 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3992         for (thread = 0; thread < grids->nthread; thread++)
3993         {
3994             reduce_threadgrid_overlap(pme, grids, thread,
3995                                       fftgrid,
3996                                       pme->overlap[0].sendbuf,
3997                                       pme->overlap[1].sendbuf);
3998         }
3999 #ifdef PME_TIME_THREADS
4000         c3   = omp_cyc_end(c3);
4001         cs3 += (double)c3;
4002 #endif
4003
4004         if (pme->nnodes > 1)
4005         {
4006             /* Communicate the overlapping part of the fftgrid */
4007             sum_fftgrid_dd(pme, fftgrid);
4008         }
4009     }
4010
4011 #ifdef PME_TIME_THREADS
4012     cnt++;
4013     if (cnt % 20 == 0)
4014     {
4015         printf("idx %.2f spread %.2f red %.2f",
4016                cs1*1e-9, cs2*1e-9, cs3*1e-9);
4017 #ifdef PME_TIME_SPREAD
4018         for (thread = 0; thread < nthread; thread++)
4019         {
4020             printf(" %.2f", cs1a[thread]*1e-9);
4021         }
4022 #endif
4023         printf("\n");
4024     }
4025 #endif
4026 }
4027
4028
4029 static void dump_grid(FILE *fp,
4030                       int sx, int sy, int sz, int nx, int ny, int nz,
4031                       int my, int mz, const real *g)
4032 {
4033     int x, y, z;
4034
4035     for (x = 0; x < nx; x++)
4036     {
4037         for (y = 0; y < ny; y++)
4038         {
4039             for (z = 0; z < nz; z++)
4040             {
4041                 fprintf(fp, "%2d %2d %2d %6.3f\n",
4042                         sx+x, sy+y, sz+z, g[(x*my + y)*mz + z]);
4043             }
4044         }
4045     }
4046 }
4047
4048 static void dump_local_fftgrid(gmx_pme_t pme, const real *fftgrid)
4049 {
4050     ivec local_fft_ndata, local_fft_offset, local_fft_size;
4051
4052     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
4053                                    local_fft_ndata,
4054                                    local_fft_offset,
4055                                    local_fft_size);
4056
4057     dump_grid(stderr,
4058               pme->pmegrid_start_ix,
4059               pme->pmegrid_start_iy,
4060               pme->pmegrid_start_iz,
4061               pme->pmegrid_nx-pme->pme_order+1,
4062               pme->pmegrid_ny-pme->pme_order+1,
4063               pme->pmegrid_nz-pme->pme_order+1,
4064               local_fft_size[YY],
4065               local_fft_size[ZZ],
4066               fftgrid);
4067 }
4068
4069
4070 void gmx_pme_calc_energy(gmx_pme_t pme, int n, rvec *x, real *q, real *V)
4071 {
4072     pme_atomcomm_t *atc;
4073     pmegrids_t *grid;
4074
4075     if (pme->nnodes > 1)
4076     {
4077         gmx_incons("gmx_pme_calc_energy called in parallel");
4078     }
4079     if (pme->bFEP > 1)
4080     {
4081         gmx_incons("gmx_pme_calc_energy with free energy");
4082     }
4083
4084     atc            = &pme->atc_energy;
4085     atc->nthread   = 1;
4086     if (atc->spline == NULL)
4087     {
4088         snew(atc->spline, atc->nthread);
4089     }
4090     atc->nslab     = 1;
4091     atc->bSpread   = TRUE;
4092     atc->pme_order = pme->pme_order;
4093     atc->n         = n;
4094     pme_realloc_atomcomm_things(atc);
4095     atc->x         = x;
4096     atc->q         = q;
4097
4098     /* We only use the A-charges grid */
4099     grid = &pme->pmegridA;
4100
4101     spread_on_grid(pme, atc, NULL, TRUE, FALSE, pme->fftgridA);
4102
4103     *V = gather_energy_bsplines(pme, grid->grid.grid, atc);
4104 }
4105
4106
4107 static void reset_pmeonly_counters(t_commrec *cr, gmx_wallcycle_t wcycle,
4108                                    t_nrnb *nrnb, t_inputrec *ir,
4109                                    gmx_large_int_t step)
4110 {
4111     /* Reset all the counters related to performance over the run */
4112     wallcycle_stop(wcycle, ewcRUN);
4113     wallcycle_reset_all(wcycle);
4114     init_nrnb(nrnb);
4115     if (ir->nsteps >= 0)
4116     {
4117         /* ir->nsteps is not used here, but we update it for consistency */
4118         ir->nsteps -= step - ir->init_step;
4119     }
4120     ir->init_step = step;
4121     wallcycle_start(wcycle, ewcRUN);
4122 }
4123
4124
4125 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4126                                ivec grid_size,
4127                                t_commrec *cr, t_inputrec *ir,
4128                                gmx_pme_t *pme_ret)
4129 {
4130     int ind;
4131     gmx_pme_t pme = NULL;
4132
4133     ind = 0;
4134     while (ind < *npmedata)
4135     {
4136         pme = (*pmedata)[ind];
4137         if (pme->nkx == grid_size[XX] &&
4138             pme->nky == grid_size[YY] &&
4139             pme->nkz == grid_size[ZZ])
4140         {
4141             *pme_ret = pme;
4142
4143             return;
4144         }
4145
4146         ind++;
4147     }
4148
4149     (*npmedata)++;
4150     srenew(*pmedata, *npmedata);
4151
4152     /* Generate a new PME data structure, copying part of the old pointers */
4153     gmx_pme_reinit(&((*pmedata)[ind]), cr, pme, ir, grid_size);
4154
4155     *pme_ret = (*pmedata)[ind];
4156 }
4157
4158
4159 int gmx_pmeonly(gmx_pme_t pme,
4160                 t_commrec *cr,    t_nrnb *nrnb,
4161                 gmx_wallcycle_t wcycle,
4162                 real ewaldcoeff,  gmx_bool bGatherOnly,
4163                 t_inputrec *ir)
4164 {
4165     int npmedata;
4166     gmx_pme_t *pmedata;
4167     gmx_pme_pp_t pme_pp;
4168     int  ret;
4169     int  natoms;
4170     matrix box;
4171     rvec *x_pp      = NULL, *f_pp = NULL;
4172     real *chargeA   = NULL, *chargeB = NULL;
4173     real lambda     = 0;
4174     int  maxshift_x = 0, maxshift_y = 0;
4175     real energy, dvdlambda;
4176     matrix vir;
4177     float cycles;
4178     int  count;
4179     gmx_bool bEnerVir;
4180     gmx_large_int_t step, step_rel;
4181     ivec grid_switch;
4182
4183     /* This data will only use with PME tuning, i.e. switching PME grids */
4184     npmedata = 1;
4185     snew(pmedata, npmedata);
4186     pmedata[0] = pme;
4187
4188     pme_pp = gmx_pme_pp_init(cr);
4189
4190     init_nrnb(nrnb);
4191
4192     count = 0;
4193     do /****** this is a quasi-loop over time steps! */
4194     {
4195         /* The reason for having a loop here is PME grid tuning/switching */
4196         do
4197         {
4198             /* Domain decomposition */
4199             ret = gmx_pme_recv_q_x(pme_pp,
4200                                    &natoms,
4201                                    &chargeA, &chargeB, box, &x_pp, &f_pp,
4202                                    &maxshift_x, &maxshift_y,
4203                                    &pme->bFEP, &lambda,
4204                                    &bEnerVir,
4205                                    &step,
4206                                    grid_switch, &ewaldcoeff);
4207
4208             if (ret == pmerecvqxSWITCHGRID)
4209             {
4210                 /* Switch the PME grid to grid_switch */
4211                 gmx_pmeonly_switch(&npmedata, &pmedata, grid_switch, cr, ir, &pme);
4212             }
4213
4214             if (ret == pmerecvqxRESETCOUNTERS)
4215             {
4216                 /* Reset the cycle and flop counters */
4217                 reset_pmeonly_counters(cr, wcycle, nrnb, ir, step);
4218             }
4219         }
4220         while (ret == pmerecvqxSWITCHGRID || ret == pmerecvqxRESETCOUNTERS);
4221
4222         if (ret == pmerecvqxFINISH)
4223         {
4224             /* We should stop: break out of the loop */
4225             break;
4226         }
4227
4228         step_rel = step - ir->init_step;
4229
4230         if (count == 0)
4231         {
4232             wallcycle_start(wcycle, ewcRUN);
4233         }
4234
4235         wallcycle_start(wcycle, ewcPMEMESH);
4236
4237         dvdlambda = 0;
4238         clear_mat(vir);
4239         gmx_pme_do(pme, 0, natoms, x_pp, f_pp, chargeA, chargeB, box,
4240                    cr, maxshift_x, maxshift_y, nrnb, wcycle, vir, ewaldcoeff,
4241                    &energy, lambda, &dvdlambda,
4242                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4243
4244         cycles = wallcycle_stop(wcycle, ewcPMEMESH);
4245
4246         gmx_pme_send_force_vir_ener(pme_pp,
4247                                     f_pp, vir, energy, dvdlambda,
4248                                     cycles);
4249
4250         count++;
4251     } /***** end of quasi-loop, we stop with the break above */
4252     while (TRUE);
4253
4254     return 0;
4255 }
4256
4257 int gmx_pme_do(gmx_pme_t pme,
4258                int start,       int homenr,
4259                rvec x[],        rvec f[],
4260                real *chargeA,   real *chargeB,
4261                matrix box, t_commrec *cr,
4262                int  maxshift_x, int maxshift_y,
4263                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4264                matrix vir,      real ewaldcoeff,
4265                real *energy,    real lambda,
4266                real *dvdlambda, int flags)
4267 {
4268     int     q, d, i, j, ntot, npme;
4269     int     nx, ny, nz;
4270     int     n_d, local_ny;
4271     pme_atomcomm_t *atc = NULL;
4272     pmegrids_t *pmegrid = NULL;
4273     real    *grid       = NULL;
4274     real    *ptr;
4275     rvec    *x_d, *f_d;
4276     real    *charge = NULL, *q_d;
4277     real    energy_AB[2];
4278     matrix  vir_AB[2];
4279     gmx_bool bClearF;
4280     gmx_parallel_3dfft_t pfft_setup;
4281     real *  fftgrid;
4282     t_complex * cfftgrid;
4283     int     thread;
4284     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4285     const gmx_bool bCalcF       = flags & GMX_PME_CALC_F;
4286
4287     assert(pme->nnodes > 0);
4288     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4289
4290     if (pme->nnodes > 1)
4291     {
4292         atc      = &pme->atc[0];
4293         atc->npd = homenr;
4294         if (atc->npd > atc->pd_nalloc)
4295         {
4296             atc->pd_nalloc = over_alloc_dd(atc->npd);
4297             srenew(atc->pd, atc->pd_nalloc);
4298         }
4299         atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4300     }
4301     else
4302     {
4303         /* This could be necessary for TPI */
4304         pme->atc[0].n = homenr;
4305     }
4306
4307     for (q = 0; q < (pme->bFEP ? 2 : 1); q++)
4308     {
4309         if (q == 0)
4310         {
4311             pmegrid    = &pme->pmegridA;
4312             fftgrid    = pme->fftgridA;
4313             cfftgrid   = pme->cfftgridA;
4314             pfft_setup = pme->pfft_setupA;
4315             charge     = chargeA+start;
4316         }
4317         else
4318         {
4319             pmegrid    = &pme->pmegridB;
4320             fftgrid    = pme->fftgridB;
4321             cfftgrid   = pme->cfftgridB;
4322             pfft_setup = pme->pfft_setupB;
4323             charge     = chargeB+start;
4324         }
4325         grid = pmegrid->grid.grid;
4326         /* Unpack structure */
4327         if (debug)
4328         {
4329             fprintf(debug, "PME: nnodes = %d, nodeid = %d\n",
4330                     cr->nnodes, cr->nodeid);
4331             fprintf(debug, "Grid = %p\n", (void*)grid);
4332             if (grid == NULL)
4333             {
4334                 gmx_fatal(FARGS, "No grid!");
4335             }
4336         }
4337         where();
4338
4339         m_inv_ur0(box, pme->recipbox);
4340
4341         if (pme->nnodes == 1)
4342         {
4343             atc = &pme->atc[0];
4344             if (DOMAINDECOMP(cr))
4345             {
4346                 atc->n = homenr;
4347                 pme_realloc_atomcomm_things(atc);
4348             }
4349             atc->x = x;
4350             atc->q = charge;
4351             atc->f = f;
4352         }
4353         else
4354         {
4355             wallcycle_start(wcycle, ewcPME_REDISTXF);
4356             for (d = pme->ndecompdim-1; d >= 0; d--)
4357             {
4358                 if (d == pme->ndecompdim-1)
4359                 {
4360                     n_d = homenr;
4361                     x_d = x + start;
4362                     q_d = charge;
4363                 }
4364                 else
4365                 {
4366                     n_d = pme->atc[d+1].n;
4367                     x_d = atc->x;
4368                     q_d = atc->q;
4369                 }
4370                 atc      = &pme->atc[d];
4371                 atc->npd = n_d;
4372                 if (atc->npd > atc->pd_nalloc)
4373                 {
4374                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4375                     srenew(atc->pd, atc->pd_nalloc);
4376                 }
4377                 atc->maxshift = (atc->dimind == 0 ? maxshift_x : maxshift_y);
4378                 pme_calc_pidx_wrapper(n_d, pme->recipbox, x_d, atc);
4379                 where();
4380
4381                 /* Redistribute x (only once) and qA or qB */
4382                 if (DOMAINDECOMP(cr))
4383                 {
4384                     dd_pmeredist_x_q(pme, n_d, q == 0, x_d, q_d, atc);
4385                 }
4386                 else
4387                 {
4388                     pmeredist_pd(pme, TRUE, n_d, q == 0, x_d, q_d, atc);
4389                 }
4390             }
4391             where();
4392
4393             wallcycle_stop(wcycle, ewcPME_REDISTXF);
4394         }
4395
4396         if (debug)
4397         {
4398             fprintf(debug, "Node= %6d, pme local particles=%6d\n",
4399                     cr->nodeid, atc->n);
4400         }
4401
4402         if (flags & GMX_PME_SPREAD_Q)
4403         {
4404             wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4405
4406             /* Spread the charges on a grid */
4407             spread_on_grid(pme, &pme->atc[0], pmegrid, q == 0, TRUE, fftgrid);
4408
4409             if (q == 0)
4410             {
4411                 inc_nrnb(nrnb, eNR_WEIGHTS, DIM*atc->n);
4412             }
4413             inc_nrnb(nrnb, eNR_SPREADQBSP,
4414                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4415
4416             if (pme->nthread == 1)
4417             {
4418                 wrap_periodic_pmegrid(pme, grid);
4419
4420                 /* sum contributions to local grid from other nodes */
4421 #ifdef GMX_MPI
4422                 if (pme->nnodes > 1)
4423                 {
4424                     gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_FORWARD);
4425                     where();
4426                 }
4427 #endif
4428
4429                 copy_pmegrid_to_fftgrid(pme, grid, fftgrid);
4430             }
4431
4432             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4433
4434             /*
4435                dump_local_fftgrid(pme,fftgrid);
4436                exit(0);
4437              */
4438         }
4439
4440         /* Here we start a large thread parallel region */
4441 #pragma omp parallel num_threads(pme->nthread) private(thread)
4442         {
4443             thread = gmx_omp_get_thread_num();
4444             if (flags & GMX_PME_SOLVE)
4445             {
4446                 int loop_count;
4447
4448                 /* do 3d-fft */
4449                 if (thread == 0)
4450                 {
4451                     wallcycle_start(wcycle, ewcPME_FFT);
4452                 }
4453                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_REAL_TO_COMPLEX,
4454                                            fftgrid, cfftgrid, thread, wcycle);
4455                 if (thread == 0)
4456                 {
4457                     wallcycle_stop(wcycle, ewcPME_FFT);
4458                 }
4459                 where();
4460
4461                 /* solve in k-space for our local cells */
4462                 if (thread == 0)
4463                 {
4464                     wallcycle_start(wcycle, ewcPME_SOLVE);
4465                 }
4466                 loop_count =
4467                     solve_pme_yzx(pme, cfftgrid, ewaldcoeff,
4468                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4469                                   bCalcEnerVir,
4470                                   pme->nthread, thread);
4471                 if (thread == 0)
4472                 {
4473                     wallcycle_stop(wcycle, ewcPME_SOLVE);
4474                     where();
4475                     inc_nrnb(nrnb, eNR_SOLVEPME, loop_count);
4476                 }
4477             }
4478
4479             if (bCalcF)
4480             {
4481                 /* do 3d-invfft */
4482                 if (thread == 0)
4483                 {
4484                     where();
4485                     wallcycle_start(wcycle, ewcPME_FFT);
4486                 }
4487                 gmx_parallel_3dfft_execute(pfft_setup, GMX_FFT_COMPLEX_TO_REAL,
4488                                            cfftgrid, fftgrid, thread, wcycle);
4489                 if (thread == 0)
4490                 {
4491                     wallcycle_stop(wcycle, ewcPME_FFT);
4492
4493                     where();
4494
4495                     if (pme->nodeid == 0)
4496                     {
4497                         ntot  = pme->nkx*pme->nky*pme->nkz;
4498                         npme  = ntot*log((real)ntot)/log(2.0);
4499                         inc_nrnb(nrnb, eNR_FFT, 2*npme);
4500                     }
4501
4502                     wallcycle_start(wcycle, ewcPME_SPREADGATHER);
4503                 }
4504
4505                 copy_fftgrid_to_pmegrid(pme, fftgrid, grid, pme->nthread, thread);
4506             }
4507         }
4508         /* End of thread parallel section.
4509          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4510          */
4511
4512         if (bCalcF)
4513         {
4514             /* distribute local grid to all nodes */
4515 #ifdef GMX_MPI
4516             if (pme->nnodes > 1)
4517             {
4518                 gmx_sum_qgrid_dd(pme, grid, GMX_SUM_QGRID_BACKWARD);
4519             }
4520 #endif
4521             where();
4522
4523             unwrap_periodic_pmegrid(pme, grid);
4524
4525             /* interpolate forces for our local atoms */
4526
4527             where();
4528
4529             /* If we are running without parallelization,
4530              * atc->f is the actual force array, not a buffer,
4531              * therefore we should not clear it.
4532              */
4533             bClearF = (q == 0 && PAR(cr));
4534 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4535             for (thread = 0; thread < pme->nthread; thread++)
4536             {
4537                 gather_f_bsplines(pme, grid, bClearF, atc,
4538                                   &atc->spline[thread],
4539                                   pme->bFEP ? (q == 0 ? 1.0-lambda : lambda) : 1.0);
4540             }
4541
4542             where();
4543
4544             inc_nrnb(nrnb, eNR_GATHERFBSP,
4545                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4546             wallcycle_stop(wcycle, ewcPME_SPREADGATHER);
4547         }
4548
4549         if (bCalcEnerVir)
4550         {
4551             /* This should only be called on the master thread
4552              * and after the threads have synchronized.
4553              */
4554             get_pme_ener_vir(pme, pme->nthread, &energy_AB[q], vir_AB[q]);
4555         }
4556     } /* of q-loop */
4557
4558     if (bCalcF && pme->nnodes > 1)
4559     {
4560         wallcycle_start(wcycle, ewcPME_REDISTXF);
4561         for (d = 0; d < pme->ndecompdim; d++)
4562         {
4563             atc = &pme->atc[d];
4564             if (d == pme->ndecompdim - 1)
4565             {
4566                 n_d = homenr;
4567                 f_d = f + start;
4568             }
4569             else
4570             {
4571                 n_d = pme->atc[d+1].n;
4572                 f_d = pme->atc[d+1].f;
4573             }
4574             if (DOMAINDECOMP(cr))
4575             {
4576                 dd_pmeredist_f(pme, atc, n_d, f_d,
4577                                d == pme->ndecompdim-1 && pme->bPPnode);
4578             }
4579             else
4580             {
4581                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4582             }
4583         }
4584
4585         wallcycle_stop(wcycle, ewcPME_REDISTXF);
4586     }
4587     where();
4588
4589     if (bCalcEnerVir)
4590     {
4591         if (!pme->bFEP)
4592         {
4593             *energy = energy_AB[0];
4594             m_add(vir, vir_AB[0], vir);
4595         }
4596         else
4597         {
4598             *energy     = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4599             *dvdlambda += energy_AB[1] - energy_AB[0];
4600             for (i = 0; i < DIM; i++)
4601             {
4602                 for (j = 0; j < DIM; j++)
4603                 {
4604                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] +
4605                         lambda*vir_AB[1][i][j];
4606                 }
4607             }
4608         }
4609     }
4610     else
4611     {
4612         *energy = 0;
4613     }
4614
4615     if (debug)
4616     {
4617         fprintf(debug, "PME mesh energy: %g\n", *energy);
4618     }
4619
4620     return 0;
4621 }