added Verlet scheme and NxN non-bonded functionality
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include <stdio.h>
71 #include <string.h>
72 #include <math.h>
73 #include <assert.h>
74 #include "typedefs.h"
75 #include "txtdump.h"
76 #include "vec.h"
77 #include "gmxcomplex.h"
78 #include "smalloc.h"
79 #include "futil.h"
80 #include "coulomb.h"
81 #include "gmx_fatal.h"
82 #include "pme.h"
83 #include "network.h"
84 #include "physics.h"
85 #include "nrnb.h"
86 #include "copyrite.h"
87 #include "gmx_wallcycle.h"
88 #include "gmx_parallel_3dfft.h"
89 #include "pdbio.h"
90 #include "gmx_cyclecounter.h"
91
92 /* Single precision, with SSE2 or higher available */
93 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
94
95 #include "gmx_x86_simd_single.h"
96
97 #define PME_SSE
98 /* Some old AMD processors could have problems with unaligned loads+stores */
99 #ifndef GMX_FAHCORE
100 #define PME_SSE_UNALIGNED
101 #endif
102 #endif
103
104 #include "mpelogging.h"
105
106 #define DFT_TOL 1e-7
107 /* #define PRT_FORCE */
108 /* conditions for on the fly time-measurement */
109 /* #define TAKETIME (step > 1 && timesteps < 10) */
110 #define TAKETIME FALSE
111
112 /* #define PME_TIME_THREADS */
113
114 #ifdef GMX_DOUBLE
115 #define mpi_type MPI_DOUBLE
116 #else
117 #define mpi_type MPI_FLOAT
118 #endif
119
120 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
121 #define GMX_CACHE_SEP 64
122
123 /* We only define a maximum to be able to use local arrays without allocation.
124  * An order larger than 12 should never be needed, even for test cases.
125  * If needed it can be changed here.
126  */
127 #define PME_ORDER_MAX 12
128
129 /* Internal datastructures */
130 typedef struct {
131     int send_index0;
132     int send_nindex;
133     int recv_index0;
134     int recv_nindex;
135     int recv_size;   /* Receive buffer width, used with OpenMP */
136 } pme_grid_comm_t;
137
138 typedef struct {
139 #ifdef GMX_MPI
140     MPI_Comm mpi_comm;
141 #endif
142     int  nnodes,nodeid;
143     int  *s2g0;
144     int  *s2g1;
145     int  noverlap_nodes;
146     int  *send_id,*recv_id;
147     int  send_size;             /* Send buffer width, used with OpenMP */
148     pme_grid_comm_t *comm_data;
149     real *sendbuf;
150     real *recvbuf;
151 } pme_overlap_t;
152
153 typedef struct {
154     int *n;     /* Cumulative counts of the number of particles per thread */
155     int nalloc; /* Allocation size of i */
156     int *i;     /* Particle indices ordered on thread index (n) */
157 } thread_plist_t;
158
159 typedef struct {
160     int  *thread_one;
161     int  n;
162     int  *ind;
163     splinevec theta;
164     real *ptr_theta_z;
165     splinevec dtheta;
166     real *ptr_dtheta_z;
167 } splinedata_t;
168
169 typedef struct {
170     int  dimind;            /* The index of the dimension, 0=x, 1=y */
171     int  nslab;
172     int  nodeid;
173 #ifdef GMX_MPI
174     MPI_Comm mpi_comm;
175 #endif
176
177     int  *node_dest;        /* The nodes to send x and q to with DD */
178     int  *node_src;         /* The nodes to receive x and q from with DD */
179     int  *buf_index;        /* Index for commnode into the buffers */
180
181     int  maxshift;
182
183     int  npd;
184     int  pd_nalloc;
185     int  *pd;
186     int  *count;            /* The number of atoms to send to each node */
187     int  **count_thread;
188     int  *rcount;           /* The number of atoms to receive */
189
190     int  n;
191     int  nalloc;
192     rvec *x;
193     real *q;
194     rvec *f;
195     gmx_bool bSpread;       /* These coordinates are used for spreading */
196     int  pme_order;
197     ivec *idx;
198     rvec *fractx;            /* Fractional coordinate relative to the
199                               * lower cell boundary
200                               */
201     int  nthread;
202     int  *thread_idx;        /* Which thread should spread which charge */
203     thread_plist_t *thread_plist;
204     splinedata_t *spline;
205 } pme_atomcomm_t;
206
207 #define FLBS  3
208 #define FLBSZ 4
209
210 typedef struct {
211     ivec ci;     /* The spatial location of this grid         */
212     ivec n;      /* The used size of *grid, including order-1 */
213     ivec offset; /* The grid offset from the full node grid   */
214     int  order;  /* PME spreading order                       */
215     ivec s;      /* The allocated size of *grid, s >= n       */
216     real *grid;  /* The grid local thread, size n             */
217 } pmegrid_t;
218
219 typedef struct {
220     pmegrid_t grid;     /* The full node grid (non thread-local)            */
221     int  nthread;       /* The number of threads operating on this grid     */
222     ivec nc;            /* The local spatial decomposition over the threads */
223     pmegrid_t *grid_th; /* Array of grids for each thread                   */
224     real *grid_all;     /* Allocated array for the grids in *grid_th        */
225     int  **g2t;         /* The grid to thread index                         */
226     ivec nthread_comm;  /* The number of threads to communicate with        */
227 } pmegrids_t;
228
229
230 typedef struct {
231 #ifdef PME_SSE
232     /* Masks for SSE aligned spreading and gathering */
233     __m128 mask_SSE0[6],mask_SSE1[6];
234 #else
235     int dummy; /* C89 requires that struct has at least one member */
236 #endif
237 } pme_spline_work_t;
238
239 typedef struct {
240     /* work data for solve_pme */
241     int      nalloc;
242     real *   mhx;
243     real *   mhy;
244     real *   mhz;
245     real *   m2;
246     real *   denom;
247     real *   tmp1_alloc;
248     real *   tmp1;
249     real *   eterm;
250     real *   m2inv;
251
252     real     energy;
253     matrix   vir;
254 } pme_work_t;
255
256 typedef struct gmx_pme {
257     int  ndecompdim;         /* The number of decomposition dimensions */
258     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
259     int  nodeid_major;
260     int  nodeid_minor;
261     int  nnodes;             /* The number of nodes doing PME */
262     int  nnodes_major;
263     int  nnodes_minor;
264
265     MPI_Comm mpi_comm;
266     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
267 #ifdef GMX_MPI
268     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
269 #endif
270
271     int  nthread;            /* The number of threads doing PME */
272
273     gmx_bool bPPnode;        /* Node also does particle-particle forces */
274     gmx_bool bFEP;           /* Compute Free energy contribution */
275     int nkx,nky,nkz;         /* Grid dimensions */
276     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
277     int pme_order;
278     real epsilon_r;
279
280     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
281     pmegrids_t pmegridB;
282     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
283     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
284     /* pmegrid_nz might be larger than strictly necessary to ensure
285      * memory alignment, pmegrid_nz_base gives the real base size.
286      */
287     int     pmegrid_nz_base;
288     /* The local PME grid starting indices */
289     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
290
291     /* Work data for spreading and gathering */
292     pme_spline_work_t *spline_work;
293
294     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
295     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
296     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
297
298     t_complex *cfftgridA;             /* Grids for complex FFT data */
299     t_complex *cfftgridB;
300     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
301
302     gmx_parallel_3dfft_t  pfft_setupA;
303     gmx_parallel_3dfft_t  pfft_setupB;
304
305     int  *nnx,*nny,*nnz;
306     real *fshx,*fshy,*fshz;
307
308     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
309     matrix    recipbox;
310     splinevec bsp_mod;
311
312     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
313
314     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
315
316     rvec *bufv;             /* Communication buffer */
317     real *bufr;             /* Communication buffer */
318     int  buf_nalloc;        /* The communication buffer size */
319
320     /* thread local work data for solve_pme */
321     pme_work_t *work;
322
323     /* Work data for PME_redist */
324     gmx_bool redist_init;
325     int *    scounts;
326     int *    rcounts;
327     int *    sdispls;
328     int *    rdispls;
329     int *    sidx;
330     int *    idxa;
331     real *   redist_buf;
332     int      redist_buf_nalloc;
333
334     /* Work data for sum_qgrid */
335     real *   sum_qgrid_tmp;
336     real *   sum_qgrid_dd_tmp;
337 } t_gmx_pme;
338
339
340 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
341                                    int start,int end,int thread)
342 {
343     int  i;
344     int  *idxptr,tix,tiy,tiz;
345     real *xptr,*fptr,tx,ty,tz;
346     real rxx,ryx,ryy,rzx,rzy,rzz;
347     int  nx,ny,nz;
348     int  start_ix,start_iy,start_iz;
349     int  *g2tx,*g2ty,*g2tz;
350     gmx_bool bThreads;
351     int  *thread_idx=NULL;
352     thread_plist_t *tpl=NULL;
353     int  *tpl_n=NULL;
354     int  thread_i;
355
356     nx  = pme->nkx;
357     ny  = pme->nky;
358     nz  = pme->nkz;
359
360     start_ix = pme->pmegrid_start_ix;
361     start_iy = pme->pmegrid_start_iy;
362     start_iz = pme->pmegrid_start_iz;
363
364     rxx = pme->recipbox[XX][XX];
365     ryx = pme->recipbox[YY][XX];
366     ryy = pme->recipbox[YY][YY];
367     rzx = pme->recipbox[ZZ][XX];
368     rzy = pme->recipbox[ZZ][YY];
369     rzz = pme->recipbox[ZZ][ZZ];
370
371     g2tx = pme->pmegridA.g2t[XX];
372     g2ty = pme->pmegridA.g2t[YY];
373     g2tz = pme->pmegridA.g2t[ZZ];
374
375     bThreads = (atc->nthread > 1);
376     if (bThreads)
377     {
378         thread_idx = atc->thread_idx;
379
380         tpl   = &atc->thread_plist[thread];
381         tpl_n = tpl->n;
382         for(i=0; i<atc->nthread; i++)
383         {
384             tpl_n[i] = 0;
385         }
386     }
387
388     for(i=start; i<end; i++) {
389         xptr   = atc->x[i];
390         idxptr = atc->idx[i];
391         fptr   = atc->fractx[i];
392
393         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
394         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
395         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
396         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
397
398         tix = (int)(tx);
399         tiy = (int)(ty);
400         tiz = (int)(tz);
401
402         /* Because decomposition only occurs in x and y,
403          * we never have a fraction correction in z.
404          */
405         fptr[XX] = tx - tix + pme->fshx[tix];
406         fptr[YY] = ty - tiy + pme->fshy[tiy];
407         fptr[ZZ] = tz - tiz;
408
409         idxptr[XX] = pme->nnx[tix];
410         idxptr[YY] = pme->nny[tiy];
411         idxptr[ZZ] = pme->nnz[tiz];
412
413 #ifdef DEBUG
414         range_check(idxptr[XX],0,pme->pmegrid_nx);
415         range_check(idxptr[YY],0,pme->pmegrid_ny);
416         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
417 #endif
418
419         if (bThreads)
420         {
421             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
422             thread_idx[i] = thread_i;
423             tpl_n[thread_i]++;
424         }
425     }
426
427     if (bThreads)
428     {
429         /* Make a list of particle indices sorted on thread */
430
431         /* Get the cumulative count */
432         for(i=1; i<atc->nthread; i++)
433         {
434             tpl_n[i] += tpl_n[i-1];
435         }
436         /* The current implementation distributes particles equally
437          * over the threads, so we could actually allocate for that
438          * in pme_realloc_atomcomm_things.
439          */
440         if (tpl_n[atc->nthread-1] > tpl->nalloc)
441         {
442             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
443             srenew(tpl->i,tpl->nalloc);
444         }
445         /* Set tpl_n to the cumulative start */
446         for(i=atc->nthread-1; i>=1; i--)
447         {
448             tpl_n[i] = tpl_n[i-1];
449         }
450         tpl_n[0] = 0;
451
452         /* Fill our thread local array with indices sorted on thread */
453         for(i=start; i<end; i++)
454         {
455             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
456         }
457         /* Now tpl_n contains the cummulative count again */
458     }
459 }
460
461 static void make_thread_local_ind(pme_atomcomm_t *atc,
462                                   int thread,splinedata_t *spline)
463 {
464     int  n,t,i,start,end;
465     thread_plist_t *tpl;
466
467     /* Combine the indices made by each thread into one index */
468
469     n = 0;
470     start = 0;
471     for(t=0; t<atc->nthread; t++)
472     {
473         tpl = &atc->thread_plist[t];
474         /* Copy our part (start - end) from the list of thread t */
475         if (thread > 0)
476         {
477             start = tpl->n[thread-1];
478         }
479         end = tpl->n[thread];
480         for(i=start; i<end; i++)
481         {
482             spline->ind[n++] = tpl->i[i];
483         }
484     }
485
486     spline->n = n;
487 }
488
489
490 static void pme_calc_pidx(int start, int end,
491                           matrix recipbox, rvec x[],
492                           pme_atomcomm_t *atc, int *count)
493 {
494     int  nslab,i;
495     int  si;
496     real *xptr,s;
497     real rxx,ryx,rzx,ryy,rzy;
498     int *pd;
499
500     /* Calculate PME task index (pidx) for each grid index.
501      * Here we always assign equally sized slabs to each node
502      * for load balancing reasons (the PME grid spacing is not used).
503      */
504
505     nslab = atc->nslab;
506     pd    = atc->pd;
507
508     /* Reset the count */
509     for(i=0; i<nslab; i++)
510     {
511         count[i] = 0;
512     }
513
514     if (atc->dimind == 0)
515     {
516         rxx = recipbox[XX][XX];
517         ryx = recipbox[YY][XX];
518         rzx = recipbox[ZZ][XX];
519         /* Calculate the node index in x-dimension */
520         for(i=start; i<end; i++)
521         {
522             xptr   = x[i];
523             /* Fractional coordinates along box vectors */
524             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
525             si = (int)(s + 2*nslab) % nslab;
526             pd[i] = si;
527             count[si]++;
528         }
529     }
530     else
531     {
532         ryy = recipbox[YY][YY];
533         rzy = recipbox[ZZ][YY];
534         /* Calculate the node index in y-dimension */
535         for(i=start; i<end; i++)
536         {
537             xptr   = x[i];
538             /* Fractional coordinates along box vectors */
539             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
540             si = (int)(s + 2*nslab) % nslab;
541             pd[i] = si;
542             count[si]++;
543         }
544     }
545 }
546
547 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
548                                   pme_atomcomm_t *atc)
549 {
550     int nthread,thread,slab;
551
552     nthread = atc->nthread;
553
554 #pragma omp parallel for num_threads(nthread) schedule(static)
555     for(thread=0; thread<nthread; thread++)
556     {
557         pme_calc_pidx(natoms* thread   /nthread,
558                       natoms*(thread+1)/nthread,
559                       recipbox,x,atc,atc->count_thread[thread]);
560     }
561     /* Non-parallel reduction, since nslab is small */
562
563     for(thread=1; thread<nthread; thread++)
564     {
565         for(slab=0; slab<atc->nslab; slab++)
566         {
567             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
568         }
569     }
570 }
571
572 static void realloc_splinevec(splinevec th,real **ptr_z,int nalloc)
573 {
574     const int padding=4;
575     int i;
576
577     srenew(th[XX],nalloc);
578     srenew(th[YY],nalloc);
579     /* In z we add padding, this is only required for the aligned SSE code */
580     srenew(*ptr_z,nalloc+2*padding);
581     th[ZZ] = *ptr_z + padding;
582
583     for(i=0; i<padding; i++)
584     {
585         (*ptr_z)[               i] = 0;
586         (*ptr_z)[padding+nalloc+i] = 0;
587     }
588 }
589
590 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
591 {
592     int i,d;
593
594     srenew(spline->ind,atc->nalloc);
595     /* Initialize the index to identity so it works without threads */
596     for(i=0; i<atc->nalloc; i++)
597     {
598         spline->ind[i] = i;
599     }
600
601     realloc_splinevec(spline->theta,&spline->ptr_theta_z,
602                       atc->pme_order*atc->nalloc);
603     realloc_splinevec(spline->dtheta,&spline->ptr_dtheta_z,
604                       atc->pme_order*atc->nalloc);
605 }
606
607 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
608 {
609     int nalloc_old,i,j,nalloc_tpl;
610
611     /* We have to avoid a NULL pointer for atc->x to avoid
612      * possible fatal errors in MPI routines.
613      */
614     if (atc->n > atc->nalloc || atc->nalloc == 0)
615     {
616         nalloc_old = atc->nalloc;
617         atc->nalloc = over_alloc_dd(max(atc->n,1));
618
619         if (atc->nslab > 1) {
620             srenew(atc->x,atc->nalloc);
621             srenew(atc->q,atc->nalloc);
622             srenew(atc->f,atc->nalloc);
623             for(i=nalloc_old; i<atc->nalloc; i++)
624             {
625                 clear_rvec(atc->f[i]);
626             }
627         }
628         if (atc->bSpread) {
629             srenew(atc->fractx,atc->nalloc);
630             srenew(atc->idx   ,atc->nalloc);
631
632             if (atc->nthread > 1)
633             {
634                 srenew(atc->thread_idx,atc->nalloc);
635             }
636
637             for(i=0; i<atc->nthread; i++)
638             {
639                 pme_realloc_splinedata(&atc->spline[i],atc);
640             }
641         }
642     }
643 }
644
645 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
646                          int n, gmx_bool bXF, rvec *x_f, real *charge,
647                          pme_atomcomm_t *atc)
648 /* Redistribute particle data for PME calculation */
649 /* domain decomposition by x coordinate           */
650 {
651     int *idxa;
652     int i, ii;
653
654     if(FALSE == pme->redist_init) {
655         snew(pme->scounts,atc->nslab);
656         snew(pme->rcounts,atc->nslab);
657         snew(pme->sdispls,atc->nslab);
658         snew(pme->rdispls,atc->nslab);
659         snew(pme->sidx,atc->nslab);
660         pme->redist_init = TRUE;
661     }
662     if (n > pme->redist_buf_nalloc) {
663         pme->redist_buf_nalloc = over_alloc_dd(n);
664         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
665     }
666
667     pme->idxa = atc->pd;
668
669 #ifdef GMX_MPI
670     if (forw && bXF) {
671         /* forward, redistribution from pp to pme */
672
673         /* Calculate send counts and exchange them with other nodes */
674         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
675         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
676         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
677
678         /* Calculate send and receive displacements and index into send
679            buffer */
680         pme->sdispls[0]=0;
681         pme->rdispls[0]=0;
682         pme->sidx[0]=0;
683         for(i=1; i<atc->nslab; i++) {
684             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
685             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
686             pme->sidx[i]=pme->sdispls[i];
687         }
688         /* Total # of particles to be received */
689         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
690
691         pme_realloc_atomcomm_things(atc);
692
693         /* Copy particle coordinates into send buffer and exchange*/
694         for(i=0; (i<n); i++) {
695             ii=DIM*pme->sidx[pme->idxa[i]];
696             pme->sidx[pme->idxa[i]]++;
697             pme->redist_buf[ii+XX]=x_f[i][XX];
698             pme->redist_buf[ii+YY]=x_f[i][YY];
699             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
700         }
701         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
702                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
703                       pme->rvec_mpi, atc->mpi_comm);
704     }
705     if (forw) {
706         /* Copy charge into send buffer and exchange*/
707         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
708         for(i=0; (i<n); i++) {
709             ii=pme->sidx[pme->idxa[i]];
710             pme->sidx[pme->idxa[i]]++;
711             pme->redist_buf[ii]=charge[i];
712         }
713         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
714                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
715                       atc->mpi_comm);
716     }
717     else { /* backward, redistribution from pme to pp */
718         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
719                       pme->redist_buf, pme->scounts, pme->sdispls,
720                       pme->rvec_mpi, atc->mpi_comm);
721
722         /* Copy data from receive buffer */
723         for(i=0; i<atc->nslab; i++)
724             pme->sidx[i] = pme->sdispls[i];
725         for(i=0; (i<n); i++) {
726             ii = DIM*pme->sidx[pme->idxa[i]];
727             x_f[i][XX] += pme->redist_buf[ii+XX];
728             x_f[i][YY] += pme->redist_buf[ii+YY];
729             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
730             pme->sidx[pme->idxa[i]]++;
731         }
732     }
733 #endif
734 }
735
736 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
737                             gmx_bool bBackward,int shift,
738                             void *buf_s,int nbyte_s,
739                             void *buf_r,int nbyte_r)
740 {
741 #ifdef GMX_MPI
742     int dest,src;
743     MPI_Status stat;
744
745     if (bBackward == FALSE) {
746         dest = atc->node_dest[shift];
747         src  = atc->node_src[shift];
748     } else {
749         dest = atc->node_src[shift];
750         src  = atc->node_dest[shift];
751     }
752
753     if (nbyte_s > 0 && nbyte_r > 0) {
754         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
755                      dest,shift,
756                      buf_r,nbyte_r,MPI_BYTE,
757                      src,shift,
758                      atc->mpi_comm,&stat);
759     } else if (nbyte_s > 0) {
760         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
761                  dest,shift,
762                  atc->mpi_comm);
763     } else if (nbyte_r > 0) {
764         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
765                  src,shift,
766                  atc->mpi_comm,&stat);
767     }
768 #endif
769 }
770
771 static void dd_pmeredist_x_q(gmx_pme_t pme,
772                              int n, gmx_bool bX, rvec *x, real *charge,
773                              pme_atomcomm_t *atc)
774 {
775     int *commnode,*buf_index;
776     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
777
778     commnode  = atc->node_dest;
779     buf_index = atc->buf_index;
780
781     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
782
783     nsend = 0;
784     for(i=0; i<nnodes_comm; i++) {
785         buf_index[commnode[i]] = nsend;
786         nsend += atc->count[commnode[i]];
787     }
788     if (bX) {
789         if (atc->count[atc->nodeid] + nsend != n)
790             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
791                       "This usually means that your system is not well equilibrated.",
792                       n - (atc->count[atc->nodeid] + nsend),
793                       pme->nodeid,'x'+atc->dimind);
794
795         if (nsend > pme->buf_nalloc) {
796             pme->buf_nalloc = over_alloc_dd(nsend);
797             srenew(pme->bufv,pme->buf_nalloc);
798             srenew(pme->bufr,pme->buf_nalloc);
799         }
800
801         atc->n = atc->count[atc->nodeid];
802         for(i=0; i<nnodes_comm; i++) {
803             scount = atc->count[commnode[i]];
804             /* Communicate the count */
805             if (debug)
806                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
807                         atc->dimind,atc->nodeid,commnode[i],scount);
808             pme_dd_sendrecv(atc,FALSE,i,
809                             &scount,sizeof(int),
810                             &atc->rcount[i],sizeof(int));
811             atc->n += atc->rcount[i];
812         }
813
814         pme_realloc_atomcomm_things(atc);
815     }
816
817     local_pos = 0;
818     for(i=0; i<n; i++) {
819         node = atc->pd[i];
820         if (node == atc->nodeid) {
821             /* Copy direct to the receive buffer */
822             if (bX) {
823                 copy_rvec(x[i],atc->x[local_pos]);
824             }
825             atc->q[local_pos] = charge[i];
826             local_pos++;
827         } else {
828             /* Copy to the send buffer */
829             if (bX) {
830                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
831             }
832             pme->bufr[buf_index[node]] = charge[i];
833             buf_index[node]++;
834         }
835     }
836
837     buf_pos = 0;
838     for(i=0; i<nnodes_comm; i++) {
839         scount = atc->count[commnode[i]];
840         rcount = atc->rcount[i];
841         if (scount > 0 || rcount > 0) {
842             if (bX) {
843                 /* Communicate the coordinates */
844                 pme_dd_sendrecv(atc,FALSE,i,
845                                 pme->bufv[buf_pos],scount*sizeof(rvec),
846                                 atc->x[local_pos],rcount*sizeof(rvec));
847             }
848             /* Communicate the charges */
849             pme_dd_sendrecv(atc,FALSE,i,
850                             pme->bufr+buf_pos,scount*sizeof(real),
851                             atc->q+local_pos,rcount*sizeof(real));
852             buf_pos   += scount;
853             local_pos += atc->rcount[i];
854         }
855     }
856 }
857
858 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
859                            int n, rvec *f,
860                            gmx_bool bAddF)
861 {
862   int *commnode,*buf_index;
863   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
864
865   commnode  = atc->node_dest;
866   buf_index = atc->buf_index;
867
868   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
869
870   local_pos = atc->count[atc->nodeid];
871   buf_pos = 0;
872   for(i=0; i<nnodes_comm; i++) {
873     scount = atc->rcount[i];
874     rcount = atc->count[commnode[i]];
875     if (scount > 0 || rcount > 0) {
876       /* Communicate the forces */
877       pme_dd_sendrecv(atc,TRUE,i,
878                       atc->f[local_pos],scount*sizeof(rvec),
879                       pme->bufv[buf_pos],rcount*sizeof(rvec));
880       local_pos += scount;
881     }
882     buf_index[commnode[i]] = buf_pos;
883     buf_pos   += rcount;
884   }
885
886     local_pos = 0;
887     if (bAddF)
888     {
889         for(i=0; i<n; i++)
890         {
891             node = atc->pd[i];
892             if (node == atc->nodeid)
893             {
894                 /* Add from the local force array */
895                 rvec_inc(f[i],atc->f[local_pos]);
896                 local_pos++;
897             }
898             else
899             {
900                 /* Add from the receive buffer */
901                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
902                 buf_index[node]++;
903             }
904         }
905     }
906     else
907     {
908         for(i=0; i<n; i++)
909         {
910             node = atc->pd[i];
911             if (node == atc->nodeid)
912             {
913                 /* Copy from the local force array */
914                 copy_rvec(atc->f[local_pos],f[i]);
915                 local_pos++;
916             }
917             else
918             {
919                 /* Copy from the receive buffer */
920                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
921                 buf_index[node]++;
922             }
923         }
924     }
925 }
926
927 #ifdef GMX_MPI
928 static void
929 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
930 {
931     pme_overlap_t *overlap;
932     int send_index0,send_nindex;
933     int recv_index0,recv_nindex;
934     MPI_Status stat;
935     int i,j,k,ix,iy,iz,icnt;
936     int ipulse,send_id,recv_id,datasize;
937     real *p;
938     real *sendptr,*recvptr;
939
940     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
941     overlap = &pme->overlap[1];
942
943     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
944     {
945         /* Since we have already (un)wrapped the overlap in the z-dimension,
946          * we only have to communicate 0 to nkz (not pmegrid_nz).
947          */
948         if (direction==GMX_SUM_QGRID_FORWARD)
949         {
950             send_id = overlap->send_id[ipulse];
951             recv_id = overlap->recv_id[ipulse];
952             send_index0   = overlap->comm_data[ipulse].send_index0;
953             send_nindex   = overlap->comm_data[ipulse].send_nindex;
954             recv_index0   = overlap->comm_data[ipulse].recv_index0;
955             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
956         }
957         else
958         {
959             send_id = overlap->recv_id[ipulse];
960             recv_id = overlap->send_id[ipulse];
961             send_index0   = overlap->comm_data[ipulse].recv_index0;
962             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
963             recv_index0   = overlap->comm_data[ipulse].send_index0;
964             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
965         }
966
967         /* Copy data to contiguous send buffer */
968         if (debug)
969         {
970             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
971                     pme->nodeid,overlap->nodeid,send_id,
972                     pme->pmegrid_start_iy,
973                     send_index0-pme->pmegrid_start_iy,
974                     send_index0-pme->pmegrid_start_iy+send_nindex);
975         }
976         icnt = 0;
977         for(i=0;i<pme->pmegrid_nx;i++)
978         {
979             ix = i;
980             for(j=0;j<send_nindex;j++)
981             {
982                 iy = j + send_index0 - pme->pmegrid_start_iy;
983                 for(k=0;k<pme->nkz;k++)
984                 {
985                     iz = k;
986                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
987                 }
988             }
989         }
990
991         datasize      = pme->pmegrid_nx * pme->nkz;
992
993         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
994                      send_id,ipulse,
995                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
996                      recv_id,ipulse,
997                      overlap->mpi_comm,&stat);
998
999         /* Get data from contiguous recv buffer */
1000         if (debug)
1001         {
1002             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1003                     pme->nodeid,overlap->nodeid,recv_id,
1004                     pme->pmegrid_start_iy,
1005                     recv_index0-pme->pmegrid_start_iy,
1006                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1007         }
1008         icnt = 0;
1009         for(i=0;i<pme->pmegrid_nx;i++)
1010         {
1011             ix = i;
1012             for(j=0;j<recv_nindex;j++)
1013             {
1014                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1015                 for(k=0;k<pme->nkz;k++)
1016                 {
1017                     iz = k;
1018                     if(direction==GMX_SUM_QGRID_FORWARD)
1019                     {
1020                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1021                     }
1022                     else
1023                     {
1024                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1025                     }
1026                 }
1027             }
1028         }
1029     }
1030
1031     /* Major dimension is easier, no copying required,
1032      * but we might have to sum to separate array.
1033      * Since we don't copy, we have to communicate up to pmegrid_nz,
1034      * not nkz as for the minor direction.
1035      */
1036     overlap = &pme->overlap[0];
1037
1038     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1039     {
1040         if(direction==GMX_SUM_QGRID_FORWARD)
1041         {
1042             send_id = overlap->send_id[ipulse];
1043             recv_id = overlap->recv_id[ipulse];
1044             send_index0   = overlap->comm_data[ipulse].send_index0;
1045             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1046             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1047             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1048             recvptr   = overlap->recvbuf;
1049         }
1050         else
1051         {
1052             send_id = overlap->recv_id[ipulse];
1053             recv_id = overlap->send_id[ipulse];
1054             send_index0   = overlap->comm_data[ipulse].recv_index0;
1055             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1056             recv_index0   = overlap->comm_data[ipulse].send_index0;
1057             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1058             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1059         }
1060
1061         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1062         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1063
1064         if (debug)
1065         {
1066             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1067                     pme->nodeid,overlap->nodeid,send_id,
1068                     pme->pmegrid_start_ix,
1069                     send_index0-pme->pmegrid_start_ix,
1070                     send_index0-pme->pmegrid_start_ix+send_nindex);
1071             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1072                     pme->nodeid,overlap->nodeid,recv_id,
1073                     pme->pmegrid_start_ix,
1074                     recv_index0-pme->pmegrid_start_ix,
1075                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1076         }
1077
1078         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1079                      send_id,ipulse,
1080                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1081                      recv_id,ipulse,
1082                      overlap->mpi_comm,&stat);
1083
1084         /* ADD data from contiguous recv buffer */
1085         if(direction==GMX_SUM_QGRID_FORWARD)
1086         {
1087             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1088             for(i=0;i<recv_nindex*datasize;i++)
1089             {
1090                 p[i] += overlap->recvbuf[i];
1091             }
1092         }
1093     }
1094 }
1095 #endif
1096
1097
1098 static int
1099 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1100 {
1101     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1102     ivec    local_pme_size;
1103     int     i,ix,iy,iz;
1104     int     pmeidx,fftidx;
1105
1106     /* Dimensions should be identical for A/B grid, so we just use A here */
1107     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1108                                    local_fft_ndata,
1109                                    local_fft_offset,
1110                                    local_fft_size);
1111
1112     local_pme_size[0] = pme->pmegrid_nx;
1113     local_pme_size[1] = pme->pmegrid_ny;
1114     local_pme_size[2] = pme->pmegrid_nz;
1115
1116     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1117      the offset is identical, and the PME grid always has more data (due to overlap)
1118      */
1119     {
1120 #ifdef DEBUG_PME
1121         FILE *fp,*fp2;
1122         char fn[STRLEN],format[STRLEN];
1123         real val;
1124         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1125         fp = ffopen(fn,"w");
1126         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1127         fp2 = ffopen(fn,"w");
1128      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1129 #endif
1130
1131     for(ix=0;ix<local_fft_ndata[XX];ix++)
1132     {
1133         for(iy=0;iy<local_fft_ndata[YY];iy++)
1134         {
1135             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1136             {
1137                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1138                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1139                 fftgrid[fftidx] = pmegrid[pmeidx];
1140 #ifdef DEBUG_PME
1141                 val = 100*pmegrid[pmeidx];
1142                 if (pmegrid[pmeidx] != 0)
1143                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1144                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1145                 if (pmegrid[pmeidx] != 0)
1146                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1147                             "qgrid",
1148                             pme->pmegrid_start_ix + ix,
1149                             pme->pmegrid_start_iy + iy,
1150                             pme->pmegrid_start_iz + iz,
1151                             pmegrid[pmeidx]);
1152 #endif
1153             }
1154         }
1155     }
1156 #ifdef DEBUG_PME
1157     ffclose(fp);
1158     ffclose(fp2);
1159 #endif
1160     }
1161     return 0;
1162 }
1163
1164
1165 static gmx_cycles_t omp_cyc_start()
1166 {
1167     return gmx_cycles_read();
1168 }
1169
1170 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1171 {
1172     return gmx_cycles_read() - c;
1173 }
1174
1175
1176 static int
1177 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1178                         int nthread,int thread)
1179 {
1180     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1181     ivec    local_pme_size;
1182     int     ixy0,ixy1,ixy,ix,iy,iz;
1183     int     pmeidx,fftidx;
1184 #ifdef PME_TIME_THREADS
1185     gmx_cycles_t c1;
1186     static double cs1=0;
1187     static int cnt=0;
1188 #endif
1189
1190 #ifdef PME_TIME_THREADS
1191     c1 = omp_cyc_start();
1192 #endif
1193     /* Dimensions should be identical for A/B grid, so we just use A here */
1194     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1195                                    local_fft_ndata,
1196                                    local_fft_offset,
1197                                    local_fft_size);
1198
1199     local_pme_size[0] = pme->pmegrid_nx;
1200     local_pme_size[1] = pme->pmegrid_ny;
1201     local_pme_size[2] = pme->pmegrid_nz;
1202
1203     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1204      the offset is identical, and the PME grid always has more data (due to overlap)
1205      */
1206     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1207     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1208
1209     for(ixy=ixy0;ixy<ixy1;ixy++)
1210     {
1211         ix = ixy/local_fft_ndata[YY];
1212         iy = ixy - ix*local_fft_ndata[YY];
1213
1214         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1215         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1216         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1217         {
1218             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1219         }
1220     }
1221
1222 #ifdef PME_TIME_THREADS
1223     c1 = omp_cyc_end(c1);
1224     cs1 += (double)c1;
1225     cnt++;
1226     if (cnt % 20 == 0)
1227     {
1228         printf("copy %.2f\n",cs1*1e-9);
1229     }
1230 #endif
1231
1232     return 0;
1233 }
1234
1235
1236 static void
1237 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1238 {
1239     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1240
1241     nx = pme->nkx;
1242     ny = pme->nky;
1243     nz = pme->nkz;
1244
1245     pnx = pme->pmegrid_nx;
1246     pny = pme->pmegrid_ny;
1247     pnz = pme->pmegrid_nz;
1248
1249     overlap = pme->pme_order - 1;
1250
1251     /* Add periodic overlap in z */
1252     for(ix=0; ix<pme->pmegrid_nx; ix++)
1253     {
1254         for(iy=0; iy<pme->pmegrid_ny; iy++)
1255         {
1256             for(iz=0; iz<overlap; iz++)
1257             {
1258                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1259                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1260             }
1261         }
1262     }
1263
1264     if (pme->nnodes_minor == 1)
1265     {
1266        for(ix=0; ix<pme->pmegrid_nx; ix++)
1267        {
1268            for(iy=0; iy<overlap; iy++)
1269            {
1270                for(iz=0; iz<nz; iz++)
1271                {
1272                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1273                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1274                }
1275            }
1276        }
1277     }
1278
1279     if (pme->nnodes_major == 1)
1280     {
1281         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1282
1283         for(ix=0; ix<overlap; ix++)
1284         {
1285             for(iy=0; iy<ny_x; iy++)
1286             {
1287                 for(iz=0; iz<nz; iz++)
1288                 {
1289                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1290                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1291                 }
1292             }
1293         }
1294     }
1295 }
1296
1297
1298 static void
1299 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1300 {
1301     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1302
1303     nx = pme->nkx;
1304     ny = pme->nky;
1305     nz = pme->nkz;
1306
1307     pnx = pme->pmegrid_nx;
1308     pny = pme->pmegrid_ny;
1309     pnz = pme->pmegrid_nz;
1310
1311     overlap = pme->pme_order - 1;
1312
1313     if (pme->nnodes_major == 1)
1314     {
1315         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1316
1317         for(ix=0; ix<overlap; ix++)
1318         {
1319             int iy,iz;
1320
1321             for(iy=0; iy<ny_x; iy++)
1322             {
1323                 for(iz=0; iz<nz; iz++)
1324                 {
1325                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1326                         pmegrid[(ix*pny+iy)*pnz+iz];
1327                 }
1328             }
1329         }
1330     }
1331
1332     if (pme->nnodes_minor == 1)
1333     {
1334 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1335        for(ix=0; ix<pme->pmegrid_nx; ix++)
1336        {
1337            int iy,iz;
1338
1339            for(iy=0; iy<overlap; iy++)
1340            {
1341                for(iz=0; iz<nz; iz++)
1342                {
1343                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1344                        pmegrid[(ix*pny+iy)*pnz+iz];
1345                }
1346            }
1347        }
1348     }
1349
1350     /* Copy periodic overlap in z */
1351 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1352     for(ix=0; ix<pme->pmegrid_nx; ix++)
1353     {
1354         int iy,iz;
1355
1356         for(iy=0; iy<pme->pmegrid_ny; iy++)
1357         {
1358             for(iz=0; iz<overlap; iz++)
1359             {
1360                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1361                     pmegrid[(ix*pny+iy)*pnz+iz];
1362             }
1363         }
1364     }
1365 }
1366
1367 static void clear_grid(int nx,int ny,int nz,real *grid,
1368                        ivec fs,int *flag,
1369                        int fx,int fy,int fz,
1370                        int order)
1371 {
1372     int nc,ncz;
1373     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1374     int flind;
1375
1376     nc  = 2 + (order - 2)/FLBS;
1377     ncz = 2 + (order - 2)/FLBSZ;
1378
1379     for(fsx=fx; fsx<fx+nc; fsx++)
1380     {
1381         for(fsy=fy; fsy<fy+nc; fsy++)
1382         {
1383             for(fsz=fz; fsz<fz+ncz; fsz++)
1384             {
1385                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1386                 if (flag[flind] == 0)
1387                 {
1388                     gx = fsx*FLBS;
1389                     gy = fsy*FLBS;
1390                     gz = fsz*FLBSZ;
1391                     g0x = (gx*ny + gy)*nz + gz;
1392                     for(x=0; x<FLBS; x++)
1393                     {
1394                         g0y = g0x;
1395                         for(y=0; y<FLBS; y++)
1396                         {
1397                             for(z=0; z<FLBSZ; z++)
1398                             {
1399                                 grid[g0y+z] = 0;
1400                             }
1401                             g0y += nz;
1402                         }
1403                         g0x += ny*nz;
1404                     }
1405
1406                     flag[flind] = 1;
1407                 }
1408             }
1409         }
1410     }
1411 }
1412
1413 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1414 #define DO_BSPLINE(order)                            \
1415 for(ithx=0; (ithx<order); ithx++)                    \
1416 {                                                    \
1417     index_x = (i0+ithx)*pny*pnz;                     \
1418     valx    = qn*thx[ithx];                          \
1419                                                      \
1420     for(ithy=0; (ithy<order); ithy++)                \
1421     {                                                \
1422         valxy    = valx*thy[ithy];                   \
1423         index_xy = index_x+(j0+ithy)*pnz;            \
1424                                                      \
1425         for(ithz=0; (ithz<order); ithz++)            \
1426         {                                            \
1427             index_xyz        = index_xy+(k0+ithz);   \
1428             grid[index_xyz] += valxy*thz[ithz];      \
1429         }                                            \
1430     }                                                \
1431 }
1432
1433
1434 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1435                                      pme_atomcomm_t *atc, splinedata_t *spline,
1436                                      pme_spline_work_t *work)
1437 {
1438
1439     /* spread charges from home atoms to local grid */
1440     real     *grid;
1441     pme_overlap_t *ol;
1442     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1443     int *    idxptr;
1444     int      order,norder,index_x,index_xy,index_xyz;
1445     real     valx,valxy,qn;
1446     real     *thx,*thy,*thz;
1447     int      localsize, bndsize;
1448     int      pnx,pny,pnz,ndatatot;
1449     int      offx,offy,offz;
1450
1451     pnx = pmegrid->s[XX];
1452     pny = pmegrid->s[YY];
1453     pnz = pmegrid->s[ZZ];
1454
1455     offx = pmegrid->offset[XX];
1456     offy = pmegrid->offset[YY];
1457     offz = pmegrid->offset[ZZ];
1458
1459     ndatatot = pnx*pny*pnz;
1460     grid = pmegrid->grid;
1461     for(i=0;i<ndatatot;i++)
1462     {
1463         grid[i] = 0;
1464     }
1465     
1466     order = pmegrid->order;
1467
1468     for(nn=0; nn<spline->n; nn++)
1469     {
1470         n  = spline->ind[nn];
1471         qn = atc->q[n];
1472
1473         if (qn != 0)
1474         {
1475             idxptr = atc->idx[n];
1476             norder = nn*order;
1477
1478             i0   = idxptr[XX] - offx;
1479             j0   = idxptr[YY] - offy;
1480             k0   = idxptr[ZZ] - offz;
1481
1482             thx = spline->theta[XX] + norder;
1483             thy = spline->theta[YY] + norder;
1484             thz = spline->theta[ZZ] + norder;
1485
1486             switch (order) {
1487             case 4:
1488 #ifdef PME_SSE
1489 #ifdef PME_SSE_UNALIGNED
1490 #define PME_SPREAD_SSE_ORDER4
1491 #else
1492 #define PME_SPREAD_SSE_ALIGNED
1493 #define PME_ORDER 4
1494 #endif
1495 #include "pme_sse_single.h"
1496 #else
1497                 DO_BSPLINE(4);
1498 #endif
1499                 break;
1500             case 5:
1501 #ifdef PME_SSE
1502 #define PME_SPREAD_SSE_ALIGNED
1503 #define PME_ORDER 5
1504 #include "pme_sse_single.h"
1505 #else
1506                 DO_BSPLINE(5);
1507 #endif
1508                 break;
1509             default:
1510                 DO_BSPLINE(order);
1511                 break;
1512             }
1513         }
1514     }
1515 }
1516
1517 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1518 {
1519 #ifdef PME_SSE
1520     if (pme_order == 5
1521 #ifndef PME_SSE_UNALIGNED
1522         || pme_order == 4
1523 #endif
1524         )
1525     {
1526         /* Round nz up to a multiple of 4 to ensure alignment */
1527         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1528     }
1529 #endif
1530 }
1531
1532 static void set_gridsize_alignment(int *gridsize,int pme_order)
1533 {
1534 #ifdef PME_SSE
1535 #ifndef PME_SSE_UNALIGNED
1536     if (pme_order == 4)
1537     {
1538         /* Add extra elements to ensured aligned operations do not go
1539          * beyond the allocated grid size.
1540          * Note that for pme_order=5, the pme grid z-size alignment
1541          * ensures that we will not go beyond the grid size.
1542          */
1543          *gridsize += 4;
1544     }
1545 #endif
1546 #endif
1547 }
1548
1549 static void pmegrid_init(pmegrid_t *grid,
1550                          int cx, int cy, int cz,
1551                          int x0, int y0, int z0,
1552                          int x1, int y1, int z1,
1553                          gmx_bool set_alignment,
1554                          int pme_order,
1555                          real *ptr)
1556 {
1557     int nz,gridsize;
1558
1559     grid->ci[XX] = cx;
1560     grid->ci[YY] = cy;
1561     grid->ci[ZZ] = cz;
1562     grid->offset[XX] = x0;
1563     grid->offset[YY] = y0;
1564     grid->offset[ZZ] = z0;
1565     grid->n[XX]      = x1 - x0 + pme_order - 1;
1566     grid->n[YY]      = y1 - y0 + pme_order - 1;
1567     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1568     copy_ivec(grid->n,grid->s);
1569
1570     nz = grid->s[ZZ];
1571     set_grid_alignment(&nz,pme_order);
1572     if (set_alignment)
1573     {
1574         grid->s[ZZ] = nz;
1575     }
1576     else if (nz != grid->s[ZZ])
1577     {
1578         gmx_incons("pmegrid_init call with an unaligned z size");
1579     }
1580
1581     grid->order = pme_order;
1582     if (ptr == NULL)
1583     {
1584         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1585         set_gridsize_alignment(&gridsize,pme_order);
1586         snew_aligned(grid->grid,gridsize,16);
1587     }
1588     else
1589     {
1590         grid->grid = ptr;
1591     }
1592 }
1593
1594 static int div_round_up(int enumerator,int denominator)
1595 {
1596     return (enumerator + denominator - 1)/denominator;
1597 }
1598
1599 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1600                                   ivec nsub)
1601 {
1602     int gsize_opt,gsize;
1603     int nsx,nsy,nsz;
1604     char *env;
1605
1606     gsize_opt = -1;
1607     for(nsx=1; nsx<=nthread; nsx++)
1608     {
1609         if (nthread % nsx == 0)
1610         {
1611             for(nsy=1; nsy<=nthread; nsy++)
1612             {
1613                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1614                 {
1615                     nsz = nthread/(nsx*nsy);
1616
1617                     /* Determine the number of grid points per thread */
1618                     gsize =
1619                         (div_round_up(n[XX],nsx) + ovl)*
1620                         (div_round_up(n[YY],nsy) + ovl)*
1621                         (div_round_up(n[ZZ],nsz) + ovl);
1622
1623                     /* Minimize the number of grids points per thread
1624                      * and, secondarily, the number of cuts in minor dimensions.
1625                      */
1626                     if (gsize_opt == -1 ||
1627                         gsize < gsize_opt ||
1628                         (gsize == gsize_opt &&
1629                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1630                     {
1631                         nsub[XX] = nsx;
1632                         nsub[YY] = nsy;
1633                         nsub[ZZ] = nsz;
1634                         gsize_opt = gsize;
1635                     }
1636                 }
1637             }
1638         }
1639     }
1640
1641     env = getenv("GMX_PME_THREAD_DIVISION");
1642     if (env != NULL)
1643     {
1644         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1645     }
1646
1647     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1648     {
1649         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1650     }
1651 }
1652
1653 static void pmegrids_init(pmegrids_t *grids,
1654                           int nx,int ny,int nz,int nz_base,
1655                           int pme_order,
1656                           int nthread,
1657                           int overlap_x,
1658                           int overlap_y)
1659 {
1660     ivec n,n_base,g0,g1;
1661     int t,x,y,z,d,i,tfac;
1662     int max_comm_lines=-1;
1663
1664     n[XX] = nx - (pme_order - 1);
1665     n[YY] = ny - (pme_order - 1);
1666     n[ZZ] = nz - (pme_order - 1);
1667
1668     copy_ivec(n,n_base);
1669     n_base[ZZ] = nz_base;
1670
1671     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1672                  NULL);
1673
1674     grids->nthread = nthread;
1675
1676     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1677
1678     if (grids->nthread > 1)
1679     {
1680         ivec nst;
1681         int gridsize;
1682
1683         for(d=0; d<DIM; d++)
1684         {
1685             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1686         }
1687         set_grid_alignment(&nst[ZZ],pme_order);
1688
1689         if (debug)
1690         {
1691             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1692                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1693             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1694                     nx,ny,nz,
1695                     nst[XX],nst[YY],nst[ZZ]);
1696         }
1697
1698         snew(grids->grid_th,grids->nthread);
1699         t = 0;
1700         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1701         set_gridsize_alignment(&gridsize,pme_order);
1702         snew_aligned(grids->grid_all,
1703                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1704                      16);
1705
1706         for(x=0; x<grids->nc[XX]; x++)
1707         {
1708             for(y=0; y<grids->nc[YY]; y++)
1709             {
1710                 for(z=0; z<grids->nc[ZZ]; z++)
1711                 {
1712                     pmegrid_init(&grids->grid_th[t],
1713                                  x,y,z,
1714                                  (n[XX]*(x  ))/grids->nc[XX],
1715                                  (n[YY]*(y  ))/grids->nc[YY],
1716                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1717                                  (n[XX]*(x+1))/grids->nc[XX],
1718                                  (n[YY]*(y+1))/grids->nc[YY],
1719                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1720                                  TRUE,
1721                                  pme_order,
1722                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1723                     t++;
1724                 }
1725             }
1726         }
1727     }
1728
1729     snew(grids->g2t,DIM);
1730     tfac = 1;
1731     for(d=DIM-1; d>=0; d--)
1732     {
1733         snew(grids->g2t[d],n[d]);
1734         t = 0;
1735         for(i=0; i<n[d]; i++)
1736         {
1737             /* The second check should match the parameters
1738              * of the pmegrid_init call above.
1739              */
1740             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1741             {
1742                 t++;
1743             }
1744             grids->g2t[d][i] = t*tfac;
1745         }
1746
1747         tfac *= grids->nc[d];
1748
1749         switch (d)
1750         {
1751         case XX: max_comm_lines = overlap_x;     break;
1752         case YY: max_comm_lines = overlap_y;     break;
1753         case ZZ: max_comm_lines = pme_order - 1; break;
1754         }
1755         grids->nthread_comm[d] = 0;
1756         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1757                grids->nthread_comm[d] < grids->nc[d])
1758         {
1759             grids->nthread_comm[d]++;
1760         }
1761         if (debug != NULL)
1762         {
1763             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1764                     'x'+d,grids->nthread_comm[d]);
1765         }
1766         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1767          * work, but this is not a problematic restriction.
1768          */
1769         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1770         {
1771             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1772         }
1773     }
1774 }
1775
1776
1777 static void pmegrids_destroy(pmegrids_t *grids)
1778 {
1779     int t;
1780
1781     if (grids->grid.grid != NULL)
1782     {
1783         sfree(grids->grid.grid);
1784
1785         if (grids->nthread > 0)
1786         {
1787             for(t=0; t<grids->nthread; t++)
1788             {
1789                 sfree(grids->grid_th[t].grid);
1790             }
1791             sfree(grids->grid_th);
1792         }
1793     }
1794 }
1795
1796
1797 static void realloc_work(pme_work_t *work,int nkx)
1798 {
1799     if (nkx > work->nalloc)
1800     {
1801         work->nalloc = nkx;
1802         srenew(work->mhx  ,work->nalloc);
1803         srenew(work->mhy  ,work->nalloc);
1804         srenew(work->mhz  ,work->nalloc);
1805         srenew(work->m2   ,work->nalloc);
1806         /* Allocate an aligned pointer for SSE operations, including 3 extra
1807          * elements at the end since SSE operates on 4 elements at a time.
1808          */
1809         sfree_aligned(work->denom);
1810         sfree_aligned(work->tmp1);
1811         sfree_aligned(work->eterm);
1812         snew_aligned(work->denom,work->nalloc+3,16);
1813         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1814         snew_aligned(work->eterm,work->nalloc+3,16);
1815         srenew(work->m2inv,work->nalloc);
1816     }
1817 }
1818
1819
1820 static void free_work(pme_work_t *work)
1821 {
1822     sfree(work->mhx);
1823     sfree(work->mhy);
1824     sfree(work->mhz);
1825     sfree(work->m2);
1826     sfree_aligned(work->denom);
1827     sfree_aligned(work->tmp1);
1828     sfree_aligned(work->eterm);
1829     sfree(work->m2inv);
1830 }
1831
1832
1833 #ifdef PME_SSE
1834     /* Calculate exponentials through SSE in float precision */
1835 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1836 {
1837     {
1838         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1839         __m128 f_sse;
1840         __m128 lu;
1841         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1842         int kx;
1843         f_sse = _mm_load1_ps(&f);
1844         for(kx=0; kx<end; kx+=4)
1845         {
1846             tmp_d1   = _mm_load_ps(d_aligned+kx);
1847             lu       = _mm_rcp_ps(tmp_d1);
1848             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1849             tmp_r    = _mm_load_ps(r_aligned+kx);
1850             tmp_r    = gmx_mm_exp_ps(tmp_r);
1851             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1852             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1853             _mm_store_ps(e_aligned+kx,tmp_e);
1854         }
1855     }
1856 }
1857 #else
1858 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1859 {
1860     int kx;
1861     for(kx=start; kx<end; kx++)
1862     {
1863         d[kx] = 1.0/d[kx];
1864     }
1865     for(kx=start; kx<end; kx++)
1866     {
1867         r[kx] = exp(r[kx]);
1868     }
1869     for(kx=start; kx<end; kx++)
1870     {
1871         e[kx] = f*r[kx]*d[kx];
1872     }
1873 }
1874 #endif
1875
1876
1877 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1878                          real ewaldcoeff,real vol,
1879                          gmx_bool bEnerVir,
1880                          int nthread,int thread)
1881 {
1882     /* do recip sum over local cells in grid */
1883     /* y major, z middle, x minor or continuous */
1884     t_complex *p0;
1885     int     kx,ky,kz,maxkx,maxky,maxkz;
1886     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1887     real    mx,my,mz;
1888     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1889     real    ets2,struct2,vfactor,ets2vf;
1890     real    d1,d2,energy=0;
1891     real    by,bz;
1892     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1893     real    rxx,ryx,ryy,rzx,rzy,rzz;
1894     pme_work_t *work;
1895     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1896     real    mhxk,mhyk,mhzk,m2k;
1897     real    corner_fac;
1898     ivec    complex_order;
1899     ivec    local_ndata,local_offset,local_size;
1900     real    elfac;
1901
1902     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1903
1904     nx = pme->nkx;
1905     ny = pme->nky;
1906     nz = pme->nkz;
1907
1908     /* Dimensions should be identical for A/B grid, so we just use A here */
1909     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1910                                       complex_order,
1911                                       local_ndata,
1912                                       local_offset,
1913                                       local_size);
1914
1915     rxx = pme->recipbox[XX][XX];
1916     ryx = pme->recipbox[YY][XX];
1917     ryy = pme->recipbox[YY][YY];
1918     rzx = pme->recipbox[ZZ][XX];
1919     rzy = pme->recipbox[ZZ][YY];
1920     rzz = pme->recipbox[ZZ][ZZ];
1921
1922     maxkx = (nx+1)/2;
1923     maxky = (ny+1)/2;
1924     maxkz = nz/2+1;
1925
1926     work = &pme->work[thread];
1927     mhx   = work->mhx;
1928     mhy   = work->mhy;
1929     mhz   = work->mhz;
1930     m2    = work->m2;
1931     denom = work->denom;
1932     tmp1  = work->tmp1;
1933     eterm = work->eterm;
1934     m2inv = work->m2inv;
1935
1936     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1937     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1938
1939     for(iyz=iyz0; iyz<iyz1; iyz++)
1940     {
1941         iy = iyz/local_ndata[ZZ];
1942         iz = iyz - iy*local_ndata[ZZ];
1943
1944         ky = iy + local_offset[YY];
1945
1946         if (ky < maxky)
1947         {
1948             my = ky;
1949         }
1950         else
1951         {
1952             my = (ky - ny);
1953         }
1954
1955         by = M_PI*vol*pme->bsp_mod[YY][ky];
1956
1957         kz = iz + local_offset[ZZ];
1958
1959         mz = kz;
1960
1961         bz = pme->bsp_mod[ZZ][kz];
1962
1963         /* 0.5 correction for corner points */
1964         corner_fac = 1;
1965         if (kz == 0 || kz == (nz+1)/2)
1966         {
1967             corner_fac = 0.5;
1968         }
1969
1970         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1971
1972         /* We should skip the k-space point (0,0,0) */
1973         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1974         {
1975             kxstart = local_offset[XX];
1976         }
1977         else
1978         {
1979             kxstart = local_offset[XX] + 1;
1980             p0++;
1981         }
1982         kxend = local_offset[XX] + local_ndata[XX];
1983
1984         if (bEnerVir)
1985         {
1986             /* More expensive inner loop, especially because of the storage
1987              * of the mh elements in array's.
1988              * Because x is the minor grid index, all mh elements
1989              * depend on kx for triclinic unit cells.
1990              */
1991
1992                 /* Two explicit loops to avoid a conditional inside the loop */
1993             for(kx=kxstart; kx<maxkx; kx++)
1994             {
1995                 mx = kx;
1996
1997                 mhxk      = mx * rxx;
1998                 mhyk      = mx * ryx + my * ryy;
1999                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2000                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2001                 mhx[kx]   = mhxk;
2002                 mhy[kx]   = mhyk;
2003                 mhz[kx]   = mhzk;
2004                 m2[kx]    = m2k;
2005                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2006                 tmp1[kx]  = -factor*m2k;
2007             }
2008
2009             for(kx=maxkx; kx<kxend; kx++)
2010             {
2011                 mx = (kx - nx);
2012
2013                 mhxk      = mx * rxx;
2014                 mhyk      = mx * ryx + my * ryy;
2015                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2016                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2017                 mhx[kx]   = mhxk;
2018                 mhy[kx]   = mhyk;
2019                 mhz[kx]   = mhzk;
2020                 m2[kx]    = m2k;
2021                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2022                 tmp1[kx]  = -factor*m2k;
2023             }
2024
2025             for(kx=kxstart; kx<kxend; kx++)
2026             {
2027                 m2inv[kx] = 1.0/m2[kx];
2028             }
2029
2030             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2031
2032             for(kx=kxstart; kx<kxend; kx++,p0++)
2033             {
2034                 d1      = p0->re;
2035                 d2      = p0->im;
2036
2037                 p0->re  = d1*eterm[kx];
2038                 p0->im  = d2*eterm[kx];
2039
2040                 struct2 = 2.0*(d1*d1+d2*d2);
2041
2042                 tmp1[kx] = eterm[kx]*struct2;
2043             }
2044
2045             for(kx=kxstart; kx<kxend; kx++)
2046             {
2047                 ets2     = corner_fac*tmp1[kx];
2048                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2049                 energy  += ets2;
2050
2051                 ets2vf   = ets2*vfactor;
2052                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2053                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2054                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2055                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2056                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2057                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2058             }
2059         }
2060         else
2061         {
2062             /* We don't need to calculate the energy and the virial.
2063              * In this case the triclinic overhead is small.
2064              */
2065
2066             /* Two explicit loops to avoid a conditional inside the loop */
2067
2068             for(kx=kxstart; kx<maxkx; kx++)
2069             {
2070                 mx = kx;
2071
2072                 mhxk      = mx * rxx;
2073                 mhyk      = mx * ryx + my * ryy;
2074                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2075                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2076                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2077                 tmp1[kx]  = -factor*m2k;
2078             }
2079
2080             for(kx=maxkx; kx<kxend; kx++)
2081             {
2082                 mx = (kx - nx);
2083
2084                 mhxk      = mx * rxx;
2085                 mhyk      = mx * ryx + my * ryy;
2086                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2087                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2088                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2089                 tmp1[kx]  = -factor*m2k;
2090             }
2091
2092             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2093
2094             for(kx=kxstart; kx<kxend; kx++,p0++)
2095             {
2096                 d1      = p0->re;
2097                 d2      = p0->im;
2098
2099                 p0->re  = d1*eterm[kx];
2100                 p0->im  = d2*eterm[kx];
2101             }
2102         }
2103     }
2104
2105     if (bEnerVir)
2106     {
2107         /* Update virial with local values.
2108          * The virial is symmetric by definition.
2109          * this virial seems ok for isotropic scaling, but I'm
2110          * experiencing problems on semiisotropic membranes.
2111          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2112          */
2113         work->vir[XX][XX] = 0.25*virxx;
2114         work->vir[YY][YY] = 0.25*viryy;
2115         work->vir[ZZ][ZZ] = 0.25*virzz;
2116         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2117         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2118         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2119
2120         /* This energy should be corrected for a charged system */
2121         work->energy = 0.5*energy;
2122     }
2123
2124     /* Return the loop count */
2125     return local_ndata[YY]*local_ndata[XX];
2126 }
2127
2128 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2129                              real *mesh_energy,matrix vir)
2130 {
2131     /* This function sums output over threads
2132      * and should therefore only be called after thread synchronization.
2133      */
2134     int thread;
2135
2136     *mesh_energy = pme->work[0].energy;
2137     copy_mat(pme->work[0].vir,vir);
2138
2139     for(thread=1; thread<nthread; thread++)
2140     {
2141         *mesh_energy += pme->work[thread].energy;
2142         m_add(vir,pme->work[thread].vir,vir);
2143     }
2144 }
2145
2146 #define DO_FSPLINE(order)                      \
2147 for(ithx=0; (ithx<order); ithx++)              \
2148 {                                              \
2149     index_x = (i0+ithx)*pny*pnz;               \
2150     tx      = thx[ithx];                       \
2151     dx      = dthx[ithx];                      \
2152                                                \
2153     for(ithy=0; (ithy<order); ithy++)          \
2154     {                                          \
2155         index_xy = index_x+(j0+ithy)*pnz;      \
2156         ty       = thy[ithy];                  \
2157         dy       = dthy[ithy];                 \
2158         fxy1     = fz1 = 0;                    \
2159                                                \
2160         for(ithz=0; (ithz<order); ithz++)      \
2161         {                                      \
2162             gval  = grid[index_xy+(k0+ithz)];  \
2163             fxy1 += thz[ithz]*gval;            \
2164             fz1  += dthz[ithz]*gval;           \
2165         }                                      \
2166         fx += dx*ty*fxy1;                      \
2167         fy += tx*dy*fxy1;                      \
2168         fz += tx*ty*fz1;                       \
2169     }                                          \
2170 }
2171
2172
2173 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2174                               gmx_bool bClearF,pme_atomcomm_t *atc,
2175                               splinedata_t *spline,
2176                               real scale)
2177 {
2178     /* sum forces for local particles */
2179     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2180     int     index_x,index_xy;
2181     int     nx,ny,nz,pnx,pny,pnz;
2182     int *   idxptr;
2183     real    tx,ty,dx,dy,qn;
2184     real    fx,fy,fz,gval;
2185     real    fxy1,fz1;
2186     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2187     int     norder;
2188     real    rxx,ryx,ryy,rzx,rzy,rzz;
2189     int     order;
2190
2191     pme_spline_work_t *work;
2192
2193     work = pme->spline_work;
2194
2195     order = pme->pme_order;
2196     thx   = spline->theta[XX];
2197     thy   = spline->theta[YY];
2198     thz   = spline->theta[ZZ];
2199     dthx  = spline->dtheta[XX];
2200     dthy  = spline->dtheta[YY];
2201     dthz  = spline->dtheta[ZZ];
2202     nx    = pme->nkx;
2203     ny    = pme->nky;
2204     nz    = pme->nkz;
2205     pnx   = pme->pmegrid_nx;
2206     pny   = pme->pmegrid_ny;
2207     pnz   = pme->pmegrid_nz;
2208
2209     rxx   = pme->recipbox[XX][XX];
2210     ryx   = pme->recipbox[YY][XX];
2211     ryy   = pme->recipbox[YY][YY];
2212     rzx   = pme->recipbox[ZZ][XX];
2213     rzy   = pme->recipbox[ZZ][YY];
2214     rzz   = pme->recipbox[ZZ][ZZ];
2215
2216     for(nn=0; nn<spline->n; nn++)
2217     {
2218         n  = spline->ind[nn];
2219         qn = scale*atc->q[n];
2220
2221         if (bClearF)
2222         {
2223             atc->f[n][XX] = 0;
2224             atc->f[n][YY] = 0;
2225             atc->f[n][ZZ] = 0;
2226         }
2227         if (qn != 0)
2228         {
2229             fx     = 0;
2230             fy     = 0;
2231             fz     = 0;
2232             idxptr = atc->idx[n];
2233             norder = nn*order;
2234
2235             i0   = idxptr[XX];
2236             j0   = idxptr[YY];
2237             k0   = idxptr[ZZ];
2238
2239             /* Pointer arithmetic alert, next six statements */
2240             thx  = spline->theta[XX] + norder;
2241             thy  = spline->theta[YY] + norder;
2242             thz  = spline->theta[ZZ] + norder;
2243             dthx = spline->dtheta[XX] + norder;
2244             dthy = spline->dtheta[YY] + norder;
2245             dthz = spline->dtheta[ZZ] + norder;
2246
2247             switch (order) {
2248             case 4:
2249 #ifdef PME_SSE
2250 #ifdef PME_SSE_UNALIGNED
2251 #define PME_GATHER_F_SSE_ORDER4
2252 #else
2253 #define PME_GATHER_F_SSE_ALIGNED
2254 #define PME_ORDER 4
2255 #endif
2256 #include "pme_sse_single.h"
2257 #else
2258                 DO_FSPLINE(4);
2259 #endif
2260                 break;
2261             case 5:
2262 #ifdef PME_SSE
2263 #define PME_GATHER_F_SSE_ALIGNED
2264 #define PME_ORDER 5
2265 #include "pme_sse_single.h"
2266 #else
2267                 DO_FSPLINE(5);
2268 #endif
2269                 break;
2270             default:
2271                 DO_FSPLINE(order);
2272                 break;
2273             }
2274
2275             atc->f[n][XX] += -qn*( fx*nx*rxx );
2276             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2277             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2278         }
2279     }
2280     /* Since the energy and not forces are interpolated
2281      * the net force might not be exactly zero.
2282      * This can be solved by also interpolating F, but
2283      * that comes at a cost.
2284      * A better hack is to remove the net force every
2285      * step, but that must be done at a higher level
2286      * since this routine doesn't see all atoms if running
2287      * in parallel. Don't know how important it is?  EL 990726
2288      */
2289 }
2290
2291
2292 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2293                                    pme_atomcomm_t *atc)
2294 {
2295     splinedata_t *spline;
2296     int     n,ithx,ithy,ithz,i0,j0,k0;
2297     int     index_x,index_xy;
2298     int *   idxptr;
2299     real    energy,pot,tx,ty,qn,gval;
2300     real    *thx,*thy,*thz;
2301     int     norder;
2302     int     order;
2303
2304     spline = &atc->spline[0];
2305
2306     order = pme->pme_order;
2307
2308     energy = 0;
2309     for(n=0; (n<atc->n); n++) {
2310         qn      = atc->q[n];
2311
2312         if (qn != 0) {
2313             idxptr = atc->idx[n];
2314             norder = n*order;
2315
2316             i0   = idxptr[XX];
2317             j0   = idxptr[YY];
2318             k0   = idxptr[ZZ];
2319
2320             /* Pointer arithmetic alert, next three statements */
2321             thx  = spline->theta[XX] + norder;
2322             thy  = spline->theta[YY] + norder;
2323             thz  = spline->theta[ZZ] + norder;
2324
2325             pot = 0;
2326             for(ithx=0; (ithx<order); ithx++)
2327             {
2328                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2329                 tx      = thx[ithx];
2330
2331                 for(ithy=0; (ithy<order); ithy++)
2332                 {
2333                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2334                     ty       = thy[ithy];
2335
2336                     for(ithz=0; (ithz<order); ithz++)
2337                     {
2338                         gval  = grid[index_xy+(k0+ithz)];
2339                         pot  += tx*ty*thz[ithz]*gval;
2340                     }
2341
2342                 }
2343             }
2344
2345             energy += pot*qn;
2346         }
2347     }
2348
2349     return energy;
2350 }
2351
2352 /* Macro to force loop unrolling by fixing order.
2353  * This gives a significant performance gain.
2354  */
2355 #define CALC_SPLINE(order)                     \
2356 {                                              \
2357     int j,k,l;                                 \
2358     real dr,div;                               \
2359     real data[PME_ORDER_MAX];                  \
2360     real ddata[PME_ORDER_MAX];                 \
2361                                                \
2362     for(j=0; (j<DIM); j++)                     \
2363     {                                          \
2364         dr  = xptr[j];                         \
2365                                                \
2366         /* dr is relative offset from lower cell limit */ \
2367         data[order-1] = 0;                     \
2368         data[1] = dr;                          \
2369         data[0] = 1 - dr;                      \
2370                                                \
2371         for(k=3; (k<order); k++)               \
2372         {                                      \
2373             div = 1.0/(k - 1.0);               \
2374             data[k-1] = div*dr*data[k-2];      \
2375             for(l=1; (l<(k-1)); l++)           \
2376             {                                  \
2377                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2378                                    data[k-l-1]);                \
2379             }                                  \
2380             data[0] = div*(1-dr)*data[0];      \
2381         }                                      \
2382         /* differentiate */                    \
2383         ddata[0] = -data[0];                   \
2384         for(k=1; (k<order); k++)               \
2385         {                                      \
2386             ddata[k] = data[k-1] - data[k];    \
2387         }                                      \
2388                                                \
2389         div = 1.0/(order - 1);                 \
2390         data[order-1] = div*dr*data[order-2];  \
2391         for(l=1; (l<(order-1)); l++)           \
2392         {                                      \
2393             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2394                                (order-l-dr)*data[order-l-1]); \
2395         }                                      \
2396         data[0] = div*(1 - dr)*data[0];        \
2397                                                \
2398         for(k=0; k<order; k++)                 \
2399         {                                      \
2400             theta[j][i*order+k]  = data[k];    \
2401             dtheta[j][i*order+k] = ddata[k];   \
2402         }                                      \
2403     }                                          \
2404 }
2405
2406 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2407                    rvec fractx[],int nr,int ind[],real charge[],
2408                    gmx_bool bFreeEnergy)
2409 {
2410     /* construct splines for local atoms */
2411     int  i,ii;
2412     real *xptr;
2413
2414     for(i=0; i<nr; i++)
2415     {
2416         /* With free energy we do not use the charge check.
2417          * In most cases this will be more efficient than calling make_bsplines
2418          * twice, since usually more than half the particles have charges.
2419          */
2420         ii = ind[i];
2421         if (bFreeEnergy || charge[ii] != 0.0) {
2422             xptr = fractx[ii];
2423             switch(order) {
2424             case 4:  CALC_SPLINE(4);     break;
2425             case 5:  CALC_SPLINE(5);     break;
2426             default: CALC_SPLINE(order); break;
2427             }
2428         }
2429     }
2430 }
2431
2432
2433 void make_dft_mod(real *mod,real *data,int ndata)
2434 {
2435   int i,j;
2436   real sc,ss,arg;
2437
2438   for(i=0;i<ndata;i++) {
2439     sc=ss=0;
2440     for(j=0;j<ndata;j++) {
2441       arg=(2.0*M_PI*i*j)/ndata;
2442       sc+=data[j]*cos(arg);
2443       ss+=data[j]*sin(arg);
2444     }
2445     mod[i]=sc*sc+ss*ss;
2446   }
2447   for(i=0;i<ndata;i++)
2448     if(mod[i]<1e-7)
2449       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2450 }
2451
2452
2453 static void make_bspline_moduli(splinevec bsp_mod,
2454                                 int nx,int ny,int nz,int order)
2455 {
2456   int nmax=max(nx,max(ny,nz));
2457   real *data,*ddata,*bsp_data;
2458   int i,k,l;
2459   real div;
2460
2461   snew(data,order);
2462   snew(ddata,order);
2463   snew(bsp_data,nmax);
2464
2465   data[order-1]=0;
2466   data[1]=0;
2467   data[0]=1;
2468
2469   for(k=3;k<order;k++) {
2470     div=1.0/(k-1.0);
2471     data[k-1]=0;
2472     for(l=1;l<(k-1);l++)
2473       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2474     data[0]=div*data[0];
2475   }
2476   /* differentiate */
2477   ddata[0]=-data[0];
2478   for(k=1;k<order;k++)
2479     ddata[k]=data[k-1]-data[k];
2480   div=1.0/(order-1);
2481   data[order-1]=0;
2482   for(l=1;l<(order-1);l++)
2483     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2484   data[0]=div*data[0];
2485
2486   for(i=0;i<nmax;i++)
2487     bsp_data[i]=0;
2488   for(i=1;i<=order;i++)
2489     bsp_data[i]=data[i-1];
2490
2491   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2492   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2493   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2494
2495   sfree(data);
2496   sfree(ddata);
2497   sfree(bsp_data);
2498 }
2499
2500
2501 /* Return the P3M optimal influence function */
2502 static double do_p3m_influence(double z, int order)
2503 {
2504     double z2,z4;
2505
2506     z2 = z*z;
2507     z4 = z2*z2;
2508
2509     /* The formula and most constants can be found in:
2510      * Ballenegger et al., JCTC 8, 936 (2012)
2511      */
2512     switch(order)
2513     {
2514     case 2:
2515         return 1.0 - 2.0*z2/3.0;
2516         break;
2517     case 3:
2518         return 1.0 - z2 + 2.0*z4/15.0;
2519         break;
2520     case 4:
2521         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2522         break;
2523     case 5:
2524         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2525         break;
2526     case 6:
2527         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2528         break;
2529     case 7:
2530         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2531     case 8:
2532         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2533         break;
2534     }
2535
2536     return 0.0;
2537 }
2538
2539 /* Calculate the P3M B-spline moduli for one dimension */
2540 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2541 {
2542     double zarg,zai,sinzai,infl;
2543     int    maxk,i;
2544
2545     if (order > 8)
2546     {
2547         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2548     }
2549
2550     zarg = M_PI/n;
2551
2552     maxk = (n + 1)/2;
2553
2554     for(i=-maxk; i<0; i++)
2555     {
2556         zai    = zarg*i;
2557         sinzai = sin(zai);
2558         infl   = do_p3m_influence(sinzai,order);
2559         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2560     }
2561     bsp_mod[0] = 1.0;
2562     for(i=1; i<maxk; i++)
2563     {
2564         zai    = zarg*i;
2565         sinzai = sin(zai);
2566         infl   = do_p3m_influence(sinzai,order);
2567         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2568     }
2569 }
2570
2571 /* Calculate the P3M B-spline moduli */
2572 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2573                                     int nx,int ny,int nz,int order)
2574 {
2575     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2576     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2577     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2578 }
2579
2580
2581 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2582 {
2583   int nslab,n,i;
2584   int fw,bw;
2585
2586   nslab = atc->nslab;
2587
2588   n = 0;
2589   for(i=1; i<=nslab/2; i++) {
2590     fw = (atc->nodeid + i) % nslab;
2591     bw = (atc->nodeid - i + nslab) % nslab;
2592     if (n < nslab - 1) {
2593       atc->node_dest[n] = fw;
2594       atc->node_src[n]  = bw;
2595       n++;
2596     }
2597     if (n < nslab - 1) {
2598       atc->node_dest[n] = bw;
2599       atc->node_src[n]  = fw;
2600       n++;
2601     }
2602   }
2603 }
2604
2605 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2606 {
2607     int thread;
2608
2609     if(NULL != log)
2610     {
2611         fprintf(log,"Destroying PME data structures.\n");
2612     }
2613
2614     sfree((*pmedata)->nnx);
2615     sfree((*pmedata)->nny);
2616     sfree((*pmedata)->nnz);
2617
2618     pmegrids_destroy(&(*pmedata)->pmegridA);
2619
2620     sfree((*pmedata)->fftgridA);
2621     sfree((*pmedata)->cfftgridA);
2622     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2623
2624     if ((*pmedata)->pmegridB.grid.grid != NULL)
2625     {
2626         pmegrids_destroy(&(*pmedata)->pmegridB);
2627         sfree((*pmedata)->fftgridB);
2628         sfree((*pmedata)->cfftgridB);
2629         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2630     }
2631     for(thread=0; thread<(*pmedata)->nthread; thread++)
2632     {
2633         free_work(&(*pmedata)->work[thread]);
2634     }
2635     sfree((*pmedata)->work);
2636
2637     sfree(*pmedata);
2638     *pmedata = NULL;
2639
2640   return 0;
2641 }
2642
2643 static int mult_up(int n,int f)
2644 {
2645     return ((n + f - 1)/f)*f;
2646 }
2647
2648
2649 static double pme_load_imbalance(gmx_pme_t pme)
2650 {
2651     int    nma,nmi;
2652     double n1,n2,n3;
2653
2654     nma = pme->nnodes_major;
2655     nmi = pme->nnodes_minor;
2656
2657     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2658     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2659     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2660
2661     /* pme_solve is roughly double the cost of an fft */
2662
2663     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2664 }
2665
2666 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2667                           int dimind,gmx_bool bSpread)
2668 {
2669     int nk,k,s,thread;
2670
2671     atc->dimind = dimind;
2672     atc->nslab  = 1;
2673     atc->nodeid = 0;
2674     atc->pd_nalloc = 0;
2675 #ifdef GMX_MPI
2676     if (pme->nnodes > 1)
2677     {
2678         atc->mpi_comm = pme->mpi_comm_d[dimind];
2679         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2680         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2681     }
2682     if (debug)
2683     {
2684         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2685     }
2686 #endif
2687
2688     atc->bSpread   = bSpread;
2689     atc->pme_order = pme->pme_order;
2690
2691     if (atc->nslab > 1)
2692     {
2693         /* These three allocations are not required for particle decomp. */
2694         snew(atc->node_dest,atc->nslab);
2695         snew(atc->node_src,atc->nslab);
2696         setup_coordinate_communication(atc);
2697
2698         snew(atc->count_thread,pme->nthread);
2699         for(thread=0; thread<pme->nthread; thread++)
2700         {
2701             snew(atc->count_thread[thread],atc->nslab);
2702         }
2703         atc->count = atc->count_thread[0];
2704         snew(atc->rcount,atc->nslab);
2705         snew(atc->buf_index,atc->nslab);
2706     }
2707
2708     atc->nthread = pme->nthread;
2709     if (atc->nthread > 1)
2710     {
2711         snew(atc->thread_plist,atc->nthread);
2712     }
2713     snew(atc->spline,atc->nthread);
2714     for(thread=0; thread<atc->nthread; thread++)
2715     {
2716         if (atc->nthread > 1)
2717         {
2718             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2719             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2720         }
2721         snew(atc->spline[thread].thread_one,pme->nthread);
2722         atc->spline[thread].thread_one[thread] = 1;
2723     }
2724 }
2725
2726 static void
2727 init_overlap_comm(pme_overlap_t *  ol,
2728                   int              norder,
2729 #ifdef GMX_MPI
2730                   MPI_Comm         comm,
2731 #endif
2732                   int              nnodes,
2733                   int              nodeid,
2734                   int              ndata,
2735                   int              commplainsize)
2736 {
2737     int lbnd,rbnd,maxlr,b,i;
2738     int exten;
2739     int nn,nk;
2740     pme_grid_comm_t *pgc;
2741     gmx_bool bCont;
2742     int fft_start,fft_end,send_index1,recv_index1;
2743 #ifdef GMX_MPI
2744     MPI_Status stat;
2745
2746     ol->mpi_comm = comm;
2747 #endif
2748
2749     ol->nnodes = nnodes;
2750     ol->nodeid = nodeid;
2751
2752     /* Linear translation of the PME grid won't affect reciprocal space
2753      * calculations, so to optimize we only interpolate "upwards",
2754      * which also means we only have to consider overlap in one direction.
2755      * I.e., particles on this node might also be spread to grid indices
2756      * that belong to higher nodes (modulo nnodes)
2757      */
2758
2759     snew(ol->s2g0,ol->nnodes+1);
2760     snew(ol->s2g1,ol->nnodes);
2761     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2762     for(i=0; i<nnodes; i++)
2763     {
2764         /* s2g0 the local interpolation grid start.
2765          * s2g1 the local interpolation grid end.
2766          * Because grid overlap communication only goes forward,
2767          * the grid the slabs for fft's should be rounded down.
2768          */
2769         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2770         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2771
2772         if (debug)
2773         {
2774             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2775         }
2776     }
2777     ol->s2g0[nnodes] = ndata;
2778     if (debug) { fprintf(debug,"\n"); }
2779
2780     /* Determine with how many nodes we need to communicate the grid overlap */
2781     b = 0;
2782     do
2783     {
2784         b++;
2785         bCont = FALSE;
2786         for(i=0; i<nnodes; i++)
2787         {
2788             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2789                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2790             {
2791                 bCont = TRUE;
2792             }
2793         }
2794     }
2795     while (bCont && b < nnodes);
2796     ol->noverlap_nodes = b - 1;
2797
2798     snew(ol->send_id,ol->noverlap_nodes);
2799     snew(ol->recv_id,ol->noverlap_nodes);
2800     for(b=0; b<ol->noverlap_nodes; b++)
2801     {
2802         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2803         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2804     }
2805     snew(ol->comm_data, ol->noverlap_nodes);
2806
2807     ol->send_size = 0;
2808     for(b=0; b<ol->noverlap_nodes; b++)
2809     {
2810         pgc = &ol->comm_data[b];
2811         /* Send */
2812         fft_start        = ol->s2g0[ol->send_id[b]];
2813         fft_end          = ol->s2g0[ol->send_id[b]+1];
2814         if (ol->send_id[b] < nodeid)
2815         {
2816             fft_start += ndata;
2817             fft_end   += ndata;
2818         }
2819         send_index1      = ol->s2g1[nodeid];
2820         send_index1      = min(send_index1,fft_end);
2821         pgc->send_index0 = fft_start;
2822         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2823         ol->send_size    += pgc->send_nindex;
2824
2825         /* We always start receiving to the first index of our slab */
2826         fft_start        = ol->s2g0[ol->nodeid];
2827         fft_end          = ol->s2g0[ol->nodeid+1];
2828         recv_index1      = ol->s2g1[ol->recv_id[b]];
2829         if (ol->recv_id[b] > nodeid)
2830         {
2831             recv_index1 -= ndata;
2832         }
2833         recv_index1      = min(recv_index1,fft_end);
2834         pgc->recv_index0 = fft_start;
2835         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2836     }
2837
2838 #ifdef GMX_MPI
2839     /* Communicate the buffer sizes to receive */
2840     for(b=0; b<ol->noverlap_nodes; b++)
2841     {
2842         MPI_Sendrecv(&ol->send_size             ,1,MPI_INT,ol->send_id[b],b,
2843                      &ol->comm_data[b].recv_size,1,MPI_INT,ol->recv_id[b],b,
2844                      ol->mpi_comm,&stat);
2845     }
2846 #endif
2847
2848     /* For non-divisible grid we need pme_order iso pme_order-1 */
2849     snew(ol->sendbuf,norder*commplainsize);
2850     snew(ol->recvbuf,norder*commplainsize);
2851 }
2852
2853 static void
2854 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2855                               int **global_to_local,
2856                               real **fraction_shift)
2857 {
2858     int i;
2859     int * gtl;
2860     real * fsh;
2861
2862     snew(gtl,5*n);
2863     snew(fsh,5*n);
2864     for(i=0; (i<5*n); i++)
2865     {
2866         /* Determine the global to local grid index */
2867         gtl[i] = (i - local_start + n) % n;
2868         /* For coordinates that fall within the local grid the fraction
2869          * is correct, we don't need to shift it.
2870          */
2871         fsh[i] = 0;
2872         if (local_range < n)
2873         {
2874             /* Due to rounding issues i could be 1 beyond the lower or
2875              * upper boundary of the local grid. Correct the index for this.
2876              * If we shift the index, we need to shift the fraction by
2877              * the same amount in the other direction to not affect
2878              * the weights.
2879              * Note that due to this shifting the weights at the end of
2880              * the spline might change, but that will only involve values
2881              * between zero and values close to the precision of a real,
2882              * which is anyhow the accuracy of the whole mesh calculation.
2883              */
2884             /* With local_range=0 we should not change i=local_start */
2885             if (i % n != local_start)
2886             {
2887                 if (gtl[i] == n-1)
2888                 {
2889                     gtl[i] = 0;
2890                     fsh[i] = -1;
2891                 }
2892                 else if (gtl[i] == local_range)
2893                 {
2894                     gtl[i] = local_range - 1;
2895                     fsh[i] = 1;
2896                 }
2897             }
2898         }
2899     }
2900
2901     *global_to_local = gtl;
2902     *fraction_shift  = fsh;
2903 }
2904
2905 static pme_spline_work_t *make_pme_spline_work(int order)
2906 {
2907     pme_spline_work_t *work;
2908
2909 #ifdef PME_SSE
2910     float  tmp[8];
2911     __m128 zero_SSE;
2912     int    of,i;
2913
2914     snew_aligned(work,1,16);
2915
2916     zero_SSE = _mm_setzero_ps();
2917
2918     /* Generate bit masks to mask out the unused grid entries,
2919      * as we only operate on order of the 8 grid entries that are
2920      * load into 2 SSE float registers.
2921      */
2922     for(of=0; of<8-(order-1); of++)
2923     {
2924         for(i=0; i<8; i++)
2925         {
2926             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2927         }
2928         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2929         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2930         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2931         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2932     }
2933 #else
2934     work = NULL;
2935 #endif
2936
2937     return work;
2938 }
2939
2940 static void
2941 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2942 {
2943     int nk_new;
2944
2945     if (*nk % nnodes != 0)
2946     {
2947         nk_new = nnodes*(*nk/nnodes + 1);
2948
2949         if (2*nk_new >= 3*(*nk))
2950         {
2951             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2952                       dim,*nk,dim,nnodes,dim);
2953         }
2954
2955         if (fplog != NULL)
2956         {
2957             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2958                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2959         }
2960
2961         *nk = nk_new;
2962     }
2963 }
2964
2965 int gmx_pme_init(gmx_pme_t *         pmedata,
2966                  t_commrec *         cr,
2967                  int                 nnodes_major,
2968                  int                 nnodes_minor,
2969                  t_inputrec *        ir,
2970                  int                 homenr,
2971                  gmx_bool            bFreeEnergy,
2972                  gmx_bool            bReproducible,
2973                  int                 nthread)
2974 {
2975     gmx_pme_t pme=NULL;
2976
2977     pme_atomcomm_t *atc;
2978     ivec ndata;
2979
2980     if (debug)
2981         fprintf(debug,"Creating PME data structures.\n");
2982     snew(pme,1);
2983
2984     pme->redist_init         = FALSE;
2985     pme->sum_qgrid_tmp       = NULL;
2986     pme->sum_qgrid_dd_tmp    = NULL;
2987     pme->buf_nalloc          = 0;
2988     pme->redist_buf_nalloc   = 0;
2989
2990     pme->nnodes              = 1;
2991     pme->bPPnode             = TRUE;
2992
2993     pme->nnodes_major        = nnodes_major;
2994     pme->nnodes_minor        = nnodes_minor;
2995
2996 #ifdef GMX_MPI
2997     if (nnodes_major*nnodes_minor > 1)
2998     {
2999         pme->mpi_comm = cr->mpi_comm_mygroup;
3000
3001         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
3002         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
3003         if (pme->nnodes != nnodes_major*nnodes_minor)
3004         {
3005             gmx_incons("PME node count mismatch");
3006         }
3007     }
3008     else
3009     {
3010         pme->mpi_comm = MPI_COMM_NULL;
3011     }
3012 #endif
3013
3014     if (pme->nnodes == 1)
3015     {
3016 #ifdef GMX_MPI
3017         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3018         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3019 #endif
3020         pme->ndecompdim = 0;
3021         pme->nodeid_major = 0;
3022         pme->nodeid_minor = 0;
3023 #ifdef GMX_MPI
3024         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3025 #endif
3026     }
3027     else
3028     {
3029         if (nnodes_minor == 1)
3030         {
3031 #ifdef GMX_MPI
3032             pme->mpi_comm_d[0] = pme->mpi_comm;
3033             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3034 #endif
3035             pme->ndecompdim = 1;
3036             pme->nodeid_major = pme->nodeid;
3037             pme->nodeid_minor = 0;
3038
3039         }
3040         else if (nnodes_major == 1)
3041         {
3042 #ifdef GMX_MPI
3043             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3044             pme->mpi_comm_d[1] = pme->mpi_comm;
3045 #endif
3046             pme->ndecompdim = 1;
3047             pme->nodeid_major = 0;
3048             pme->nodeid_minor = pme->nodeid;
3049         }
3050         else
3051         {
3052             if (pme->nnodes % nnodes_major != 0)
3053             {
3054                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3055             }
3056             pme->ndecompdim = 2;
3057
3058 #ifdef GMX_MPI
3059             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3060                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3061             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3062                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3063
3064             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3065             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3066             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3067             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3068 #endif
3069         }
3070         pme->bPPnode = (cr->duty & DUTY_PP);
3071     }
3072
3073     pme->nthread = nthread;
3074
3075     if (ir->ePBC == epbcSCREW)
3076     {
3077         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3078     }
3079
3080     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3081     pme->nkx         = ir->nkx;
3082     pme->nky         = ir->nky;
3083     pme->nkz         = ir->nkz;
3084     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3085     pme->pme_order   = ir->pme_order;
3086     pme->epsilon_r   = ir->epsilon_r;
3087
3088     if (pme->pme_order > PME_ORDER_MAX)
3089     {
3090         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3091                   pme->pme_order,PME_ORDER_MAX);
3092     }
3093
3094     /* Currently pme.c supports only the fft5d FFT code.
3095      * Therefore the grid always needs to be divisible by nnodes.
3096      * When the old 1D code is also supported again, change this check.
3097      *
3098      * This check should be done before calling gmx_pme_init
3099      * and fplog should be passed iso stderr.
3100      *
3101     if (pme->ndecompdim >= 2)
3102     */
3103     if (pme->ndecompdim >= 1)
3104     {
3105         /*
3106         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3107                                         'x',nnodes_major,&pme->nkx);
3108         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3109                                         'y',nnodes_minor,&pme->nky);
3110         */
3111     }
3112
3113     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3114         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3115         pme->nkz <= pme->pme_order)
3116     {
3117         gmx_fatal(FARGS,"The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",pme->pme_order);
3118     }
3119
3120     if (pme->nnodes > 1) {
3121         double imbal;
3122
3123 #ifdef GMX_MPI
3124         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3125         MPI_Type_commit(&(pme->rvec_mpi));
3126 #endif
3127
3128         /* Note that the charge spreading and force gathering, which usually
3129          * takes about the same amount of time as FFT+solve_pme,
3130          * is always fully load balanced
3131          * (unless the charge distribution is inhomogeneous).
3132          */
3133
3134         imbal = pme_load_imbalance(pme);
3135         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3136         {
3137             fprintf(stderr,
3138                     "\n"
3139                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3140                     "      For optimal PME load balancing\n"
3141                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3142                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3143                     "\n",
3144                     (int)((imbal-1)*100 + 0.5),
3145                     pme->nkx,pme->nky,pme->nnodes_major,
3146                     pme->nky,pme->nkz,pme->nnodes_minor);
3147         }
3148     }
3149
3150     /* For non-divisible grid we need pme_order iso pme_order-1 */
3151     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3152      * y is always copied through a buffer: we don't need padding in z,
3153      * but we do need the overlap in x because of the communication order.
3154      */
3155     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3156 #ifdef GMX_MPI
3157                       pme->mpi_comm_d[0],
3158 #endif
3159                       pme->nnodes_major,pme->nodeid_major,
3160                       pme->nkx,
3161                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3162
3163     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3164      * We do this with an offset buffer of equal size, so we need to allocate
3165      * extra for the offset. That's what the (+1)*pme->nkz is for.
3166      */
3167     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3168 #ifdef GMX_MPI
3169                       pme->mpi_comm_d[1],
3170 #endif
3171                       pme->nnodes_minor,pme->nodeid_minor,
3172                       pme->nky,
3173                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3174
3175     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3176      * We only allow multiple communication pulses in dim 1, not in dim 0.
3177      */
3178     if (pme->nthread > 1 && (pme->overlap[0].noverlap_nodes > 1 ||
3179                              pme->nkx < pme->nnodes_major*pme->pme_order))
3180     {
3181         gmx_fatal(FARGS,"The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x and should be >= pme_order (%d). To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3182                   pme->nkx/(double)pme->nnodes_major,pme->pme_order);
3183     }
3184
3185     snew(pme->bsp_mod[XX],pme->nkx);
3186     snew(pme->bsp_mod[YY],pme->nky);
3187     snew(pme->bsp_mod[ZZ],pme->nkz);
3188
3189     /* The required size of the interpolation grid, including overlap.
3190      * The allocated size (pmegrid_n?) might be slightly larger.
3191      */
3192     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3193                       pme->overlap[0].s2g0[pme->nodeid_major];
3194     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3195                       pme->overlap[1].s2g0[pme->nodeid_minor];
3196     pme->pmegrid_nz_base = pme->nkz;
3197     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3198     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3199
3200     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3201     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3202     pme->pmegrid_start_iz = 0;
3203
3204     make_gridindex5_to_localindex(pme->nkx,
3205                                   pme->pmegrid_start_ix,
3206                                   pme->pmegrid_nx - (pme->pme_order-1),
3207                                   &pme->nnx,&pme->fshx);
3208     make_gridindex5_to_localindex(pme->nky,
3209                                   pme->pmegrid_start_iy,
3210                                   pme->pmegrid_ny - (pme->pme_order-1),
3211                                   &pme->nny,&pme->fshy);
3212     make_gridindex5_to_localindex(pme->nkz,
3213                                   pme->pmegrid_start_iz,
3214                                   pme->pmegrid_nz_base,
3215                                   &pme->nnz,&pme->fshz);
3216
3217     pmegrids_init(&pme->pmegridA,
3218                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3219                   pme->pmegrid_nz_base,
3220                   pme->pme_order,
3221                   pme->nthread,
3222                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3223                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3224
3225     pme->spline_work = make_pme_spline_work(pme->pme_order);
3226
3227     ndata[0] = pme->nkx;
3228     ndata[1] = pme->nky;
3229     ndata[2] = pme->nkz;
3230
3231     /* This routine will allocate the grid data to fit the FFTs */
3232     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3233                             &pme->fftgridA,&pme->cfftgridA,
3234                             pme->mpi_comm_d,
3235                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3236                             bReproducible,pme->nthread);
3237
3238     if (bFreeEnergy)
3239     {
3240         pmegrids_init(&pme->pmegridB,
3241                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3242                       pme->pmegrid_nz_base,
3243                       pme->pme_order,
3244                       pme->nthread,
3245                       pme->nkx % pme->nnodes_major != 0,
3246                       pme->nky % pme->nnodes_minor != 0);
3247
3248         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3249                                 &pme->fftgridB,&pme->cfftgridB,
3250                                 pme->mpi_comm_d,
3251                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3252                                 bReproducible,pme->nthread);
3253     }
3254     else
3255     {
3256         pme->pmegridB.grid.grid = NULL;
3257         pme->fftgridB           = NULL;
3258         pme->cfftgridB          = NULL;
3259     }
3260
3261     if (!pme->bP3M)
3262     {
3263         /* Use plain SPME B-spline interpolation */
3264         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3265     }
3266     else
3267     {
3268         /* Use the P3M grid-optimized influence function */
3269         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3270     }
3271
3272     /* Use atc[0] for spreading */
3273     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3274     if (pme->ndecompdim >= 2)
3275     {
3276         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3277     }
3278
3279     if (pme->nnodes == 1) {
3280         pme->atc[0].n = homenr;
3281         pme_realloc_atomcomm_things(&pme->atc[0]);
3282     }
3283
3284     {
3285         int thread;
3286
3287         /* Use fft5d, order after FFT is y major, z, x minor */
3288
3289         snew(pme->work,pme->nthread);
3290         for(thread=0; thread<pme->nthread; thread++)
3291         {
3292             realloc_work(&pme->work[thread],pme->nkx);
3293         }
3294     }
3295
3296     *pmedata = pme;
3297     
3298     return 0;
3299 }
3300
3301 static void reuse_pmegrids(const pmegrids_t *old,pmegrids_t *new)
3302 {
3303     int d,t;
3304
3305     for(d=0; d<DIM; d++)
3306     {
3307         if (new->grid.n[d] > old->grid.n[d])
3308         {
3309             return;
3310         }
3311     }
3312
3313     sfree_aligned(new->grid.grid);
3314     new->grid.grid = old->grid.grid;
3315
3316     if (new->nthread > 1 && new->nthread == old->nthread)
3317     {
3318         sfree_aligned(new->grid_all);
3319         for(t=0; t<new->nthread; t++)
3320         {
3321             new->grid_th[t].grid = old->grid_th[t].grid;
3322         }
3323     }
3324 }
3325
3326 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3327                    t_commrec *         cr,
3328                    gmx_pme_t           pme_src,
3329                    const t_inputrec *  ir,
3330                    ivec                grid_size)
3331 {
3332     t_inputrec irc;
3333     int homenr;
3334     int ret;
3335
3336     irc = *ir;
3337     irc.nkx = grid_size[XX];
3338     irc.nky = grid_size[YY];
3339     irc.nkz = grid_size[ZZ];
3340
3341     if (pme_src->nnodes == 1)
3342     {
3343         homenr = pme_src->atc[0].n;
3344     }
3345     else
3346     {
3347         homenr = -1;
3348     }
3349
3350     ret = gmx_pme_init(pmedata,cr,pme_src->nnodes_major,pme_src->nnodes_minor,
3351                        &irc,homenr,pme_src->bFEP,FALSE,pme_src->nthread);
3352
3353     if (ret == 0)
3354     {
3355         /* We can easily reuse the allocated pme grids in pme_src */
3356         reuse_pmegrids(&pme_src->pmegridA,&(*pmedata)->pmegridA);
3357         /* We would like to reuse the fft grids, but that's harder */
3358     }
3359
3360     return ret;
3361 }
3362
3363
3364 static void copy_local_grid(gmx_pme_t pme,
3365                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3366 {
3367     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3368     int  fft_my,fft_mz;
3369     int  nsx,nsy,nsz;
3370     ivec nf;
3371     int  offx,offy,offz,x,y,z,i0,i0t;
3372     int  d;
3373     pmegrid_t *pmegrid;
3374     real *grid_th;
3375
3376     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3377                                    local_fft_ndata,
3378                                    local_fft_offset,
3379                                    local_fft_size);
3380     fft_my = local_fft_size[YY];
3381     fft_mz = local_fft_size[ZZ];
3382
3383     pmegrid = &pmegrids->grid_th[thread];
3384
3385     nsx = pmegrid->s[XX];
3386     nsy = pmegrid->s[YY];
3387     nsz = pmegrid->s[ZZ];
3388
3389     for(d=0; d<DIM; d++)
3390     {
3391         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3392                     local_fft_ndata[d] - pmegrid->offset[d]);
3393     }
3394
3395     offx = pmegrid->offset[XX];
3396     offy = pmegrid->offset[YY];
3397     offz = pmegrid->offset[ZZ];
3398
3399     /* Directly copy the non-overlapping parts of the local grids.
3400      * This also initializes the full grid.
3401      */
3402     grid_th = pmegrid->grid;
3403     for(x=0; x<nf[XX]; x++)
3404     {
3405         for(y=0; y<nf[YY]; y++)
3406         {
3407             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3408             i0t = (x*nsy + y)*nsz;
3409             for(z=0; z<nf[ZZ]; z++)
3410             {
3411                 fftgrid[i0+z] = grid_th[i0t+z];
3412             }
3413         }
3414     }
3415 }
3416
3417 static void
3418 reduce_threadgrid_overlap(gmx_pme_t pme,
3419                           const pmegrids_t *pmegrids,int thread,
3420                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3421 {
3422     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3423     int  fft_nx,fft_ny,fft_nz;
3424     int  fft_my,fft_mz;
3425     int  buf_my=-1;
3426     int  nsx,nsy,nsz;
3427     ivec ne;
3428     int  offx,offy,offz,x,y,z,i0,i0t;
3429     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3430     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3431     gmx_bool bCommX,bCommY;
3432     int  d;
3433     int  thread_f;
3434     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3435     const real *grid_th;
3436     real *commbuf=NULL;
3437
3438     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3439                                    local_fft_ndata,
3440                                    local_fft_offset,
3441                                    local_fft_size);
3442     fft_nx = local_fft_ndata[XX];
3443     fft_ny = local_fft_ndata[YY];
3444     fft_nz = local_fft_ndata[ZZ];
3445
3446     fft_my = local_fft_size[YY];
3447     fft_mz = local_fft_size[ZZ];
3448
3449     /* This routine is called when all thread have finished spreading.
3450      * Here each thread sums grid contributions calculated by other threads
3451      * to the thread local grid volume.
3452      * To minimize the number of grid copying operations,
3453      * this routines sums immediately from the pmegrid to the fftgrid.
3454      */
3455
3456     /* Determine which part of the full node grid we should operate on,
3457      * this is our thread local part of the full grid.
3458      */
3459     pmegrid = &pmegrids->grid_th[thread];
3460
3461     for(d=0; d<DIM; d++)
3462     {
3463         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3464                     local_fft_ndata[d]);
3465     }
3466
3467     offx = pmegrid->offset[XX];
3468     offy = pmegrid->offset[YY];
3469     offz = pmegrid->offset[ZZ];
3470
3471
3472     bClearBufX  = TRUE;
3473     bClearBufY  = TRUE;
3474     bClearBufXY = TRUE;
3475
3476     /* Now loop over all the thread data blocks that contribute
3477      * to the grid region we (our thread) are operating on.
3478      */
3479     /* Note that ffy_nx/y is equal to the number of grid points
3480      * between the first point of our node grid and the one of the next node.
3481      */
3482     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3483     {
3484         fx = pmegrid->ci[XX] + sx;
3485         ox = 0;
3486         bCommX = FALSE;
3487         if (fx < 0) {
3488             fx += pmegrids->nc[XX];
3489             ox -= fft_nx;
3490             bCommX = (pme->nnodes_major > 1);
3491         }
3492         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3493         ox += pmegrid_g->offset[XX];
3494         if (!bCommX)
3495         {
3496             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3497         }
3498         else
3499         {
3500             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3501         }
3502
3503         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3504         {
3505             fy = pmegrid->ci[YY] + sy;
3506             oy = 0;
3507             bCommY = FALSE;
3508             if (fy < 0) {
3509                 fy += pmegrids->nc[YY];
3510                 oy -= fft_ny;
3511                 bCommY = (pme->nnodes_minor > 1);
3512             }
3513             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3514             oy += pmegrid_g->offset[YY];
3515             if (!bCommY)
3516             {
3517                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3518             }
3519             else
3520             {
3521                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3522             }
3523
3524             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3525             {
3526                 fz = pmegrid->ci[ZZ] + sz;
3527                 oz = 0;
3528                 if (fz < 0)
3529                 {
3530                     fz += pmegrids->nc[ZZ];
3531                     oz -= fft_nz;
3532                 }
3533                 pmegrid_g = &pmegrids->grid_th[fz];
3534                 oz += pmegrid_g->offset[ZZ];
3535                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3536
3537                 if (sx == 0 && sy == 0 && sz == 0)
3538                 {
3539                     /* We have already added our local contribution
3540                      * before calling this routine, so skip it here.
3541                      */
3542                     continue;
3543                 }
3544
3545                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3546
3547                 pmegrid_f = &pmegrids->grid_th[thread_f];
3548
3549                 grid_th = pmegrid_f->grid;
3550
3551                 nsx = pmegrid_f->s[XX];
3552                 nsy = pmegrid_f->s[YY];
3553                 nsz = pmegrid_f->s[ZZ];
3554
3555 #ifdef DEBUG_PME_REDUCE
3556                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3557                        pme->nodeid,thread,thread_f,
3558                        pme->pmegrid_start_ix,
3559                        pme->pmegrid_start_iy,
3560                        pme->pmegrid_start_iz,
3561                        sx,sy,sz,
3562                        offx-ox,tx1-ox,offx,tx1,
3563                        offy-oy,ty1-oy,offy,ty1,
3564                        offz-oz,tz1-oz,offz,tz1);
3565 #endif
3566
3567                 if (!(bCommX || bCommY))
3568                 {
3569                     /* Copy from the thread local grid to the node grid */
3570                     for(x=offx; x<tx1; x++)
3571                     {
3572                         for(y=offy; y<ty1; y++)
3573                         {
3574                             i0  = (x*fft_my + y)*fft_mz;
3575                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3576                             for(z=offz; z<tz1; z++)
3577                             {
3578                                 fftgrid[i0+z] += grid_th[i0t+z];
3579                             }
3580                         }
3581                     }
3582                 }
3583                 else
3584                 {
3585                     /* The order of this conditional decides
3586                      * where the corner volume gets stored with x+y decomp.
3587                      */
3588                     if (bCommY)
3589                     {
3590                         commbuf = commbuf_y;
3591                         buf_my  = ty1 - offy;
3592                         if (bCommX)
3593                         {
3594                             /* We index commbuf modulo the local grid size */
3595                             commbuf += buf_my*fft_nx*fft_nz;
3596
3597                             bClearBuf  = bClearBufXY;
3598                             bClearBufXY = FALSE;
3599                         }
3600                         else
3601                         {
3602                             bClearBuf  = bClearBufY;
3603                             bClearBufY = FALSE;
3604                         }
3605                     }
3606                     else
3607                     {
3608                         commbuf = commbuf_x;
3609                         buf_my  = fft_ny;
3610                         bClearBuf  = bClearBufX;
3611                         bClearBufX = FALSE;
3612                     }
3613
3614                     /* Copy to the communication buffer */
3615                     for(x=offx; x<tx1; x++)
3616                     {
3617                         for(y=offy; y<ty1; y++)
3618                         {
3619                             i0  = (x*buf_my + y)*fft_nz;
3620                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3621
3622                             if (bClearBuf)
3623                             {
3624                                 /* First access of commbuf, initialize it */
3625                                 for(z=offz; z<tz1; z++)
3626                                 {
3627                                     commbuf[i0+z]  = grid_th[i0t+z];
3628                                 }
3629                             }
3630                             else
3631                             {
3632                                 for(z=offz; z<tz1; z++)
3633                                 {
3634                                     commbuf[i0+z] += grid_th[i0t+z];
3635                                 }
3636                             }
3637                         }
3638                     }
3639                 }
3640             }
3641         }
3642     }
3643 }
3644
3645
3646 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3647 {
3648     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3649     pme_overlap_t *overlap;
3650     int  send_index0,send_nindex;
3651     int  recv_nindex;
3652 #ifdef GMX_MPI
3653     MPI_Status stat;
3654 #endif
3655     int  send_size_y,recv_size_y;
3656     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3657     real *sendptr,*recvptr;
3658     int  x,y,z,indg,indb;
3659
3660     /* Note that this routine is only used for forward communication.
3661      * Since the force gathering, unlike the charge spreading,
3662      * can be trivially parallelized over the particles,
3663      * the backwards process is much simpler and can use the "old"
3664      * communication setup.
3665      */
3666
3667     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3668                                    local_fft_ndata,
3669                                    local_fft_offset,
3670                                    local_fft_size);
3671
3672     if (pme->nnodes_minor > 1)
3673     {
3674         /* Major dimension */
3675         overlap = &pme->overlap[1];
3676
3677         if (pme->nnodes_major > 1)
3678         {
3679              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3680         }
3681         else
3682         {
3683             size_yx = 0;
3684         }
3685         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3686
3687         send_size_y = overlap->send_size;
3688
3689         for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
3690         {
3691             send_id = overlap->send_id[ipulse];
3692             recv_id = overlap->recv_id[ipulse];
3693             send_index0   =
3694                 overlap->comm_data[ipulse].send_index0 -
3695                 overlap->comm_data[0].send_index0;
3696             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3697             /* We don't use recv_index0, as we always receive starting at 0 */
3698             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3699             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3700
3701             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3702             recvptr = overlap->recvbuf;
3703
3704 #ifdef GMX_MPI
3705             MPI_Sendrecv(sendptr,send_size_y*datasize,GMX_MPI_REAL,
3706                          send_id,ipulse,
3707                          recvptr,recv_size_y*datasize,GMX_MPI_REAL,
3708                          recv_id,ipulse,
3709                          overlap->mpi_comm,&stat);
3710 #endif
3711
3712             for(x=0; x<local_fft_ndata[XX]; x++)
3713             {
3714                 for(y=0; y<recv_nindex; y++)
3715                 {
3716                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3717                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3718                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3719                     {
3720                         fftgrid[indg+z] += recvptr[indb+z];
3721                     }
3722                 }
3723             }
3724
3725             if (pme->nnodes_major > 1)
3726             {
3727                 /* Copy from the received buffer to the send buffer for dim 0 */
3728                 sendptr = pme->overlap[0].sendbuf;
3729                 for(x=0; x<size_yx; x++)
3730                 {
3731                     for(y=0; y<recv_nindex; y++)
3732                     {
3733                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3734                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3735                         for(z=0; z<local_fft_ndata[ZZ]; z++)
3736                         {
3737                             sendptr[indg+z] += recvptr[indb+z];
3738                         }
3739                     }
3740                 }
3741             }
3742         }
3743     }
3744
3745     /* We only support a single pulse here.
3746      * This is not a severe limitation, as this code is only used
3747      * with OpenMP and with OpenMP the (PME) domains can be larger.
3748      */
3749     if (pme->nnodes_major > 1)
3750     {
3751         /* Major dimension */
3752         overlap = &pme->overlap[0];
3753
3754         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3755         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3756
3757         ipulse = 0;
3758
3759         send_id = overlap->send_id[ipulse];
3760         recv_id = overlap->recv_id[ipulse];
3761         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3762         /* We don't use recv_index0, as we always receive starting at 0 */
3763         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3764
3765         sendptr = overlap->sendbuf;
3766         recvptr = overlap->recvbuf;
3767
3768         if (debug != NULL)
3769         {
3770             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3771                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3772         }
3773
3774 #ifdef GMX_MPI
3775         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3776                      send_id,ipulse,
3777                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3778                      recv_id,ipulse,
3779                      overlap->mpi_comm,&stat);
3780 #endif
3781
3782         for(x=0; x<recv_nindex; x++)
3783         {
3784             for(y=0; y<local_fft_ndata[YY]; y++)
3785             {
3786                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3787                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3788                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3789                 {
3790                     fftgrid[indg+z] += recvptr[indb+z];
3791                 }
3792             }
3793         }
3794     }
3795 }
3796
3797
3798 static void spread_on_grid(gmx_pme_t pme,
3799                            pme_atomcomm_t *atc,pmegrids_t *grids,
3800                            gmx_bool bCalcSplines,gmx_bool bSpread,
3801                            real *fftgrid)
3802 {
3803     int nthread,thread;
3804 #ifdef PME_TIME_THREADS
3805     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3806     static double cs1=0,cs2=0,cs3=0;
3807     static double cs1a[6]={0,0,0,0,0,0};
3808     static int cnt=0;
3809 #endif
3810
3811     nthread = pme->nthread;
3812     assert(nthread>0);
3813
3814 #ifdef PME_TIME_THREADS
3815     c1 = omp_cyc_start();
3816 #endif
3817     if (bCalcSplines)
3818     {
3819 #pragma omp parallel for num_threads(nthread) schedule(static)
3820         for(thread=0; thread<nthread; thread++)
3821         {
3822             int start,end;
3823
3824             start = atc->n* thread   /nthread;
3825             end   = atc->n*(thread+1)/nthread;
3826
3827             /* Compute fftgrid index for all atoms,
3828              * with help of some extra variables.
3829              */
3830             calc_interpolation_idx(pme,atc,start,end,thread);
3831         }
3832     }
3833 #ifdef PME_TIME_THREADS
3834     c1 = omp_cyc_end(c1);
3835     cs1 += (double)c1;
3836 #endif
3837
3838 #ifdef PME_TIME_THREADS
3839     c2 = omp_cyc_start();
3840 #endif
3841 #pragma omp parallel for num_threads(nthread) schedule(static)
3842     for(thread=0; thread<nthread; thread++)
3843     {
3844         splinedata_t *spline;
3845         pmegrid_t *grid;
3846
3847         /* make local bsplines  */
3848         if (grids == NULL || grids->nthread == 1)
3849         {
3850             spline = &atc->spline[0];
3851
3852             spline->n = atc->n;
3853
3854             grid = &grids->grid;
3855         }
3856         else
3857         {
3858             spline = &atc->spline[thread];
3859
3860             make_thread_local_ind(atc,thread,spline);
3861
3862             grid = &grids->grid_th[thread];
3863         }
3864
3865         if (bCalcSplines)
3866         {
3867             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3868                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3869         }
3870
3871         if (bSpread)
3872         {
3873             /* put local atoms on grid. */
3874 #ifdef PME_TIME_SPREAD
3875             ct1a = omp_cyc_start();
3876 #endif
3877             spread_q_bsplines_thread(grid,atc,spline,pme->spline_work);
3878
3879             if (grids->nthread > 1)
3880             {
3881                 copy_local_grid(pme,grids,thread,fftgrid);
3882             }
3883 #ifdef PME_TIME_SPREAD
3884             ct1a = omp_cyc_end(ct1a);
3885             cs1a[thread] += (double)ct1a;
3886 #endif
3887         }
3888     }
3889 #ifdef PME_TIME_THREADS
3890     c2 = omp_cyc_end(c2);
3891     cs2 += (double)c2;
3892 #endif
3893
3894     if (bSpread && grids->nthread > 1)
3895     {
3896 #ifdef PME_TIME_THREADS
3897         c3 = omp_cyc_start();
3898 #endif
3899 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3900         for(thread=0; thread<grids->nthread; thread++)
3901         {
3902             reduce_threadgrid_overlap(pme,grids,thread,
3903                                       fftgrid,
3904                                       pme->overlap[0].sendbuf,
3905                                       pme->overlap[1].sendbuf);
3906         }
3907 #ifdef PME_TIME_THREADS
3908         c3 = omp_cyc_end(c3);
3909         cs3 += (double)c3;
3910 #endif
3911
3912         if (pme->nnodes > 1)
3913         {
3914             /* Communicate the overlapping part of the fftgrid */
3915             sum_fftgrid_dd(pme,fftgrid);
3916         }
3917     }
3918
3919 #ifdef PME_TIME_THREADS
3920     cnt++;
3921     if (cnt % 20 == 0)
3922     {
3923         printf("idx %.2f spread %.2f red %.2f",
3924                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3925 #ifdef PME_TIME_SPREAD
3926         for(thread=0; thread<nthread; thread++)
3927             printf(" %.2f",cs1a[thread]*1e-9);
3928 #endif
3929         printf("\n");
3930     }
3931 #endif
3932 }
3933
3934
3935 static void dump_grid(FILE *fp,
3936                       int sx,int sy,int sz,int nx,int ny,int nz,
3937                       int my,int mz,const real *g)
3938 {
3939     int x,y,z;
3940
3941     for(x=0; x<nx; x++)
3942     {
3943         for(y=0; y<ny; y++)
3944         {
3945             for(z=0; z<nz; z++)
3946             {
3947                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3948                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3949             }
3950         }
3951     }
3952 }
3953
3954 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3955 {
3956     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3957
3958     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3959                                    local_fft_ndata,
3960                                    local_fft_offset,
3961                                    local_fft_size);
3962
3963     dump_grid(stderr,
3964               pme->pmegrid_start_ix,
3965               pme->pmegrid_start_iy,
3966               pme->pmegrid_start_iz,
3967               pme->pmegrid_nx-pme->pme_order+1,
3968               pme->pmegrid_ny-pme->pme_order+1,
3969               pme->pmegrid_nz-pme->pme_order+1,
3970               local_fft_size[YY],
3971               local_fft_size[ZZ],
3972               fftgrid);
3973 }
3974
3975
3976 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3977 {
3978     pme_atomcomm_t *atc;
3979     pmegrids_t *grid;
3980
3981     if (pme->nnodes > 1)
3982     {
3983         gmx_incons("gmx_pme_calc_energy called in parallel");
3984     }
3985     if (pme->bFEP > 1)
3986     {
3987         gmx_incons("gmx_pme_calc_energy with free energy");
3988     }
3989
3990     atc = &pme->atc_energy;
3991     atc->nthread   = 1;
3992     if (atc->spline == NULL)
3993     {
3994         snew(atc->spline,atc->nthread);
3995     }
3996     atc->nslab     = 1;
3997     atc->bSpread   = TRUE;
3998     atc->pme_order = pme->pme_order;
3999     atc->n         = n;
4000     pme_realloc_atomcomm_things(atc);
4001     atc->x         = x;
4002     atc->q         = q;
4003
4004     /* We only use the A-charges grid */
4005     grid = &pme->pmegridA;
4006
4007     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
4008
4009     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
4010 }
4011
4012
4013 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
4014         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
4015 {
4016     /* Reset all the counters related to performance over the run */
4017     wallcycle_stop(wcycle,ewcRUN);
4018     wallcycle_reset_all(wcycle);
4019     init_nrnb(nrnb);
4020     ir->init_step += step_rel;
4021     ir->nsteps    -= step_rel;
4022     wallcycle_start(wcycle,ewcRUN);
4023 }
4024
4025
4026 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4027                                ivec grid_size,
4028                                t_commrec *cr, t_inputrec *ir,
4029                                gmx_pme_t *pme_ret)
4030 {
4031     int ind;
4032     gmx_pme_t pme = NULL;
4033
4034     ind = 0;
4035     while (ind < *npmedata)
4036     {
4037         pme = (*pmedata)[ind];
4038         if (pme->nkx == grid_size[XX] &&
4039             pme->nky == grid_size[YY] &&
4040             pme->nkz == grid_size[ZZ])
4041         {
4042             *pme_ret = pme;
4043
4044             return;
4045         }
4046
4047         ind++;
4048     }
4049
4050     (*npmedata)++;
4051     srenew(*pmedata,*npmedata);
4052
4053     /* Generate a new PME data structure, copying part of the old pointers */
4054     gmx_pme_reinit(&((*pmedata)[ind]),cr,pme,ir,grid_size);
4055
4056     *pme_ret = (*pmedata)[ind];
4057 }
4058
4059
4060 int gmx_pmeonly(gmx_pme_t pme,
4061                 t_commrec *cr,    t_nrnb *nrnb,
4062                 gmx_wallcycle_t wcycle,
4063                 real ewaldcoeff,  gmx_bool bGatherOnly,
4064                 t_inputrec *ir)
4065 {
4066     int npmedata;
4067     gmx_pme_t *pmedata;
4068     gmx_pme_pp_t pme_pp;
4069     int  natoms;
4070     matrix box;
4071     rvec *x_pp=NULL,*f_pp=NULL;
4072     real *chargeA=NULL,*chargeB=NULL;
4073     real lambda=0;
4074     int  maxshift_x=0,maxshift_y=0;
4075     real energy,dvdlambda;
4076     matrix vir;
4077     float cycles;
4078     int  count;
4079     gmx_bool bEnerVir;
4080     gmx_large_int_t step,step_rel;
4081     ivec grid_switch;
4082
4083     /* This data will only use with PME tuning, i.e. switching PME grids */
4084     npmedata = 1;
4085     snew(pmedata,npmedata);
4086     pmedata[0] = pme;
4087
4088     pme_pp = gmx_pme_pp_init(cr);
4089
4090     init_nrnb(nrnb);
4091
4092     count = 0;
4093     do /****** this is a quasi-loop over time steps! */
4094     {
4095         /* The reason for having a loop here is PME grid tuning/switching */
4096         do
4097         {
4098             /* Domain decomposition */
4099             natoms = gmx_pme_recv_q_x(pme_pp,
4100                                       &chargeA,&chargeB,box,&x_pp,&f_pp,
4101                                       &maxshift_x,&maxshift_y,
4102                                       &pme->bFEP,&lambda,
4103                                       &bEnerVir,
4104                                       &step,
4105                                       grid_switch,&ewaldcoeff);
4106
4107             if (natoms == -2)
4108             {
4109                 /* Switch the PME grid to grid_switch */
4110                 gmx_pmeonly_switch(&npmedata,&pmedata,grid_switch,cr,ir,&pme);
4111             }
4112         }
4113         while (natoms == -2);
4114
4115         if (natoms == -1)
4116         {
4117             /* We should stop: break out of the loop */
4118             break;
4119         }
4120
4121         step_rel = step - ir->init_step;
4122
4123         if (count == 0)
4124             wallcycle_start(wcycle,ewcRUN);
4125
4126         wallcycle_start(wcycle,ewcPMEMESH);
4127
4128         dvdlambda = 0;
4129         clear_mat(vir);
4130         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
4131                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
4132                    &energy,lambda,&dvdlambda,
4133                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4134
4135         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
4136
4137         gmx_pme_send_force_vir_ener(pme_pp,
4138                                     f_pp,vir,energy,dvdlambda,
4139                                     cycles);
4140
4141         count++;
4142
4143         if (step_rel == wcycle_get_reset_counters(wcycle))
4144         {
4145             /* Reset all the counters related to performance over the run */
4146             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4147             wcycle_set_reset_counters(wcycle, 0);
4148         }
4149
4150     } /***** end of quasi-loop, we stop with the break above */
4151     while (TRUE);
4152
4153     return 0;
4154 }
4155
4156 int gmx_pme_do(gmx_pme_t pme,
4157                int start,       int homenr,
4158                rvec x[],        rvec f[],
4159                real *chargeA,   real *chargeB,
4160                matrix box, t_commrec *cr,
4161                int  maxshift_x, int maxshift_y,
4162                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4163                matrix vir,      real ewaldcoeff,
4164                real *energy,    real lambda,
4165                real *dvdlambda, int flags)
4166 {
4167     int     q,d,i,j,ntot,npme;
4168     int     nx,ny,nz;
4169     int     n_d,local_ny;
4170     pme_atomcomm_t *atc=NULL;
4171     pmegrids_t *pmegrid=NULL;
4172     real    *grid=NULL;
4173     real    *ptr;
4174     rvec    *x_d,*f_d;
4175     real    *charge=NULL,*q_d;
4176     real    energy_AB[2];
4177     matrix  vir_AB[2];
4178     gmx_bool bClearF;
4179     gmx_parallel_3dfft_t pfft_setup;
4180     real *  fftgrid;
4181     t_complex * cfftgrid;
4182     int     thread;
4183     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4184     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4185
4186     assert(pme->nnodes > 0);
4187     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4188
4189     if (pme->nnodes > 1) {
4190         atc = &pme->atc[0];
4191         atc->npd = homenr;
4192         if (atc->npd > atc->pd_nalloc) {
4193             atc->pd_nalloc = over_alloc_dd(atc->npd);
4194             srenew(atc->pd,atc->pd_nalloc);
4195         }
4196         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4197     }
4198     else
4199     {
4200         /* This could be necessary for TPI */
4201         pme->atc[0].n = homenr;
4202     }
4203
4204     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4205         if (q == 0) {
4206             pmegrid = &pme->pmegridA;
4207             fftgrid = pme->fftgridA;
4208             cfftgrid = pme->cfftgridA;
4209             pfft_setup = pme->pfft_setupA;
4210             charge = chargeA+start;
4211         } else {
4212             pmegrid = &pme->pmegridB;
4213             fftgrid = pme->fftgridB;
4214             cfftgrid = pme->cfftgridB;
4215             pfft_setup = pme->pfft_setupB;
4216             charge = chargeB+start;
4217         }
4218         grid = pmegrid->grid.grid;
4219         /* Unpack structure */
4220         if (debug) {
4221             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4222                     cr->nnodes,cr->nodeid);
4223             fprintf(debug,"Grid = %p\n",(void*)grid);
4224             if (grid == NULL)
4225                 gmx_fatal(FARGS,"No grid!");
4226         }
4227         where();
4228
4229         m_inv_ur0(box,pme->recipbox);
4230
4231         if (pme->nnodes == 1) {
4232             atc = &pme->atc[0];
4233             if (DOMAINDECOMP(cr)) {
4234                 atc->n = homenr;
4235                 pme_realloc_atomcomm_things(atc);
4236             }
4237             atc->x = x;
4238             atc->q = charge;
4239             atc->f = f;
4240         } else {
4241             wallcycle_start(wcycle,ewcPME_REDISTXF);
4242             for(d=pme->ndecompdim-1; d>=0; d--)
4243             {
4244                 if (d == pme->ndecompdim-1)
4245                 {
4246                     n_d = homenr;
4247                     x_d = x + start;
4248                     q_d = charge;
4249                 }
4250                 else
4251                 {
4252                     n_d = pme->atc[d+1].n;
4253                     x_d = atc->x;
4254                     q_d = atc->q;
4255                 }
4256                 atc = &pme->atc[d];
4257                 atc->npd = n_d;
4258                 if (atc->npd > atc->pd_nalloc) {
4259                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4260                     srenew(atc->pd,atc->pd_nalloc);
4261                 }
4262                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4263                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4264                 where();
4265
4266                 GMX_BARRIER(cr->mpi_comm_mygroup);
4267                 /* Redistribute x (only once) and qA or qB */
4268                 if (DOMAINDECOMP(cr)) {
4269                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4270                 } else {
4271                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4272                 }
4273             }
4274             where();
4275
4276             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4277         }
4278
4279         if (debug)
4280             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4281                     cr->nodeid,atc->n);
4282
4283         if (flags & GMX_PME_SPREAD_Q)
4284         {
4285             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4286
4287             /* Spread the charges on a grid */
4288             GMX_MPE_LOG(ev_spread_on_grid_start);
4289
4290             /* Spread the charges on a grid */
4291             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4292             GMX_MPE_LOG(ev_spread_on_grid_finish);
4293
4294             if (q == 0)
4295             {
4296                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4297             }
4298             inc_nrnb(nrnb,eNR_SPREADQBSP,
4299                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4300
4301             if (pme->nthread == 1)
4302             {
4303                 wrap_periodic_pmegrid(pme,grid);
4304
4305                 /* sum contributions to local grid from other nodes */
4306 #ifdef GMX_MPI
4307                 if (pme->nnodes > 1)
4308                 {
4309                     GMX_BARRIER(cr->mpi_comm_mygroup);
4310                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4311                     where();
4312                 }
4313 #endif
4314
4315                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4316             }
4317
4318             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4319
4320             /*
4321             dump_local_fftgrid(pme,fftgrid);
4322             exit(0);
4323             */
4324         }
4325
4326         /* Here we start a large thread parallel region */
4327 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4328         for(thread=0; thread<pme->nthread; thread++)
4329         {
4330             if (flags & GMX_PME_SOLVE)
4331             {
4332                 int loop_count;
4333
4334                 /* do 3d-fft */
4335                 if (thread == 0)
4336                 {
4337                     GMX_BARRIER(cr->mpi_comm_mygroup);
4338                     GMX_MPE_LOG(ev_gmxfft3d_start);
4339                     wallcycle_start(wcycle,ewcPME_FFT);
4340                 }
4341                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4342                                            fftgrid,cfftgrid,thread,wcycle);
4343                 if (thread == 0)
4344                 {
4345                     wallcycle_stop(wcycle,ewcPME_FFT);
4346                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4347                 }
4348                 where();
4349
4350                 /* solve in k-space for our local cells */
4351                 if (thread == 0)
4352                 {
4353                     GMX_BARRIER(cr->mpi_comm_mygroup);
4354                     GMX_MPE_LOG(ev_solve_pme_start);
4355                     wallcycle_start(wcycle,ewcPME_SOLVE);
4356                 }
4357                 loop_count =
4358                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4359                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4360                                   bCalcEnerVir,
4361                                   pme->nthread,thread);
4362                 if (thread == 0)
4363                 {
4364                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4365                     where();
4366                     GMX_MPE_LOG(ev_solve_pme_finish);
4367                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4368                 }
4369             }
4370
4371             if (bCalcF)
4372             {
4373                 /* do 3d-invfft */
4374                 if (thread == 0)
4375                 {
4376                     GMX_BARRIER(cr->mpi_comm_mygroup);
4377                     GMX_MPE_LOG(ev_gmxfft3d_start);
4378                     where();
4379                     wallcycle_start(wcycle,ewcPME_FFT);
4380                 }
4381                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4382                                            cfftgrid,fftgrid,thread,wcycle);
4383                 if (thread == 0)
4384                 {
4385                     wallcycle_stop(wcycle,ewcPME_FFT);
4386                     
4387                     where();
4388                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4389
4390                     if (pme->nodeid == 0)
4391                     {
4392                         ntot = pme->nkx*pme->nky*pme->nkz;
4393                         npme  = ntot*log((real)ntot)/log(2.0);
4394                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4395                     }
4396
4397                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4398                 }
4399
4400                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4401             }
4402         }
4403         /* End of thread parallel section.
4404          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4405          */
4406
4407         if (bCalcF)
4408         {
4409             /* distribute local grid to all nodes */
4410 #ifdef GMX_MPI
4411             if (pme->nnodes > 1) {
4412                 GMX_BARRIER(cr->mpi_comm_mygroup);
4413                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4414             }
4415 #endif
4416             where();
4417
4418             unwrap_periodic_pmegrid(pme,grid);
4419
4420             /* interpolate forces for our local atoms */
4421             GMX_BARRIER(cr->mpi_comm_mygroup);
4422             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4423
4424             where();
4425
4426             /* If we are running without parallelization,
4427              * atc->f is the actual force array, not a buffer,
4428              * therefore we should not clear it.
4429              */
4430             bClearF = (q == 0 && PAR(cr));
4431 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4432             for(thread=0; thread<pme->nthread; thread++)
4433             {
4434                 gather_f_bsplines(pme,grid,bClearF,atc,
4435                                   &atc->spline[thread],
4436                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4437             }
4438
4439             where();
4440
4441             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4442
4443             inc_nrnb(nrnb,eNR_GATHERFBSP,
4444                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4445             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4446         }
4447
4448         if (bCalcEnerVir)
4449         {
4450             /* This should only be called on the master thread
4451              * and after the threads have synchronized.
4452              */
4453             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4454         }
4455     } /* of q-loop */
4456
4457     if (bCalcF && pme->nnodes > 1) {
4458         wallcycle_start(wcycle,ewcPME_REDISTXF);
4459         for(d=0; d<pme->ndecompdim; d++)
4460         {
4461             atc = &pme->atc[d];
4462             if (d == pme->ndecompdim - 1)
4463             {
4464                 n_d = homenr;
4465                 f_d = f + start;
4466             }
4467             else
4468             {
4469                 n_d = pme->atc[d+1].n;
4470                 f_d = pme->atc[d+1].f;
4471             }
4472             GMX_BARRIER(cr->mpi_comm_mygroup);
4473             if (DOMAINDECOMP(cr)) {
4474                 dd_pmeredist_f(pme,atc,n_d,f_d,
4475                                d==pme->ndecompdim-1 && pme->bPPnode);
4476             } else {
4477                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4478             }
4479         }
4480
4481         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4482     }
4483     where();
4484
4485     if (bCalcEnerVir)
4486     {
4487         if (!pme->bFEP) {
4488             *energy = energy_AB[0];
4489             m_add(vir,vir_AB[0],vir);
4490         } else {
4491             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4492             *dvdlambda += energy_AB[1] - energy_AB[0];
4493             for(i=0; i<DIM; i++)
4494             {
4495                 for(j=0; j<DIM; j++)
4496                 {
4497                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4498                         lambda*vir_AB[1][i][j];
4499                 }
4500             }
4501         }
4502     }
4503     else
4504     {
4505         *energy = 0;
4506     }
4507
4508     if (debug)
4509     {
4510         fprintf(debug,"PME mesh energy: %g\n",*energy);
4511     }
4512
4513     return 0;
4514 }