Merge remote-tracking branch 'gerrit/release-4-6'
[alexxy/gromacs.git] / src / gromacs / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #ifdef GMX_OPENMP
71 #include <omp.h>
72 #endif
73
74 #include <stdio.h>
75 #include <string.h>
76 #include <math.h>
77 #include <assert.h>
78 #include "typedefs.h"
79 #include "txtdump.h"
80 #include "vec.h"
81 #include "gmxcomplex.h"
82 #include "smalloc.h"
83 #include "futil.h"
84 #include "coulomb.h"
85 #include "gmx_fatal.h"
86 #include "pme.h"
87 #include "network.h"
88 #include "physics.h"
89 #include "nrnb.h"
90 #include "copyrite.h"
91 #include "gmx_wallcycle.h"
92 #include "gmx_parallel_3dfft.h"
93 #include "pdbio.h"
94 #include "gmx_cyclecounter.h"
95 #include "macros.h"
96
97 /* Single precision, with SSE2 or higher available */
98 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
99
100 #include "gmx_x86_sse2.h"
101 #include "gmx_math_x86_sse2_single.h"
102
103 #define PME_SSE
104 /* Some old AMD processors could have problems with unaligned loads+stores */
105 #ifndef GMX_FAHCORE
106 #define PME_SSE_UNALIGNED
107 #endif
108 #endif
109
110 #define DFT_TOL 1e-7
111 /* #define PRT_FORCE */
112 /* conditions for on the fly time-measurement */
113 /* #define TAKETIME (step > 1 && timesteps < 10) */
114 #define TAKETIME FALSE
115
116 /* #define PME_TIME_THREADS */
117
118 #ifdef GMX_DOUBLE
119 #define mpi_type MPI_DOUBLE
120 #else
121 #define mpi_type MPI_FLOAT
122 #endif
123
124 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
125 #define GMX_CACHE_SEP 64
126
127 /* We only define a maximum to be able to use local arrays without allocation.
128  * An order larger than 12 should never be needed, even for test cases.
129  * If needed it can be changed here.
130  */
131 #define PME_ORDER_MAX 12
132
133 /* Internal datastructures */
134 typedef struct {
135     int send_index0;
136     int send_nindex;
137     int recv_index0;
138     int recv_nindex;
139 } pme_grid_comm_t;
140
141 typedef struct {
142 #ifdef GMX_MPI
143     MPI_Comm mpi_comm;
144 #endif
145     int  nnodes,nodeid;
146     int  *s2g0;
147     int  *s2g1;
148     int  noverlap_nodes;
149     int  *send_id,*recv_id;
150     pme_grid_comm_t *comm_data;
151     real *sendbuf;
152     real *recvbuf;
153 } pme_overlap_t;
154
155 typedef struct {
156     int *n;     /* Cumulative counts of the number of particles per thread */
157     int nalloc; /* Allocation size of i */
158     int *i;     /* Particle indices ordered on thread index (n) */
159 } thread_plist_t;
160
161 typedef struct {
162     int  n;
163     int  *ind;
164     splinevec theta;
165     splinevec dtheta;
166 } splinedata_t;
167
168 typedef struct {
169     int  dimind;            /* The index of the dimension, 0=x, 1=y */
170     int  nslab;
171     int  nodeid;
172 #ifdef GMX_MPI
173     MPI_Comm mpi_comm;
174 #endif
175
176     int  *node_dest;        /* The nodes to send x and q to with DD */
177     int  *node_src;         /* The nodes to receive x and q from with DD */
178     int  *buf_index;        /* Index for commnode into the buffers */
179
180     int  maxshift;
181
182     int  npd;
183     int  pd_nalloc;
184     int  *pd;
185     int  *count;            /* The number of atoms to send to each node */
186     int  **count_thread;
187     int  *rcount;           /* The number of atoms to receive */
188
189     int  n;
190     int  nalloc;
191     rvec *x;
192     real *q;
193     rvec *f;
194     gmx_bool bSpread;       /* These coordinates are used for spreading */
195     int  pme_order;
196     ivec *idx;
197     rvec *fractx;            /* Fractional coordinate relative to the
198                               * lower cell boundary
199                               */
200     int  nthread;
201     int  *thread_idx;        /* Which thread should spread which charge */
202     thread_plist_t *thread_plist;
203     splinedata_t *spline;
204 } pme_atomcomm_t;
205
206 #define FLBS  3
207 #define FLBSZ 4
208
209 typedef struct {
210     ivec ci;     /* The spatial location of this grid       */
211     ivec n;      /* The size of *grid, including order-1    */
212     ivec offset; /* The grid offset from the full node grid */
213     int  order;  /* PME spreading order                     */
214     real *grid;  /* The grid local thread, size n           */
215 } pmegrid_t;
216
217 typedef struct {
218     pmegrid_t grid;     /* The full node grid (non thread-local)            */
219     int  nthread;       /* The number of threads operating on this grid     */
220     ivec nc;            /* The local spatial decomposition over the threads */
221     pmegrid_t *grid_th; /* Array of grids for each thread                   */
222     int  **g2t;         /* The grid to thread index                         */
223     ivec nthread_comm;  /* The number of threads to communicate with        */
224 } pmegrids_t;
225
226
227 typedef struct {
228 #ifdef PME_SSE
229     /* Masks for SSE aligned spreading and gathering */
230     __m128 mask_SSE0[6],mask_SSE1[6];
231 #else
232     int dummy; /* C89 requires that struct has at least one member */
233 #endif
234 } pme_spline_work_t;
235
236 typedef struct {
237     /* work data for solve_pme */
238     int      nalloc;
239     real *   mhx;
240     real *   mhy;
241     real *   mhz;
242     real *   m2;
243     real *   denom;
244     real *   tmp1_alloc;
245     real *   tmp1;
246     real *   eterm;
247     real *   m2inv;
248
249     real     energy;
250     matrix   vir;
251 } pme_work_t;
252
253 typedef struct gmx_pme {
254     int  ndecompdim;         /* The number of decomposition dimensions */
255     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
256     int  nodeid_major;
257     int  nodeid_minor;
258     int  nnodes;             /* The number of nodes doing PME */
259     int  nnodes_major;
260     int  nnodes_minor;
261
262     MPI_Comm mpi_comm;
263     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
264 #ifdef GMX_MPI
265     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
266 #endif
267
268     int  nthread;            /* The number of threads doing PME */
269
270     gmx_bool bPPnode;        /* Node also does particle-particle forces */
271     gmx_bool bFEP;           /* Compute Free energy contribution */
272     int nkx,nky,nkz;         /* Grid dimensions */
273     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
274     int pme_order;
275     real epsilon_r;
276
277     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
278     pmegrids_t pmegridB;
279     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
280     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
281     /* pmegrid_nz might be larger than strictly necessary to ensure
282      * memory alignment, pmegrid_nz_base gives the real base size.
283      */
284     int     pmegrid_nz_base;
285     /* The local PME grid starting indices */
286     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
287
288     /* Work data for spreading and gathering */
289     pme_spline_work_t spline_work;
290
291     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
292     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
293     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
294
295     t_complex *cfftgridA;             /* Grids for complex FFT data */
296     t_complex *cfftgridB;
297     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
298
299     gmx_parallel_3dfft_t  pfft_setupA;
300     gmx_parallel_3dfft_t  pfft_setupB;
301
302     int  *nnx,*nny,*nnz;
303     real *fshx,*fshy,*fshz;
304
305     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
306     matrix    recipbox;
307     splinevec bsp_mod;
308
309     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
310
311     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
312
313     rvec *bufv;             /* Communication buffer */
314     real *bufr;             /* Communication buffer */
315     int  buf_nalloc;        /* The communication buffer size */
316
317     /* thread local work data for solve_pme */
318     pme_work_t *work;
319
320     /* Work data for PME_redist */
321     gmx_bool redist_init;
322     int *    scounts;
323     int *    rcounts;
324     int *    sdispls;
325     int *    rdispls;
326     int *    sidx;
327     int *    idxa;
328     real *   redist_buf;
329     int      redist_buf_nalloc;
330
331     /* Work data for sum_qgrid */
332     real *   sum_qgrid_tmp;
333     real *   sum_qgrid_dd_tmp;
334 } t_gmx_pme;
335
336
337 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
338                                    int start,int end,int thread)
339 {
340     int  i;
341     int  *idxptr,tix,tiy,tiz;
342     real *xptr,*fptr,tx,ty,tz;
343     real rxx,ryx,ryy,rzx,rzy,rzz;
344     int  nx,ny,nz;
345     int  start_ix,start_iy,start_iz;
346     int  *g2tx,*g2ty,*g2tz;
347     gmx_bool bThreads;
348     int  *thread_idx=NULL;
349     thread_plist_t *tpl=NULL;
350     int  *tpl_n=NULL;
351     int  thread_i;
352
353     nx  = pme->nkx;
354     ny  = pme->nky;
355     nz  = pme->nkz;
356
357     start_ix = pme->pmegrid_start_ix;
358     start_iy = pme->pmegrid_start_iy;
359     start_iz = pme->pmegrid_start_iz;
360
361     rxx = pme->recipbox[XX][XX];
362     ryx = pme->recipbox[YY][XX];
363     ryy = pme->recipbox[YY][YY];
364     rzx = pme->recipbox[ZZ][XX];
365     rzy = pme->recipbox[ZZ][YY];
366     rzz = pme->recipbox[ZZ][ZZ];
367
368     g2tx = pme->pmegridA.g2t[XX];
369     g2ty = pme->pmegridA.g2t[YY];
370     g2tz = pme->pmegridA.g2t[ZZ];
371
372     bThreads = (atc->nthread > 1);
373     if (bThreads)
374     {
375         thread_idx = atc->thread_idx;
376
377         tpl   = &atc->thread_plist[thread];
378         tpl_n = tpl->n;
379         for(i=0; i<atc->nthread; i++)
380         {
381             tpl_n[i] = 0;
382         }
383     }
384
385     for(i=start; i<end; i++) {
386         xptr   = atc->x[i];
387         idxptr = atc->idx[i];
388         fptr   = atc->fractx[i];
389
390         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
391         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
392         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
393         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
394
395         tix = (int)(tx);
396         tiy = (int)(ty);
397         tiz = (int)(tz);
398
399         /* Because decomposition only occurs in x and y,
400          * we never have a fraction correction in z.
401          */
402         fptr[XX] = tx - tix + pme->fshx[tix];
403         fptr[YY] = ty - tiy + pme->fshy[tiy];
404         fptr[ZZ] = tz - tiz;
405
406         idxptr[XX] = pme->nnx[tix];
407         idxptr[YY] = pme->nny[tiy];
408         idxptr[ZZ] = pme->nnz[tiz];
409
410 #ifdef DEBUG
411         range_check(idxptr[XX],0,pme->pmegrid_nx);
412         range_check(idxptr[YY],0,pme->pmegrid_ny);
413         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
414 #endif
415
416         if (bThreads)
417         {
418             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
419             thread_idx[i] = thread_i;
420             tpl_n[thread_i]++;
421         }
422     }
423
424     if (bThreads)
425     {
426         /* Make a list of particle indices sorted on thread */
427
428         /* Get the cumulative count */
429         for(i=1; i<atc->nthread; i++)
430         {
431             tpl_n[i] += tpl_n[i-1];
432         }
433         /* The current implementation distributes particles equally
434          * over the threads, so we could actually allocate for that
435          * in pme_realloc_atomcomm_things.
436          */
437         if (tpl_n[atc->nthread-1] > tpl->nalloc)
438         {
439             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
440             srenew(tpl->i,tpl->nalloc);
441         }
442         /* Set tpl_n to the cumulative start */
443         for(i=atc->nthread-1; i>=1; i--)
444         {
445             tpl_n[i] = tpl_n[i-1];
446         }
447         tpl_n[0] = 0;
448
449         /* Fill our thread local array with indices sorted on thread */
450         for(i=start; i<end; i++)
451         {
452             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
453         }
454         /* Now tpl_n contains the cummulative count again */
455     }
456 }
457
458 static void make_thread_local_ind(pme_atomcomm_t *atc,
459                                   int thread,splinedata_t *spline)
460 {
461     int  n,t,i,start,end;
462     thread_plist_t *tpl;
463
464     /* Combine the indices made by each thread into one index */
465
466     n = 0;
467     start = 0;
468     for(t=0; t<atc->nthread; t++)
469     {
470         tpl = &atc->thread_plist[t];
471         /* Copy our part (start - end) from the list of thread t */
472         if (thread > 0)
473         {
474             start = tpl->n[thread-1];
475         }
476         end = tpl->n[thread];
477         for(i=start; i<end; i++)
478         {
479             spline->ind[n++] = tpl->i[i];
480         }
481     }
482
483     spline->n = n;
484 }
485
486
487 static void pme_calc_pidx(int start, int end,
488                           matrix recipbox, rvec x[],
489                           pme_atomcomm_t *atc, int *count)
490 {
491     int  nslab,i;
492     int  si;
493     real *xptr,s;
494     real rxx,ryx,rzx,ryy,rzy;
495     int *pd;
496
497     /* Calculate PME task index (pidx) for each grid index.
498      * Here we always assign equally sized slabs to each node
499      * for load balancing reasons (the PME grid spacing is not used).
500      */
501
502     nslab = atc->nslab;
503     pd    = atc->pd;
504
505     /* Reset the count */
506     for(i=0; i<nslab; i++)
507     {
508         count[i] = 0;
509     }
510
511     if (atc->dimind == 0)
512     {
513         rxx = recipbox[XX][XX];
514         ryx = recipbox[YY][XX];
515         rzx = recipbox[ZZ][XX];
516         /* Calculate the node index in x-dimension */
517         for(i=start; i<end; i++)
518         {
519             xptr   = x[i];
520             /* Fractional coordinates along box vectors */
521             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
522             si = (int)(s + 2*nslab) % nslab;
523             pd[i] = si;
524             count[si]++;
525         }
526     }
527     else
528     {
529         ryy = recipbox[YY][YY];
530         rzy = recipbox[ZZ][YY];
531         /* Calculate the node index in y-dimension */
532         for(i=start; i<end; i++)
533         {
534             xptr   = x[i];
535             /* Fractional coordinates along box vectors */
536             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
537             si = (int)(s + 2*nslab) % nslab;
538             pd[i] = si;
539             count[si]++;
540         }
541     }
542 }
543
544 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
545                                   pme_atomcomm_t *atc)
546 {
547     int nthread,thread,slab;
548
549     nthread = atc->nthread;
550
551 #pragma omp parallel for num_threads(nthread) schedule(static)
552     for(thread=0; thread<nthread; thread++)
553     {
554         pme_calc_pidx(natoms* thread   /nthread,
555                       natoms*(thread+1)/nthread,
556                       recipbox,x,atc,atc->count_thread[thread]);
557     }
558     /* Non-parallel reduction, since nslab is small */
559
560     for(thread=1; thread<nthread; thread++)
561     {
562         for(slab=0; slab<atc->nslab; slab++)
563         {
564             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
565         }
566     }
567 }
568
569 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
570 {
571     int i,d;
572
573     srenew(spline->ind,atc->nalloc);
574     /* Initialize the index to identity so it works without threads */
575     for(i=0; i<atc->nalloc; i++)
576     {
577         spline->ind[i] = i;
578     }
579
580     for(d=0;d<DIM;d++)
581     {
582         srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
583         srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
584     }
585 }
586
587 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
588 {
589     int nalloc_old,i,j,nalloc_tpl;
590
591     /* We have to avoid a NULL pointer for atc->x to avoid
592      * possible fatal errors in MPI routines.
593      */
594     if (atc->n > atc->nalloc || atc->nalloc == 0)
595     {
596         nalloc_old = atc->nalloc;
597         atc->nalloc = over_alloc_dd(max(atc->n,1));
598
599         if (atc->nslab > 1) {
600             srenew(atc->x,atc->nalloc);
601             srenew(atc->q,atc->nalloc);
602             srenew(atc->f,atc->nalloc);
603             for(i=nalloc_old; i<atc->nalloc; i++)
604             {
605                 clear_rvec(atc->f[i]);
606             }
607         }
608         if (atc->bSpread) {
609             srenew(atc->fractx,atc->nalloc);
610             srenew(atc->idx   ,atc->nalloc);
611
612             if (atc->nthread > 1)
613             {
614                 srenew(atc->thread_idx,atc->nalloc);
615             }
616
617             for(i=0; i<atc->nthread; i++)
618             {
619                 pme_realloc_splinedata(&atc->spline[i],atc);
620             }
621         }
622     }
623 }
624
625 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
626                          int n, gmx_bool bXF, rvec *x_f, real *charge,
627                          pme_atomcomm_t *atc)
628 /* Redistribute particle data for PME calculation */
629 /* domain decomposition by x coordinate           */
630 {
631     int *idxa;
632     int i, ii;
633
634     if(FALSE == pme->redist_init) {
635         snew(pme->scounts,atc->nslab);
636         snew(pme->rcounts,atc->nslab);
637         snew(pme->sdispls,atc->nslab);
638         snew(pme->rdispls,atc->nslab);
639         snew(pme->sidx,atc->nslab);
640         pme->redist_init = TRUE;
641     }
642     if (n > pme->redist_buf_nalloc) {
643         pme->redist_buf_nalloc = over_alloc_dd(n);
644         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
645     }
646
647     pme->idxa = atc->pd;
648
649 #ifdef GMX_MPI
650     if (forw && bXF) {
651         /* forward, redistribution from pp to pme */
652
653         /* Calculate send counts and exchange them with other nodes */
654         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
655         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
656         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
657
658         /* Calculate send and receive displacements and index into send
659            buffer */
660         pme->sdispls[0]=0;
661         pme->rdispls[0]=0;
662         pme->sidx[0]=0;
663         for(i=1; i<atc->nslab; i++) {
664             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
665             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
666             pme->sidx[i]=pme->sdispls[i];
667         }
668         /* Total # of particles to be received */
669         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
670
671         pme_realloc_atomcomm_things(atc);
672
673         /* Copy particle coordinates into send buffer and exchange*/
674         for(i=0; (i<n); i++) {
675             ii=DIM*pme->sidx[pme->idxa[i]];
676             pme->sidx[pme->idxa[i]]++;
677             pme->redist_buf[ii+XX]=x_f[i][XX];
678             pme->redist_buf[ii+YY]=x_f[i][YY];
679             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
680         }
681         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
682                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
683                       pme->rvec_mpi, atc->mpi_comm);
684     }
685     if (forw) {
686         /* Copy charge into send buffer and exchange*/
687         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
688         for(i=0; (i<n); i++) {
689             ii=pme->sidx[pme->idxa[i]];
690             pme->sidx[pme->idxa[i]]++;
691             pme->redist_buf[ii]=charge[i];
692         }
693         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
694                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
695                       atc->mpi_comm);
696     }
697     else { /* backward, redistribution from pme to pp */
698         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
699                       pme->redist_buf, pme->scounts, pme->sdispls,
700                       pme->rvec_mpi, atc->mpi_comm);
701
702         /* Copy data from receive buffer */
703         for(i=0; i<atc->nslab; i++)
704             pme->sidx[i] = pme->sdispls[i];
705         for(i=0; (i<n); i++) {
706             ii = DIM*pme->sidx[pme->idxa[i]];
707             x_f[i][XX] += pme->redist_buf[ii+XX];
708             x_f[i][YY] += pme->redist_buf[ii+YY];
709             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
710             pme->sidx[pme->idxa[i]]++;
711         }
712     }
713 #endif
714 }
715
716 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
717                             gmx_bool bBackward,int shift,
718                             void *buf_s,int nbyte_s,
719                             void *buf_r,int nbyte_r)
720 {
721 #ifdef GMX_MPI
722     int dest,src;
723     MPI_Status stat;
724
725     if (bBackward == FALSE) {
726         dest = atc->node_dest[shift];
727         src  = atc->node_src[shift];
728     } else {
729         dest = atc->node_src[shift];
730         src  = atc->node_dest[shift];
731     }
732
733     if (nbyte_s > 0 && nbyte_r > 0) {
734         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
735                      dest,shift,
736                      buf_r,nbyte_r,MPI_BYTE,
737                      src,shift,
738                      atc->mpi_comm,&stat);
739     } else if (nbyte_s > 0) {
740         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
741                  dest,shift,
742                  atc->mpi_comm);
743     } else if (nbyte_r > 0) {
744         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
745                  src,shift,
746                  atc->mpi_comm,&stat);
747     }
748 #endif
749 }
750
751 static void dd_pmeredist_x_q(gmx_pme_t pme,
752                              int n, gmx_bool bX, rvec *x, real *charge,
753                              pme_atomcomm_t *atc)
754 {
755     int *commnode,*buf_index;
756     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
757
758     commnode  = atc->node_dest;
759     buf_index = atc->buf_index;
760
761     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
762
763     nsend = 0;
764     for(i=0; i<nnodes_comm; i++) {
765         buf_index[commnode[i]] = nsend;
766         nsend += atc->count[commnode[i]];
767     }
768     if (bX) {
769         if (atc->count[atc->nodeid] + nsend != n)
770             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
771                       "This usually means that your system is not well equilibrated.",
772                       n - (atc->count[atc->nodeid] + nsend),
773                       pme->nodeid,'x'+atc->dimind);
774
775         if (nsend > pme->buf_nalloc) {
776             pme->buf_nalloc = over_alloc_dd(nsend);
777             srenew(pme->bufv,pme->buf_nalloc);
778             srenew(pme->bufr,pme->buf_nalloc);
779         }
780
781         atc->n = atc->count[atc->nodeid];
782         for(i=0; i<nnodes_comm; i++) {
783             scount = atc->count[commnode[i]];
784             /* Communicate the count */
785             if (debug)
786                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
787                         atc->dimind,atc->nodeid,commnode[i],scount);
788             pme_dd_sendrecv(atc,FALSE,i,
789                             &scount,sizeof(int),
790                             &atc->rcount[i],sizeof(int));
791             atc->n += atc->rcount[i];
792         }
793
794         pme_realloc_atomcomm_things(atc);
795     }
796
797     local_pos = 0;
798     for(i=0; i<n; i++) {
799         node = atc->pd[i];
800         if (node == atc->nodeid) {
801             /* Copy direct to the receive buffer */
802             if (bX) {
803                 copy_rvec(x[i],atc->x[local_pos]);
804             }
805             atc->q[local_pos] = charge[i];
806             local_pos++;
807         } else {
808             /* Copy to the send buffer */
809             if (bX) {
810                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
811             }
812             pme->bufr[buf_index[node]] = charge[i];
813             buf_index[node]++;
814         }
815     }
816
817     buf_pos = 0;
818     for(i=0; i<nnodes_comm; i++) {
819         scount = atc->count[commnode[i]];
820         rcount = atc->rcount[i];
821         if (scount > 0 || rcount > 0) {
822             if (bX) {
823                 /* Communicate the coordinates */
824                 pme_dd_sendrecv(atc,FALSE,i,
825                                 pme->bufv[buf_pos],scount*sizeof(rvec),
826                                 atc->x[local_pos],rcount*sizeof(rvec));
827             }
828             /* Communicate the charges */
829             pme_dd_sendrecv(atc,FALSE,i,
830                             pme->bufr+buf_pos,scount*sizeof(real),
831                             atc->q+local_pos,rcount*sizeof(real));
832             buf_pos   += scount;
833             local_pos += atc->rcount[i];
834         }
835     }
836 }
837
838 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
839                            int n, rvec *f,
840                            gmx_bool bAddF)
841 {
842   int *commnode,*buf_index;
843   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
844
845   commnode  = atc->node_dest;
846   buf_index = atc->buf_index;
847
848   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
849
850   local_pos = atc->count[atc->nodeid];
851   buf_pos = 0;
852   for(i=0; i<nnodes_comm; i++) {
853     scount = atc->rcount[i];
854     rcount = atc->count[commnode[i]];
855     if (scount > 0 || rcount > 0) {
856       /* Communicate the forces */
857       pme_dd_sendrecv(atc,TRUE,i,
858                       atc->f[local_pos],scount*sizeof(rvec),
859                       pme->bufv[buf_pos],rcount*sizeof(rvec));
860       local_pos += scount;
861     }
862     buf_index[commnode[i]] = buf_pos;
863     buf_pos   += rcount;
864   }
865
866     local_pos = 0;
867     if (bAddF)
868     {
869         for(i=0; i<n; i++)
870         {
871             node = atc->pd[i];
872             if (node == atc->nodeid)
873             {
874                 /* Add from the local force array */
875                 rvec_inc(f[i],atc->f[local_pos]);
876                 local_pos++;
877             }
878             else
879             {
880                 /* Add from the receive buffer */
881                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
882                 buf_index[node]++;
883             }
884         }
885     }
886     else
887     {
888         for(i=0; i<n; i++)
889         {
890             node = atc->pd[i];
891             if (node == atc->nodeid)
892             {
893                 /* Copy from the local force array */
894                 copy_rvec(atc->f[local_pos],f[i]);
895                 local_pos++;
896             }
897             else
898             {
899                 /* Copy from the receive buffer */
900                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
901                 buf_index[node]++;
902             }
903         }
904     }
905 }
906
907 #ifdef GMX_MPI
908 static void
909 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
910 {
911     pme_overlap_t *overlap;
912     int send_index0,send_nindex;
913     int recv_index0,recv_nindex;
914     MPI_Status stat;
915     int i,j,k,ix,iy,iz,icnt;
916     int ipulse,send_id,recv_id,datasize;
917     real *p;
918     real *sendptr,*recvptr;
919
920     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
921     overlap = &pme->overlap[1];
922
923     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
924     {
925         /* Since we have already (un)wrapped the overlap in the z-dimension,
926          * we only have to communicate 0 to nkz (not pmegrid_nz).
927          */
928         if (direction==GMX_SUM_QGRID_FORWARD)
929         {
930             send_id = overlap->send_id[ipulse];
931             recv_id = overlap->recv_id[ipulse];
932             send_index0   = overlap->comm_data[ipulse].send_index0;
933             send_nindex   = overlap->comm_data[ipulse].send_nindex;
934             recv_index0   = overlap->comm_data[ipulse].recv_index0;
935             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
936         }
937         else
938         {
939             send_id = overlap->recv_id[ipulse];
940             recv_id = overlap->send_id[ipulse];
941             send_index0   = overlap->comm_data[ipulse].recv_index0;
942             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
943             recv_index0   = overlap->comm_data[ipulse].send_index0;
944             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
945         }
946
947         /* Copy data to contiguous send buffer */
948         if (debug)
949         {
950             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
951                     pme->nodeid,overlap->nodeid,send_id,
952                     pme->pmegrid_start_iy,
953                     send_index0-pme->pmegrid_start_iy,
954                     send_index0-pme->pmegrid_start_iy+send_nindex);
955         }
956         icnt = 0;
957         for(i=0;i<pme->pmegrid_nx;i++)
958         {
959             ix = i;
960             for(j=0;j<send_nindex;j++)
961             {
962                 iy = j + send_index0 - pme->pmegrid_start_iy;
963                 for(k=0;k<pme->nkz;k++)
964                 {
965                     iz = k;
966                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
967                 }
968             }
969         }
970
971         datasize      = pme->pmegrid_nx * pme->nkz;
972
973         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
974                      send_id,ipulse,
975                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
976                      recv_id,ipulse,
977                      overlap->mpi_comm,&stat);
978
979         /* Get data from contiguous recv buffer */
980         if (debug)
981         {
982             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
983                     pme->nodeid,overlap->nodeid,recv_id,
984                     pme->pmegrid_start_iy,
985                     recv_index0-pme->pmegrid_start_iy,
986                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
987         }
988         icnt = 0;
989         for(i=0;i<pme->pmegrid_nx;i++)
990         {
991             ix = i;
992             for(j=0;j<recv_nindex;j++)
993             {
994                 iy = j + recv_index0 - pme->pmegrid_start_iy;
995                 for(k=0;k<pme->nkz;k++)
996                 {
997                     iz = k;
998                     if(direction==GMX_SUM_QGRID_FORWARD)
999                     {
1000                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1001                     }
1002                     else
1003                     {
1004                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1005                     }
1006                 }
1007             }
1008         }
1009     }
1010
1011     /* Major dimension is easier, no copying required,
1012      * but we might have to sum to separate array.
1013      * Since we don't copy, we have to communicate up to pmegrid_nz,
1014      * not nkz as for the minor direction.
1015      */
1016     overlap = &pme->overlap[0];
1017
1018     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1019     {
1020         if(direction==GMX_SUM_QGRID_FORWARD)
1021         {
1022             send_id = overlap->send_id[ipulse];
1023             recv_id = overlap->recv_id[ipulse];
1024             send_index0   = overlap->comm_data[ipulse].send_index0;
1025             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1026             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1027             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1028             recvptr   = overlap->recvbuf;
1029         }
1030         else
1031         {
1032             send_id = overlap->recv_id[ipulse];
1033             recv_id = overlap->send_id[ipulse];
1034             send_index0   = overlap->comm_data[ipulse].recv_index0;
1035             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1036             recv_index0   = overlap->comm_data[ipulse].send_index0;
1037             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1038             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1039         }
1040
1041         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1042         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1043
1044         if (debug)
1045         {
1046             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1047                     pme->nodeid,overlap->nodeid,send_id,
1048                     pme->pmegrid_start_ix,
1049                     send_index0-pme->pmegrid_start_ix,
1050                     send_index0-pme->pmegrid_start_ix+send_nindex);
1051             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1052                     pme->nodeid,overlap->nodeid,recv_id,
1053                     pme->pmegrid_start_ix,
1054                     recv_index0-pme->pmegrid_start_ix,
1055                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1056         }
1057
1058         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1059                      send_id,ipulse,
1060                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1061                      recv_id,ipulse,
1062                      overlap->mpi_comm,&stat);
1063
1064         /* ADD data from contiguous recv buffer */
1065         if(direction==GMX_SUM_QGRID_FORWARD)
1066         {
1067             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1068             for(i=0;i<recv_nindex*datasize;i++)
1069             {
1070                 p[i] += overlap->recvbuf[i];
1071             }
1072         }
1073     }
1074 }
1075 #endif
1076
1077
1078 static int
1079 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1080 {
1081     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1082     ivec    local_pme_size;
1083     int     i,ix,iy,iz;
1084     int     pmeidx,fftidx;
1085
1086     /* Dimensions should be identical for A/B grid, so we just use A here */
1087     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1088                                    local_fft_ndata,
1089                                    local_fft_offset,
1090                                    local_fft_size);
1091
1092     local_pme_size[0] = pme->pmegrid_nx;
1093     local_pme_size[1] = pme->pmegrid_ny;
1094     local_pme_size[2] = pme->pmegrid_nz;
1095
1096     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1097      the offset is identical, and the PME grid always has more data (due to overlap)
1098      */
1099     {
1100 #ifdef DEBUG_PME
1101         FILE *fp,*fp2;
1102         char fn[STRLEN],format[STRLEN];
1103         real val;
1104         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1105         fp = ffopen(fn,"w");
1106         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1107         fp2 = ffopen(fn,"w");
1108      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1109 #endif
1110
1111     for(ix=0;ix<local_fft_ndata[XX];ix++)
1112     {
1113         for(iy=0;iy<local_fft_ndata[YY];iy++)
1114         {
1115             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1116             {
1117                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1118                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1119                 fftgrid[fftidx] = pmegrid[pmeidx];
1120 #ifdef DEBUG_PME
1121                 val = 100*pmegrid[pmeidx];
1122                 if (pmegrid[pmeidx] != 0)
1123                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1124                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1125                 if (pmegrid[pmeidx] != 0)
1126                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1127                             "qgrid",
1128                             pme->pmegrid_start_ix + ix,
1129                             pme->pmegrid_start_iy + iy,
1130                             pme->pmegrid_start_iz + iz,
1131                             pmegrid[pmeidx]);
1132 #endif
1133             }
1134         }
1135     }
1136 #ifdef DEBUG_PME
1137     ffclose(fp);
1138     ffclose(fp2);
1139 #endif
1140     }
1141     return 0;
1142 }
1143
1144
1145 static gmx_cycles_t omp_cyc_start()
1146 {
1147     return gmx_cycles_read();
1148 }
1149
1150 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1151 {
1152     return gmx_cycles_read() - c;
1153 }
1154
1155
1156 static int
1157 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1158                         int nthread,int thread)
1159 {
1160     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1161     ivec    local_pme_size;
1162     int     ixy0,ixy1,ixy,ix,iy,iz;
1163     int     pmeidx,fftidx;
1164 #ifdef PME_TIME_THREADS
1165     gmx_cycles_t c1;
1166     static double cs1=0;
1167     static int cnt=0;
1168 #endif
1169
1170 #ifdef PME_TIME_THREADS
1171     c1 = omp_cyc_start();
1172 #endif
1173     /* Dimensions should be identical for A/B grid, so we just use A here */
1174     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1175                                    local_fft_ndata,
1176                                    local_fft_offset,
1177                                    local_fft_size);
1178
1179     local_pme_size[0] = pme->pmegrid_nx;
1180     local_pme_size[1] = pme->pmegrid_ny;
1181     local_pme_size[2] = pme->pmegrid_nz;
1182
1183     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1184      the offset is identical, and the PME grid always has more data (due to overlap)
1185      */
1186     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1187     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1188
1189     for(ixy=ixy0;ixy<ixy1;ixy++)
1190     {
1191         ix = ixy/local_fft_ndata[YY];
1192         iy = ixy - ix*local_fft_ndata[YY];
1193
1194         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1195         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1196         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1197         {
1198             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1199         }
1200     }
1201
1202 #ifdef PME_TIME_THREADS
1203     c1 = omp_cyc_end(c1);
1204     cs1 += (double)c1;
1205     cnt++;
1206     if (cnt % 20 == 0)
1207     {
1208         printf("copy %.2f\n",cs1*1e-9);
1209     }
1210 #endif
1211
1212     return 0;
1213 }
1214
1215
1216 static void
1217 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1218 {
1219     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1220
1221     nx = pme->nkx;
1222     ny = pme->nky;
1223     nz = pme->nkz;
1224
1225     pnx = pme->pmegrid_nx;
1226     pny = pme->pmegrid_ny;
1227     pnz = pme->pmegrid_nz;
1228
1229     overlap = pme->pme_order - 1;
1230
1231     /* Add periodic overlap in z */
1232     for(ix=0; ix<pme->pmegrid_nx; ix++)
1233     {
1234         for(iy=0; iy<pme->pmegrid_ny; iy++)
1235         {
1236             for(iz=0; iz<overlap; iz++)
1237             {
1238                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1239                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1240             }
1241         }
1242     }
1243
1244     if (pme->nnodes_minor == 1)
1245     {
1246        for(ix=0; ix<pme->pmegrid_nx; ix++)
1247        {
1248            for(iy=0; iy<overlap; iy++)
1249            {
1250                for(iz=0; iz<nz; iz++)
1251                {
1252                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1253                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1254                }
1255            }
1256        }
1257     }
1258
1259     if (pme->nnodes_major == 1)
1260     {
1261         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1262
1263         for(ix=0; ix<overlap; ix++)
1264         {
1265             for(iy=0; iy<ny_x; iy++)
1266             {
1267                 for(iz=0; iz<nz; iz++)
1268                 {
1269                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1270                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1271                 }
1272             }
1273         }
1274     }
1275 }
1276
1277
1278 static void
1279 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1280 {
1281     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1282
1283     nx = pme->nkx;
1284     ny = pme->nky;
1285     nz = pme->nkz;
1286
1287     pnx = pme->pmegrid_nx;
1288     pny = pme->pmegrid_ny;
1289     pnz = pme->pmegrid_nz;
1290
1291     overlap = pme->pme_order - 1;
1292
1293     if (pme->nnodes_major == 1)
1294     {
1295         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1296
1297         for(ix=0; ix<overlap; ix++)
1298         {
1299             int iy,iz;
1300
1301             for(iy=0; iy<ny_x; iy++)
1302             {
1303                 for(iz=0; iz<nz; iz++)
1304                 {
1305                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1306                         pmegrid[(ix*pny+iy)*pnz+iz];
1307                 }
1308             }
1309         }
1310     }
1311
1312     if (pme->nnodes_minor == 1)
1313     {
1314 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1315        for(ix=0; ix<pme->pmegrid_nx; ix++)
1316        {
1317            int iy,iz;
1318
1319            for(iy=0; iy<overlap; iy++)
1320            {
1321                for(iz=0; iz<nz; iz++)
1322                {
1323                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1324                        pmegrid[(ix*pny+iy)*pnz+iz];
1325                }
1326            }
1327        }
1328     }
1329
1330     /* Copy periodic overlap in z */
1331 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1332     for(ix=0; ix<pme->pmegrid_nx; ix++)
1333     {
1334         int iy,iz;
1335
1336         for(iy=0; iy<pme->pmegrid_ny; iy++)
1337         {
1338             for(iz=0; iz<overlap; iz++)
1339             {
1340                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1341                     pmegrid[(ix*pny+iy)*pnz+iz];
1342             }
1343         }
1344     }
1345 }
1346
1347 static void clear_grid(int nx,int ny,int nz,real *grid,
1348                        ivec fs,int *flag,
1349                        int fx,int fy,int fz,
1350                        int order)
1351 {
1352     int nc,ncz;
1353     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1354     int flind;
1355
1356     nc  = 2 + (order - 2)/FLBS;
1357     ncz = 2 + (order - 2)/FLBSZ;
1358
1359     for(fsx=fx; fsx<fx+nc; fsx++)
1360     {
1361         for(fsy=fy; fsy<fy+nc; fsy++)
1362         {
1363             for(fsz=fz; fsz<fz+ncz; fsz++)
1364             {
1365                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1366                 if (flag[flind] == 0)
1367                 {
1368                     gx = fsx*FLBS;
1369                     gy = fsy*FLBS;
1370                     gz = fsz*FLBSZ;
1371                     g0x = (gx*ny + gy)*nz + gz;
1372                     for(x=0; x<FLBS; x++)
1373                     {
1374                         g0y = g0x;
1375                         for(y=0; y<FLBS; y++)
1376                         {
1377                             for(z=0; z<FLBSZ; z++)
1378                             {
1379                                 grid[g0y+z] = 0;
1380                             }
1381                             g0y += nz;
1382                         }
1383                         g0x += ny*nz;
1384                     }
1385
1386                     flag[flind] = 1;
1387                 }
1388             }
1389         }
1390     }
1391 }
1392
1393 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1394 #define DO_BSPLINE(order)                            \
1395 for(ithx=0; (ithx<order); ithx++)                    \
1396 {                                                    \
1397     index_x = (i0+ithx)*pny*pnz;                     \
1398     valx    = qn*thx[ithx];                          \
1399                                                      \
1400     for(ithy=0; (ithy<order); ithy++)                \
1401     {                                                \
1402         valxy    = valx*thy[ithy];                   \
1403         index_xy = index_x+(j0+ithy)*pnz;            \
1404                                                      \
1405         for(ithz=0; (ithz<order); ithz++)            \
1406         {                                            \
1407             index_xyz        = index_xy+(k0+ithz);   \
1408             grid[index_xyz] += valxy*thz[ithz];      \
1409         }                                            \
1410     }                                                \
1411 }
1412
1413
1414 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1415                                      pme_atomcomm_t *atc, splinedata_t *spline,
1416                                      pme_spline_work_t *work)
1417 {
1418
1419     /* spread charges from home atoms to local grid */
1420     real     *grid;
1421     pme_overlap_t *ol;
1422     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1423     int *    idxptr;
1424     int      order,norder,index_x,index_xy,index_xyz;
1425     real     valx,valxy,qn;
1426     real     *thx,*thy,*thz;
1427     int      localsize, bndsize;
1428     int      pnx,pny,pnz,ndatatot;
1429     int      offx,offy,offz;
1430
1431     pnx = pmegrid->n[XX];
1432     pny = pmegrid->n[YY];
1433     pnz = pmegrid->n[ZZ];
1434
1435     offx = pmegrid->offset[XX];
1436     offy = pmegrid->offset[YY];
1437     offz = pmegrid->offset[ZZ];
1438
1439     ndatatot = pnx*pny*pnz;
1440     grid = pmegrid->grid;
1441     for(i=0;i<ndatatot;i++)
1442     {
1443         grid[i] = 0;
1444     }
1445
1446     order = pmegrid->order;
1447
1448     for(nn=0; nn<spline->n; nn++)
1449     {
1450         n  = spline->ind[nn];
1451         qn = atc->q[n];
1452
1453         if (qn != 0)
1454         {
1455             idxptr = atc->idx[n];
1456             norder = nn*order;
1457
1458             i0   = idxptr[XX] - offx;
1459             j0   = idxptr[YY] - offy;
1460             k0   = idxptr[ZZ] - offz;
1461
1462             thx = spline->theta[XX] + norder;
1463             thy = spline->theta[YY] + norder;
1464             thz = spline->theta[ZZ] + norder;
1465
1466             switch (order) {
1467             case 4:
1468 #ifdef PME_SSE
1469 #ifdef PME_SSE_UNALIGNED
1470 #define PME_SPREAD_SSE_ORDER4
1471 #else
1472 #define PME_SPREAD_SSE_ALIGNED
1473 #define PME_ORDER 4
1474 #endif
1475 #include "pme_sse_single.h"
1476 #else
1477                 DO_BSPLINE(4);
1478 #endif
1479                 break;
1480             case 5:
1481 #ifdef PME_SSE
1482 #define PME_SPREAD_SSE_ALIGNED
1483 #define PME_ORDER 5
1484 #include "pme_sse_single.h"
1485 #else
1486                 DO_BSPLINE(5);
1487 #endif
1488                 break;
1489             default:
1490                 DO_BSPLINE(order);
1491                 break;
1492             }
1493         }
1494     }
1495 }
1496
1497 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1498 {
1499 #ifdef PME_SSE
1500     if (pme_order == 5
1501 #ifndef PME_SSE_UNALIGNED
1502         || pme_order == 4
1503 #endif
1504         )
1505     {
1506         /* Round nz up to a multiple of 4 to ensure alignment */
1507         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1508     }
1509 #endif
1510 }
1511
1512 static void set_gridsize_alignment(int *gridsize,int pme_order)
1513 {
1514 #ifdef PME_SSE
1515 #ifndef PME_SSE_UNALIGNED
1516     if (pme_order == 4)
1517     {
1518         /* Add extra elements to ensured aligned operations do not go
1519          * beyond the allocated grid size.
1520          * Note that for pme_order=5, the pme grid z-size alignment
1521          * ensures that we will not go beyond the grid size.
1522          */
1523          *gridsize += 4;
1524     }
1525 #endif
1526 #endif
1527 }
1528
1529 static void pmegrid_init(pmegrid_t *grid,
1530                          int cx, int cy, int cz,
1531                          int x0, int y0, int z0,
1532                          int x1, int y1, int z1,
1533                          gmx_bool set_alignment,
1534                          int pme_order,
1535                          real *ptr)
1536 {
1537     int nz,gridsize;
1538
1539     grid->ci[XX] = cx;
1540     grid->ci[YY] = cy;
1541     grid->ci[ZZ] = cz;
1542     grid->offset[XX] = x0;
1543     grid->offset[YY] = y0;
1544     grid->offset[ZZ] = z0;
1545     grid->n[XX]      = x1 - x0 + pme_order - 1;
1546     grid->n[YY]      = y1 - y0 + pme_order - 1;
1547     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1548
1549     nz = grid->n[ZZ];
1550     set_grid_alignment(&nz,pme_order);
1551     if (set_alignment)
1552     {
1553         grid->n[ZZ] = nz;
1554     }
1555     else if (nz != grid->n[ZZ])
1556     {
1557         gmx_incons("pmegrid_init call with an unaligned z size");
1558     }
1559
1560     grid->order = pme_order;
1561     if (ptr == NULL)
1562     {
1563         gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1564         set_gridsize_alignment(&gridsize,pme_order);
1565         snew_aligned(grid->grid,gridsize,16);
1566     }
1567     else
1568     {
1569         grid->grid = ptr;
1570     }
1571 }
1572
1573 static int div_round_up(int enumerator,int denominator)
1574 {
1575     return (enumerator + denominator - 1)/denominator;
1576 }
1577
1578 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1579                                   ivec nsub)
1580 {
1581     int gsize_opt,gsize;
1582     int nsx,nsy,nsz;
1583     char *env;
1584
1585     gsize_opt = -1;
1586     for(nsx=1; nsx<=nthread; nsx++)
1587     {
1588         if (nthread % nsx == 0)
1589         {
1590             for(nsy=1; nsy<=nthread; nsy++)
1591             {
1592                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1593                 {
1594                     nsz = nthread/(nsx*nsy);
1595
1596                     /* Determine the number of grid points per thread */
1597                     gsize =
1598                         (div_round_up(n[XX],nsx) + ovl)*
1599                         (div_round_up(n[YY],nsy) + ovl)*
1600                         (div_round_up(n[ZZ],nsz) + ovl);
1601
1602                     /* Minimize the number of grids points per thread
1603                      * and, secondarily, the number of cuts in minor dimensions.
1604                      */
1605                     if (gsize_opt == -1 ||
1606                         gsize < gsize_opt ||
1607                         (gsize == gsize_opt &&
1608                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1609                     {
1610                         nsub[XX] = nsx;
1611                         nsub[YY] = nsy;
1612                         nsub[ZZ] = nsz;
1613                         gsize_opt = gsize;
1614                     }
1615                 }
1616             }
1617         }
1618     }
1619
1620     env = getenv("GMX_PME_THREAD_DIVISION");
1621     if (env != NULL)
1622     {
1623         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1624     }
1625
1626     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1627     {
1628         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1629     }
1630 }
1631
1632 static void pmegrids_init(pmegrids_t *grids,
1633                           int nx,int ny,int nz,int nz_base,
1634                           int pme_order,
1635                           int nthread,
1636                           int overlap_x,
1637                           int overlap_y)
1638 {
1639     ivec n,n_base,g0,g1;
1640     int t,x,y,z,d,i,tfac;
1641     int max_comm_lines;
1642
1643     n[XX] = nx - (pme_order - 1);
1644     n[YY] = ny - (pme_order - 1);
1645     n[ZZ] = nz - (pme_order - 1);
1646
1647     copy_ivec(n,n_base);
1648     n_base[ZZ] = nz_base;
1649
1650     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1651                  NULL);
1652
1653     grids->nthread = nthread;
1654
1655     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1656
1657     if (grids->nthread > 1)
1658     {
1659         ivec nst;
1660         int gridsize;
1661         real *grid_all;
1662
1663         for(d=0; d<DIM; d++)
1664         {
1665             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1666         }
1667         set_grid_alignment(&nst[ZZ],pme_order);
1668
1669         if (debug)
1670         {
1671             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1672                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1673             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1674                     nx,ny,nz,
1675                     nst[XX],nst[YY],nst[ZZ]);
1676         }
1677
1678         snew(grids->grid_th,grids->nthread);
1679         t = 0;
1680         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1681         set_gridsize_alignment(&gridsize,pme_order);
1682         snew_aligned(grid_all,
1683                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1684                      16);
1685
1686         for(x=0; x<grids->nc[XX]; x++)
1687         {
1688             for(y=0; y<grids->nc[YY]; y++)
1689             {
1690                 for(z=0; z<grids->nc[ZZ]; z++)
1691                 {
1692                     pmegrid_init(&grids->grid_th[t],
1693                                  x,y,z,
1694                                  (n[XX]*(x  ))/grids->nc[XX],
1695                                  (n[YY]*(y  ))/grids->nc[YY],
1696                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1697                                  (n[XX]*(x+1))/grids->nc[XX],
1698                                  (n[YY]*(y+1))/grids->nc[YY],
1699                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1700                                  TRUE,
1701                                  pme_order,
1702                                  grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1703                     t++;
1704                 }
1705             }
1706         }
1707     }
1708
1709     snew(grids->g2t,DIM);
1710     tfac = 1;
1711     for(d=DIM-1; d>=0; d--)
1712     {
1713         snew(grids->g2t[d],n[d]);
1714         t = 0;
1715         for(i=0; i<n[d]; i++)
1716         {
1717             /* The second check should match the parameters
1718              * of the pmegrid_init call above.
1719              */
1720             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1721             {
1722                 t++;
1723             }
1724             grids->g2t[d][i] = t*tfac;
1725         }
1726
1727         tfac *= grids->nc[d];
1728
1729         switch (d)
1730         {
1731         case XX: max_comm_lines = overlap_x;     break;
1732         case YY: max_comm_lines = overlap_y;     break;
1733         case ZZ: max_comm_lines = pme_order - 1; break;
1734         }
1735         grids->nthread_comm[d] = 0;
1736         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1737         {
1738             grids->nthread_comm[d]++;
1739         }
1740         if (debug != NULL)
1741         {
1742             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1743                     'x'+d,grids->nthread_comm[d]);
1744         }
1745         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1746          * work, but this is not a problematic restriction.
1747          */
1748         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1749         {
1750             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1751         }
1752     }
1753 }
1754
1755
1756 static void pmegrids_destroy(pmegrids_t *grids)
1757 {
1758     int t;
1759
1760     if (grids->grid.grid != NULL)
1761     {
1762         sfree(grids->grid.grid);
1763
1764         if (grids->nthread > 0)
1765         {
1766             for(t=0; t<grids->nthread; t++)
1767             {
1768                 sfree(grids->grid_th[t].grid);
1769             }
1770             sfree(grids->grid_th);
1771         }
1772     }
1773 }
1774
1775
1776 static void realloc_work(pme_work_t *work,int nkx)
1777 {
1778     if (nkx > work->nalloc)
1779     {
1780         work->nalloc = nkx;
1781         srenew(work->mhx  ,work->nalloc);
1782         srenew(work->mhy  ,work->nalloc);
1783         srenew(work->mhz  ,work->nalloc);
1784         srenew(work->m2   ,work->nalloc);
1785         /* Allocate an aligned pointer for SSE operations, including 3 extra
1786          * elements at the end since SSE operates on 4 elements at a time.
1787          */
1788         sfree_aligned(work->denom);
1789         sfree_aligned(work->tmp1);
1790         sfree_aligned(work->eterm);
1791         snew_aligned(work->denom,work->nalloc+3,16);
1792         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1793         snew_aligned(work->eterm,work->nalloc+3,16);
1794         srenew(work->m2inv,work->nalloc);
1795     }
1796 }
1797
1798
1799 static void free_work(pme_work_t *work)
1800 {
1801     sfree(work->mhx);
1802     sfree(work->mhy);
1803     sfree(work->mhz);
1804     sfree(work->m2);
1805     sfree_aligned(work->denom);
1806     sfree_aligned(work->tmp1);
1807     sfree_aligned(work->eterm);
1808     sfree(work->m2inv);
1809 }
1810
1811
1812 #ifdef PME_SSE
1813     /* Calculate exponentials through SSE in float precision */
1814 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1815 {
1816     {
1817         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1818         __m128 f_sse;
1819         __m128 lu;
1820         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1821         int kx;
1822         f_sse = _mm_load1_ps(&f);
1823         for(kx=0; kx<end; kx+=4)
1824         {
1825             tmp_d1   = _mm_load_ps(d_aligned+kx);
1826             lu       = _mm_rcp_ps(tmp_d1);
1827             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1828             tmp_r    = _mm_load_ps(r_aligned+kx);
1829             tmp_r    = gmx_mm_exp_ps(tmp_r);
1830             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1831             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1832             _mm_store_ps(e_aligned+kx,tmp_e);
1833         }
1834     }
1835 }
1836 #else
1837 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1838 {
1839     int kx;
1840     for(kx=start; kx<end; kx++)
1841     {
1842         d[kx] = 1.0/d[kx];
1843     }
1844     for(kx=start; kx<end; kx++)
1845     {
1846         r[kx] = exp(r[kx]);
1847     }
1848     for(kx=start; kx<end; kx++)
1849     {
1850         e[kx] = f*r[kx]*d[kx];
1851     }
1852 }
1853 #endif
1854
1855
1856 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1857                          real ewaldcoeff,real vol,
1858                          gmx_bool bEnerVir,
1859                          int nthread,int thread)
1860 {
1861     /* do recip sum over local cells in grid */
1862     /* y major, z middle, x minor or continuous */
1863     t_complex *p0;
1864     int     kx,ky,kz,maxkx,maxky,maxkz;
1865     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1866     real    mx,my,mz;
1867     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1868     real    ets2,struct2,vfactor,ets2vf;
1869     real    d1,d2,energy=0;
1870     real    by,bz;
1871     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1872     real    rxx,ryx,ryy,rzx,rzy,rzz;
1873     pme_work_t *work;
1874     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1875     real    mhxk,mhyk,mhzk,m2k;
1876     real    corner_fac;
1877     ivec    complex_order;
1878     ivec    local_ndata,local_offset,local_size;
1879     real    elfac;
1880
1881     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1882
1883     nx = pme->nkx;
1884     ny = pme->nky;
1885     nz = pme->nkz;
1886
1887     /* Dimensions should be identical for A/B grid, so we just use A here */
1888     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1889                                       complex_order,
1890                                       local_ndata,
1891                                       local_offset,
1892                                       local_size);
1893
1894     rxx = pme->recipbox[XX][XX];
1895     ryx = pme->recipbox[YY][XX];
1896     ryy = pme->recipbox[YY][YY];
1897     rzx = pme->recipbox[ZZ][XX];
1898     rzy = pme->recipbox[ZZ][YY];
1899     rzz = pme->recipbox[ZZ][ZZ];
1900
1901     maxkx = (nx+1)/2;
1902     maxky = (ny+1)/2;
1903     maxkz = nz/2+1;
1904
1905     work = &pme->work[thread];
1906     mhx   = work->mhx;
1907     mhy   = work->mhy;
1908     mhz   = work->mhz;
1909     m2    = work->m2;
1910     denom = work->denom;
1911     tmp1  = work->tmp1;
1912     eterm = work->eterm;
1913     m2inv = work->m2inv;
1914
1915     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1916     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1917
1918     for(iyz=iyz0; iyz<iyz1; iyz++)
1919     {
1920         iy = iyz/local_ndata[ZZ];
1921         iz = iyz - iy*local_ndata[ZZ];
1922
1923         ky = iy + local_offset[YY];
1924
1925         if (ky < maxky)
1926         {
1927             my = ky;
1928         }
1929         else
1930         {
1931             my = (ky - ny);
1932         }
1933
1934         by = M_PI*vol*pme->bsp_mod[YY][ky];
1935
1936         kz = iz + local_offset[ZZ];
1937
1938         mz = kz;
1939
1940         bz = pme->bsp_mod[ZZ][kz];
1941
1942         /* 0.5 correction for corner points */
1943         corner_fac = 1;
1944         if (kz == 0 || kz == (nz+1)/2)
1945         {
1946             corner_fac = 0.5;
1947         }
1948
1949         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1950
1951         /* We should skip the k-space point (0,0,0) */
1952         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1953         {
1954             kxstart = local_offset[XX];
1955         }
1956         else
1957         {
1958             kxstart = local_offset[XX] + 1;
1959             p0++;
1960         }
1961         kxend = local_offset[XX] + local_ndata[XX];
1962
1963         if (bEnerVir)
1964         {
1965             /* More expensive inner loop, especially because of the storage
1966              * of the mh elements in array's.
1967              * Because x is the minor grid index, all mh elements
1968              * depend on kx for triclinic unit cells.
1969              */
1970
1971                 /* Two explicit loops to avoid a conditional inside the loop */
1972             for(kx=kxstart; kx<maxkx; kx++)
1973             {
1974                 mx = kx;
1975
1976                 mhxk      = mx * rxx;
1977                 mhyk      = mx * ryx + my * ryy;
1978                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1979                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1980                 mhx[kx]   = mhxk;
1981                 mhy[kx]   = mhyk;
1982                 mhz[kx]   = mhzk;
1983                 m2[kx]    = m2k;
1984                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1985                 tmp1[kx]  = -factor*m2k;
1986             }
1987
1988             for(kx=maxkx; kx<kxend; kx++)
1989             {
1990                 mx = (kx - nx);
1991
1992                 mhxk      = mx * rxx;
1993                 mhyk      = mx * ryx + my * ryy;
1994                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1995                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1996                 mhx[kx]   = mhxk;
1997                 mhy[kx]   = mhyk;
1998                 mhz[kx]   = mhzk;
1999                 m2[kx]    = m2k;
2000                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2001                 tmp1[kx]  = -factor*m2k;
2002             }
2003
2004             for(kx=kxstart; kx<kxend; kx++)
2005             {
2006                 m2inv[kx] = 1.0/m2[kx];
2007             }
2008
2009             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2010
2011             for(kx=kxstart; kx<kxend; kx++,p0++)
2012             {
2013                 d1      = p0->re;
2014                 d2      = p0->im;
2015
2016                 p0->re  = d1*eterm[kx];
2017                 p0->im  = d2*eterm[kx];
2018
2019                 struct2 = 2.0*(d1*d1+d2*d2);
2020
2021                 tmp1[kx] = eterm[kx]*struct2;
2022             }
2023
2024             for(kx=kxstart; kx<kxend; kx++)
2025             {
2026                 ets2     = corner_fac*tmp1[kx];
2027                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2028                 energy  += ets2;
2029
2030                 ets2vf   = ets2*vfactor;
2031                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2032                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2033                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2034                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2035                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2036                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2037             }
2038         }
2039         else
2040         {
2041             /* We don't need to calculate the energy and the virial.
2042              * In this case the triclinic overhead is small.
2043              */
2044
2045             /* Two explicit loops to avoid a conditional inside the loop */
2046
2047             for(kx=kxstart; kx<maxkx; kx++)
2048             {
2049                 mx = kx;
2050
2051                 mhxk      = mx * rxx;
2052                 mhyk      = mx * ryx + my * ryy;
2053                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2054                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2055                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2056                 tmp1[kx]  = -factor*m2k;
2057             }
2058
2059             for(kx=maxkx; kx<kxend; kx++)
2060             {
2061                 mx = (kx - nx);
2062
2063                 mhxk      = mx * rxx;
2064                 mhyk      = mx * ryx + my * ryy;
2065                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2066                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2067                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2068                 tmp1[kx]  = -factor*m2k;
2069             }
2070
2071             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2072
2073             for(kx=kxstart; kx<kxend; kx++,p0++)
2074             {
2075                 d1      = p0->re;
2076                 d2      = p0->im;
2077
2078                 p0->re  = d1*eterm[kx];
2079                 p0->im  = d2*eterm[kx];
2080             }
2081         }
2082     }
2083
2084     if (bEnerVir)
2085     {
2086         /* Update virial with local values.
2087          * The virial is symmetric by definition.
2088          * this virial seems ok for isotropic scaling, but I'm
2089          * experiencing problems on semiisotropic membranes.
2090          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2091          */
2092         work->vir[XX][XX] = 0.25*virxx;
2093         work->vir[YY][YY] = 0.25*viryy;
2094         work->vir[ZZ][ZZ] = 0.25*virzz;
2095         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2096         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2097         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2098
2099         /* This energy should be corrected for a charged system */
2100         work->energy = 0.5*energy;
2101     }
2102
2103     /* Return the loop count */
2104     return local_ndata[YY]*local_ndata[XX];
2105 }
2106
2107 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2108                              real *mesh_energy,matrix vir)
2109 {
2110     /* This function sums output over threads
2111      * and should therefore only be called after thread synchronization.
2112      */
2113     int thread;
2114
2115     *mesh_energy = pme->work[0].energy;
2116     copy_mat(pme->work[0].vir,vir);
2117
2118     for(thread=1; thread<nthread; thread++)
2119     {
2120         *mesh_energy += pme->work[thread].energy;
2121         m_add(vir,pme->work[thread].vir,vir);
2122     }
2123 }
2124
2125 #define DO_FSPLINE(order)                      \
2126 for(ithx=0; (ithx<order); ithx++)              \
2127 {                                              \
2128     index_x = (i0+ithx)*pny*pnz;               \
2129     tx      = thx[ithx];                       \
2130     dx      = dthx[ithx];                      \
2131                                                \
2132     for(ithy=0; (ithy<order); ithy++)          \
2133     {                                          \
2134         index_xy = index_x+(j0+ithy)*pnz;      \
2135         ty       = thy[ithy];                  \
2136         dy       = dthy[ithy];                 \
2137         fxy1     = fz1 = 0;                    \
2138                                                \
2139         for(ithz=0; (ithz<order); ithz++)      \
2140         {                                      \
2141             gval  = grid[index_xy+(k0+ithz)];  \
2142             fxy1 += thz[ithz]*gval;            \
2143             fz1  += dthz[ithz]*gval;           \
2144         }                                      \
2145         fx += dx*ty*fxy1;                      \
2146         fy += tx*dy*fxy1;                      \
2147         fz += tx*ty*fz1;                       \
2148     }                                          \
2149 }
2150
2151
2152 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2153                               gmx_bool bClearF,pme_atomcomm_t *atc,
2154                               splinedata_t *spline,
2155                               real scale)
2156 {
2157     /* sum forces for local particles */
2158     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2159     int     index_x,index_xy;
2160     int     nx,ny,nz,pnx,pny,pnz;
2161     int *   idxptr;
2162     real    tx,ty,dx,dy,qn;
2163     real    fx,fy,fz,gval;
2164     real    fxy1,fz1;
2165     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2166     int     norder;
2167     real    rxx,ryx,ryy,rzx,rzy,rzz;
2168     int     order;
2169
2170     pme_spline_work_t *work;
2171
2172     work = &pme->spline_work;
2173
2174     order = pme->pme_order;
2175     thx   = spline->theta[XX];
2176     thy   = spline->theta[YY];
2177     thz   = spline->theta[ZZ];
2178     dthx  = spline->dtheta[XX];
2179     dthy  = spline->dtheta[YY];
2180     dthz  = spline->dtheta[ZZ];
2181     nx    = pme->nkx;
2182     ny    = pme->nky;
2183     nz    = pme->nkz;
2184     pnx   = pme->pmegrid_nx;
2185     pny   = pme->pmegrid_ny;
2186     pnz   = pme->pmegrid_nz;
2187
2188     rxx   = pme->recipbox[XX][XX];
2189     ryx   = pme->recipbox[YY][XX];
2190     ryy   = pme->recipbox[YY][YY];
2191     rzx   = pme->recipbox[ZZ][XX];
2192     rzy   = pme->recipbox[ZZ][YY];
2193     rzz   = pme->recipbox[ZZ][ZZ];
2194
2195     for(nn=0; nn<spline->n; nn++)
2196     {
2197         n  = spline->ind[nn];
2198         qn = scale*atc->q[n];
2199
2200         if (bClearF)
2201         {
2202             atc->f[n][XX] = 0;
2203             atc->f[n][YY] = 0;
2204             atc->f[n][ZZ] = 0;
2205         }
2206         if (qn != 0)
2207         {
2208             fx     = 0;
2209             fy     = 0;
2210             fz     = 0;
2211             idxptr = atc->idx[n];
2212             norder = nn*order;
2213
2214             i0   = idxptr[XX];
2215             j0   = idxptr[YY];
2216             k0   = idxptr[ZZ];
2217
2218             /* Pointer arithmetic alert, next six statements */
2219             thx  = spline->theta[XX] + norder;
2220             thy  = spline->theta[YY] + norder;
2221             thz  = spline->theta[ZZ] + norder;
2222             dthx = spline->dtheta[XX] + norder;
2223             dthy = spline->dtheta[YY] + norder;
2224             dthz = spline->dtheta[ZZ] + norder;
2225
2226             switch (order) {
2227             case 4:
2228 #ifdef PME_SSE
2229 #ifdef PME_SSE_UNALIGNED
2230 #define PME_GATHER_F_SSE_ORDER4
2231 #else
2232 #define PME_GATHER_F_SSE_ALIGNED
2233 #define PME_ORDER 4
2234 #endif
2235 #include "pme_sse_single.h"
2236 #else
2237                 DO_FSPLINE(4);
2238 #endif
2239                 break;
2240             case 5:
2241 #ifdef PME_SSE
2242 #define PME_GATHER_F_SSE_ALIGNED
2243 #define PME_ORDER 5
2244 #include "pme_sse_single.h"
2245 #else
2246                 DO_FSPLINE(5);
2247 #endif
2248                 break;
2249             default:
2250                 DO_FSPLINE(order);
2251                 break;
2252             }
2253
2254             atc->f[n][XX] += -qn*( fx*nx*rxx );
2255             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2256             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2257         }
2258     }
2259     /* Since the energy and not forces are interpolated
2260      * the net force might not be exactly zero.
2261      * This can be solved by also interpolating F, but
2262      * that comes at a cost.
2263      * A better hack is to remove the net force every
2264      * step, but that must be done at a higher level
2265      * since this routine doesn't see all atoms if running
2266      * in parallel. Don't know how important it is?  EL 990726
2267      */
2268 }
2269
2270
2271 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2272                                    pme_atomcomm_t *atc)
2273 {
2274     splinedata_t *spline;
2275     int     n,ithx,ithy,ithz,i0,j0,k0;
2276     int     index_x,index_xy;
2277     int *   idxptr;
2278     real    energy,pot,tx,ty,qn,gval;
2279     real    *thx,*thy,*thz;
2280     int     norder;
2281     int     order;
2282
2283     spline = &atc->spline[0];
2284
2285     order = pme->pme_order;
2286
2287     energy = 0;
2288     for(n=0; (n<atc->n); n++) {
2289         qn      = atc->q[n];
2290
2291         if (qn != 0) {
2292             idxptr = atc->idx[n];
2293             norder = n*order;
2294
2295             i0   = idxptr[XX];
2296             j0   = idxptr[YY];
2297             k0   = idxptr[ZZ];
2298
2299             /* Pointer arithmetic alert, next three statements */
2300             thx  = spline->theta[XX] + norder;
2301             thy  = spline->theta[YY] + norder;
2302             thz  = spline->theta[ZZ] + norder;
2303
2304             pot = 0;
2305             for(ithx=0; (ithx<order); ithx++)
2306             {
2307                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2308                 tx      = thx[ithx];
2309
2310                 for(ithy=0; (ithy<order); ithy++)
2311                 {
2312                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2313                     ty       = thy[ithy];
2314
2315                     for(ithz=0; (ithz<order); ithz++)
2316                     {
2317                         gval  = grid[index_xy+(k0+ithz)];
2318                         pot  += tx*ty*thz[ithz]*gval;
2319                     }
2320
2321                 }
2322             }
2323
2324             energy += pot*qn;
2325         }
2326     }
2327
2328     return energy;
2329 }
2330
2331 /* Macro to force loop unrolling by fixing order.
2332  * This gives a significant performance gain.
2333  */
2334 #define CALC_SPLINE(order)                     \
2335 {                                              \
2336     int j,k,l;                                 \
2337     real dr,div;                               \
2338     real data[PME_ORDER_MAX];                  \
2339     real ddata[PME_ORDER_MAX];                 \
2340                                                \
2341     for(j=0; (j<DIM); j++)                     \
2342     {                                          \
2343         dr  = xptr[j];                         \
2344                                                \
2345         /* dr is relative offset from lower cell limit */ \
2346         data[order-1] = 0;                     \
2347         data[1] = dr;                          \
2348         data[0] = 1 - dr;                      \
2349                                                \
2350         for(k=3; (k<order); k++)               \
2351         {                                      \
2352             div = 1.0/(k - 1.0);               \
2353             data[k-1] = div*dr*data[k-2];      \
2354             for(l=1; (l<(k-1)); l++)           \
2355             {                                  \
2356                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2357                                    data[k-l-1]);                \
2358             }                                  \
2359             data[0] = div*(1-dr)*data[0];      \
2360         }                                      \
2361         /* differentiate */                    \
2362         ddata[0] = -data[0];                   \
2363         for(k=1; (k<order); k++)               \
2364         {                                      \
2365             ddata[k] = data[k-1] - data[k];    \
2366         }                                      \
2367                                                \
2368         div = 1.0/(order - 1);                 \
2369         data[order-1] = div*dr*data[order-2];  \
2370         for(l=1; (l<(order-1)); l++)           \
2371         {                                      \
2372             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2373                                (order-l-dr)*data[order-l-1]); \
2374         }                                      \
2375         data[0] = div*(1 - dr)*data[0];        \
2376                                                \
2377         for(k=0; k<order; k++)                 \
2378         {                                      \
2379             theta[j][i*order+k]  = data[k];    \
2380             dtheta[j][i*order+k] = ddata[k];   \
2381         }                                      \
2382     }                                          \
2383 }
2384
2385 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2386                    rvec fractx[],int nr,int ind[],real charge[],
2387                    gmx_bool bFreeEnergy)
2388 {
2389     /* construct splines for local atoms */
2390     int  i,ii;
2391     real *xptr;
2392
2393     for(i=0; i<nr; i++)
2394     {
2395         /* With free energy we do not use the charge check.
2396          * In most cases this will be more efficient than calling make_bsplines
2397          * twice, since usually more than half the particles have charges.
2398          */
2399         ii = ind[i];
2400         if (bFreeEnergy || charge[ii] != 0.0) {
2401             xptr = fractx[ii];
2402             switch(order) {
2403             case 4:  CALC_SPLINE(4);     break;
2404             case 5:  CALC_SPLINE(5);     break;
2405             default: CALC_SPLINE(order); break;
2406             }
2407         }
2408     }
2409 }
2410
2411
2412 void make_dft_mod(real *mod,real *data,int ndata)
2413 {
2414   int i,j;
2415   real sc,ss,arg;
2416
2417   for(i=0;i<ndata;i++) {
2418     sc=ss=0;
2419     for(j=0;j<ndata;j++) {
2420       arg=(2.0*M_PI*i*j)/ndata;
2421       sc+=data[j]*cos(arg);
2422       ss+=data[j]*sin(arg);
2423     }
2424     mod[i]=sc*sc+ss*ss;
2425   }
2426   for(i=0;i<ndata;i++)
2427     if(mod[i]<1e-7)
2428       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2429 }
2430
2431
2432 static void make_bspline_moduli(splinevec bsp_mod,
2433                                 int nx,int ny,int nz,int order)
2434 {
2435   int nmax=max(nx,max(ny,nz));
2436   real *data,*ddata,*bsp_data;
2437   int i,k,l;
2438   real div;
2439
2440   snew(data,order);
2441   snew(ddata,order);
2442   snew(bsp_data,nmax);
2443
2444   data[order-1]=0;
2445   data[1]=0;
2446   data[0]=1;
2447
2448   for(k=3;k<order;k++) {
2449     div=1.0/(k-1.0);
2450     data[k-1]=0;
2451     for(l=1;l<(k-1);l++)
2452       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2453     data[0]=div*data[0];
2454   }
2455   /* differentiate */
2456   ddata[0]=-data[0];
2457   for(k=1;k<order;k++)
2458     ddata[k]=data[k-1]-data[k];
2459   div=1.0/(order-1);
2460   data[order-1]=0;
2461   for(l=1;l<(order-1);l++)
2462     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2463   data[0]=div*data[0];
2464
2465   for(i=0;i<nmax;i++)
2466     bsp_data[i]=0;
2467   for(i=1;i<=order;i++)
2468     bsp_data[i]=data[i-1];
2469
2470   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2471   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2472   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2473
2474   sfree(data);
2475   sfree(ddata);
2476   sfree(bsp_data);
2477 }
2478
2479
2480 /* Return the P3M optimal influence function */
2481 static double do_p3m_influence(double z, int order)
2482 {
2483     double z2,z4;
2484
2485     z2 = z*z;
2486     z4 = z2*z2;
2487
2488     /* The formula and most constants can be found in:
2489      * Ballenegger et al., JCTC 8, 936 (2012)
2490      */
2491     switch(order)
2492     {
2493     case 2:
2494         return 1.0 - 2.0*z2/3.0;
2495         break;
2496     case 3:
2497         return 1.0 - z2 + 2.0*z4/15.0;
2498         break;
2499     case 4:
2500         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2501         break;
2502     case 5:
2503         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2504         break;
2505     case 6:
2506         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2507         break;
2508     case 7:
2509         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2510     case 8:
2511         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2512         break;
2513     }
2514
2515     return 0.0;
2516 }
2517
2518 /* Calculate the P3M B-spline moduli for one dimension */
2519 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2520 {
2521     double zarg,zai,sinzai,infl;
2522     int    maxk,i;
2523
2524     if (order > 8)
2525     {
2526         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2527     }
2528
2529     zarg = M_PI/n;
2530
2531     maxk = (n + 1)/2;
2532
2533     for(i=-maxk; i<0; i++)
2534     {
2535         zai    = zarg*i;
2536         sinzai = sin(zai);
2537         infl   = do_p3m_influence(sinzai,order);
2538         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2539     }
2540     bsp_mod[0] = 1.0;
2541     for(i=1; i<maxk; i++)
2542     {
2543         zai    = zarg*i;
2544         sinzai = sin(zai);
2545         infl   = do_p3m_influence(sinzai,order);
2546         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2547     }
2548 }
2549
2550 /* Calculate the P3M B-spline moduli */
2551 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2552                                     int nx,int ny,int nz,int order)
2553 {
2554     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2555     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2556     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2557 }
2558
2559
2560 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2561 {
2562   int nslab,n,i;
2563   int fw,bw;
2564
2565   nslab = atc->nslab;
2566
2567   n = 0;
2568   for(i=1; i<=nslab/2; i++) {
2569     fw = (atc->nodeid + i) % nslab;
2570     bw = (atc->nodeid - i + nslab) % nslab;
2571     if (n < nslab - 1) {
2572       atc->node_dest[n] = fw;
2573       atc->node_src[n]  = bw;
2574       n++;
2575     }
2576     if (n < nslab - 1) {
2577       atc->node_dest[n] = bw;
2578       atc->node_src[n]  = fw;
2579       n++;
2580     }
2581   }
2582 }
2583
2584 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2585 {
2586     int thread;
2587
2588     if(NULL != log)
2589     {
2590         fprintf(log,"Destroying PME data structures.\n");
2591     }
2592
2593     sfree((*pmedata)->nnx);
2594     sfree((*pmedata)->nny);
2595     sfree((*pmedata)->nnz);
2596
2597     pmegrids_destroy(&(*pmedata)->pmegridA);
2598
2599     sfree((*pmedata)->fftgridA);
2600     sfree((*pmedata)->cfftgridA);
2601     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2602
2603     if ((*pmedata)->pmegridB.grid.grid != NULL)
2604     {
2605         pmegrids_destroy(&(*pmedata)->pmegridB);
2606         sfree((*pmedata)->fftgridB);
2607         sfree((*pmedata)->cfftgridB);
2608         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2609     }
2610     for(thread=0; thread<(*pmedata)->nthread; thread++)
2611     {
2612         free_work(&(*pmedata)->work[thread]);
2613     }
2614     sfree((*pmedata)->work);
2615
2616     sfree(*pmedata);
2617     *pmedata = NULL;
2618
2619   return 0;
2620 }
2621
2622 static int mult_up(int n,int f)
2623 {
2624     return ((n + f - 1)/f)*f;
2625 }
2626
2627
2628 static double pme_load_imbalance(gmx_pme_t pme)
2629 {
2630     int    nma,nmi;
2631     double n1,n2,n3;
2632
2633     nma = pme->nnodes_major;
2634     nmi = pme->nnodes_minor;
2635
2636     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2637     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2638     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2639
2640     /* pme_solve is roughly double the cost of an fft */
2641
2642     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2643 }
2644
2645 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2646                           int dimind,gmx_bool bSpread)
2647 {
2648     int nk,k,s,thread;
2649
2650     atc->dimind = dimind;
2651     atc->nslab  = 1;
2652     atc->nodeid = 0;
2653     atc->pd_nalloc = 0;
2654 #ifdef GMX_MPI
2655     if (pme->nnodes > 1)
2656     {
2657         atc->mpi_comm = pme->mpi_comm_d[dimind];
2658         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2659         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2660     }
2661     if (debug)
2662     {
2663         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2664     }
2665 #endif
2666
2667     atc->bSpread   = bSpread;
2668     atc->pme_order = pme->pme_order;
2669
2670     if (atc->nslab > 1)
2671     {
2672         /* These three allocations are not required for particle decomp. */
2673         snew(atc->node_dest,atc->nslab);
2674         snew(atc->node_src,atc->nslab);
2675         setup_coordinate_communication(atc);
2676
2677         snew(atc->count_thread,pme->nthread);
2678         for(thread=0; thread<pme->nthread; thread++)
2679         {
2680             snew(atc->count_thread[thread],atc->nslab);
2681         }
2682         atc->count = atc->count_thread[0];
2683         snew(atc->rcount,atc->nslab);
2684         snew(atc->buf_index,atc->nslab);
2685     }
2686
2687     atc->nthread = pme->nthread;
2688     if (atc->nthread > 1)
2689     {
2690         snew(atc->thread_plist,atc->nthread);
2691     }
2692     snew(atc->spline,atc->nthread);
2693     for(thread=0; thread<atc->nthread; thread++)
2694     {
2695         if (atc->nthread > 1)
2696         {
2697             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2698             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2699         }
2700     }
2701 }
2702
2703 static void
2704 init_overlap_comm(pme_overlap_t *  ol,
2705                   int              norder,
2706 #ifdef GMX_MPI
2707                   MPI_Comm         comm,
2708 #endif
2709                   int              nnodes,
2710                   int              nodeid,
2711                   int              ndata,
2712                   int              commplainsize)
2713 {
2714     int lbnd,rbnd,maxlr,b,i;
2715     int exten;
2716     int nn,nk;
2717     pme_grid_comm_t *pgc;
2718     gmx_bool bCont;
2719     int fft_start,fft_end,send_index1,recv_index1;
2720
2721 #ifdef GMX_MPI
2722     ol->mpi_comm = comm;
2723 #endif
2724
2725     ol->nnodes = nnodes;
2726     ol->nodeid = nodeid;
2727
2728     /* Linear translation of the PME grid wo'nt affect reciprocal space
2729      * calculations, so to optimize we only interpolate "upwards",
2730      * which also means we only have to consider overlap in one direction.
2731      * I.e., particles on this node might also be spread to grid indices
2732      * that belong to higher nodes (modulo nnodes)
2733      */
2734
2735     snew(ol->s2g0,ol->nnodes+1);
2736     snew(ol->s2g1,ol->nnodes);
2737     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2738     for(i=0; i<nnodes; i++)
2739     {
2740         /* s2g0 the local interpolation grid start.
2741          * s2g1 the local interpolation grid end.
2742          * Because grid overlap communication only goes forward,
2743          * the grid the slabs for fft's should be rounded down.
2744          */
2745         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2746         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2747
2748         if (debug)
2749         {
2750             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2751         }
2752     }
2753     ol->s2g0[nnodes] = ndata;
2754     if (debug) { fprintf(debug,"\n"); }
2755
2756     /* Determine with how many nodes we need to communicate the grid overlap */
2757     b = 0;
2758     do
2759     {
2760         b++;
2761         bCont = FALSE;
2762         for(i=0; i<nnodes; i++)
2763         {
2764             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2765                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2766             {
2767                 bCont = TRUE;
2768             }
2769         }
2770     }
2771     while (bCont && b < nnodes);
2772     ol->noverlap_nodes = b - 1;
2773
2774     snew(ol->send_id,ol->noverlap_nodes);
2775     snew(ol->recv_id,ol->noverlap_nodes);
2776     for(b=0; b<ol->noverlap_nodes; b++)
2777     {
2778         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2779         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2780     }
2781     snew(ol->comm_data, ol->noverlap_nodes);
2782
2783     for(b=0; b<ol->noverlap_nodes; b++)
2784     {
2785         pgc = &ol->comm_data[b];
2786         /* Send */
2787         fft_start        = ol->s2g0[ol->send_id[b]];
2788         fft_end          = ol->s2g0[ol->send_id[b]+1];
2789         if (ol->send_id[b] < nodeid)
2790         {
2791             fft_start += ndata;
2792             fft_end   += ndata;
2793         }
2794         send_index1      = ol->s2g1[nodeid];
2795         send_index1      = min(send_index1,fft_end);
2796         pgc->send_index0 = fft_start;
2797         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2798
2799         /* We always start receiving to the first index of our slab */
2800         fft_start        = ol->s2g0[ol->nodeid];
2801         fft_end          = ol->s2g0[ol->nodeid+1];
2802         recv_index1      = ol->s2g1[ol->recv_id[b]];
2803         if (ol->recv_id[b] > nodeid)
2804         {
2805             recv_index1 -= ndata;
2806         }
2807         recv_index1      = min(recv_index1,fft_end);
2808         pgc->recv_index0 = fft_start;
2809         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2810     }
2811
2812     /* For non-divisible grid we need pme_order iso pme_order-1 */
2813     snew(ol->sendbuf,norder*commplainsize);
2814     snew(ol->recvbuf,norder*commplainsize);
2815 }
2816
2817 static void
2818 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2819                               int **global_to_local,
2820                               real **fraction_shift)
2821 {
2822     int i;
2823     int * gtl;
2824     real * fsh;
2825
2826     snew(gtl,5*n);
2827     snew(fsh,5*n);
2828     for(i=0; (i<5*n); i++)
2829     {
2830         /* Determine the global to local grid index */
2831         gtl[i] = (i - local_start + n) % n;
2832         /* For coordinates that fall within the local grid the fraction
2833          * is correct, we don't need to shift it.
2834          */
2835         fsh[i] = 0;
2836         if (local_range < n)
2837         {
2838             /* Due to rounding issues i could be 1 beyond the lower or
2839              * upper boundary of the local grid. Correct the index for this.
2840              * If we shift the index, we need to shift the fraction by
2841              * the same amount in the other direction to not affect
2842              * the weights.
2843              * Note that due to this shifting the weights at the end of
2844              * the spline might change, but that will only involve values
2845              * between zero and values close to the precision of a real,
2846              * which is anyhow the accuracy of the whole mesh calculation.
2847              */
2848             /* With local_range=0 we should not change i=local_start */
2849             if (i % n != local_start)
2850             {
2851                 if (gtl[i] == n-1)
2852                 {
2853                     gtl[i] = 0;
2854                     fsh[i] = -1;
2855                 }
2856                 else if (gtl[i] == local_range)
2857                 {
2858                     gtl[i] = local_range - 1;
2859                     fsh[i] = 1;
2860                 }
2861             }
2862         }
2863     }
2864
2865     *global_to_local = gtl;
2866     *fraction_shift  = fsh;
2867 }
2868
2869 static void sse_mask_init(pme_spline_work_t *work,int order)
2870 {
2871 #ifdef PME_SSE
2872     float  tmp[8];
2873     __m128 zero_SSE;
2874     int    of,i;
2875
2876     zero_SSE = _mm_setzero_ps();
2877
2878     for(of=0; of<8-(order-1); of++)
2879     {
2880         for(i=0; i<8; i++)
2881         {
2882             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2883         }
2884         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2885         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2886         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2887         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2888     }
2889 #endif
2890 }
2891
2892 static void
2893 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2894 {
2895     int nk_new;
2896
2897     if (*nk % nnodes != 0)
2898     {
2899         nk_new = nnodes*(*nk/nnodes + 1);
2900
2901         if (2*nk_new >= 3*(*nk))
2902         {
2903             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2904                       dim,*nk,dim,nnodes,dim);
2905         }
2906
2907         if (fplog != NULL)
2908         {
2909             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2910                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2911         }
2912
2913         *nk = nk_new;
2914     }
2915 }
2916
2917 int gmx_pme_init(gmx_pme_t *         pmedata,
2918                  t_commrec *         cr,
2919                  int                 nnodes_major,
2920                  int                 nnodes_minor,
2921                  t_inputrec *        ir,
2922                  int                 homenr,
2923                  gmx_bool            bFreeEnergy,
2924                  gmx_bool            bReproducible,
2925                  int                 nthread)
2926 {
2927     gmx_pme_t pme=NULL;
2928
2929     pme_atomcomm_t *atc;
2930     ivec ndata;
2931
2932     if (debug)
2933         fprintf(debug,"Creating PME data structures.\n");
2934     snew(pme,1);
2935
2936     pme->redist_init         = FALSE;
2937     pme->sum_qgrid_tmp       = NULL;
2938     pme->sum_qgrid_dd_tmp    = NULL;
2939     pme->buf_nalloc          = 0;
2940     pme->redist_buf_nalloc   = 0;
2941
2942     pme->nnodes              = 1;
2943     pme->bPPnode             = TRUE;
2944
2945     pme->nnodes_major        = nnodes_major;
2946     pme->nnodes_minor        = nnodes_minor;
2947
2948 #ifdef GMX_MPI
2949     if (nnodes_major*nnodes_minor > 1)
2950     {
2951         pme->mpi_comm = cr->mpi_comm_mygroup;
2952
2953         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2954         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2955         if (pme->nnodes != nnodes_major*nnodes_minor)
2956         {
2957             gmx_incons("PME node count mismatch");
2958         }
2959     }
2960     else
2961     {
2962         pme->mpi_comm = MPI_COMM_NULL;
2963     }
2964 #endif
2965
2966     if (pme->nnodes == 1)
2967     {
2968         pme->ndecompdim = 0;
2969         pme->nodeid_major = 0;
2970         pme->nodeid_minor = 0;
2971 #ifdef GMX_MPI
2972         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2973 #endif
2974     }
2975     else
2976     {
2977         if (nnodes_minor == 1)
2978         {
2979 #ifdef GMX_MPI
2980             pme->mpi_comm_d[0] = pme->mpi_comm;
2981             pme->mpi_comm_d[1] = MPI_COMM_NULL;
2982 #endif
2983             pme->ndecompdim = 1;
2984             pme->nodeid_major = pme->nodeid;
2985             pme->nodeid_minor = 0;
2986
2987         }
2988         else if (nnodes_major == 1)
2989         {
2990 #ifdef GMX_MPI
2991             pme->mpi_comm_d[0] = MPI_COMM_NULL;
2992             pme->mpi_comm_d[1] = pme->mpi_comm;
2993 #endif
2994             pme->ndecompdim = 1;
2995             pme->nodeid_major = 0;
2996             pme->nodeid_minor = pme->nodeid;
2997         }
2998         else
2999         {
3000             if (pme->nnodes % nnodes_major != 0)
3001             {
3002                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3003             }
3004             pme->ndecompdim = 2;
3005
3006 #ifdef GMX_MPI
3007             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3008                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3009             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3010                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3011
3012             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3013             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3014             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3015             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3016 #endif
3017         }
3018         pme->bPPnode = (cr->duty & DUTY_PP);
3019     }
3020
3021     pme->nthread = nthread;
3022
3023     if (ir->ePBC == epbcSCREW)
3024     {
3025         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3026     }
3027
3028     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3029     pme->nkx         = ir->nkx;
3030     pme->nky         = ir->nky;
3031     pme->nkz         = ir->nkz;
3032     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3033     pme->pme_order   = ir->pme_order;
3034     pme->epsilon_r   = ir->epsilon_r;
3035
3036     if (pme->pme_order > PME_ORDER_MAX)
3037     {
3038         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3039                   pme->pme_order,PME_ORDER_MAX);
3040     }
3041
3042     /* Currently pme.c supports only the fft5d FFT code.
3043      * Therefore the grid always needs to be divisible by nnodes.
3044      * When the old 1D code is also supported again, change this check.
3045      *
3046      * This check should be done before calling gmx_pme_init
3047      * and fplog should be passed iso stderr.
3048      *
3049     if (pme->ndecompdim >= 2)
3050     */
3051     if (pme->ndecompdim >= 1)
3052     {
3053         /*
3054         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3055                                         'x',nnodes_major,&pme->nkx);
3056         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3057                                         'y',nnodes_minor,&pme->nky);
3058         */
3059     }
3060
3061     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3062         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3063         pme->nkz <= pme->pme_order)
3064     {
3065         gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3066     }
3067
3068     if (pme->nnodes > 1) {
3069         double imbal;
3070
3071 #ifdef GMX_MPI
3072         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3073         MPI_Type_commit(&(pme->rvec_mpi));
3074 #endif
3075
3076         /* Note that the charge spreading and force gathering, which usually
3077          * takes about the same amount of time as FFT+solve_pme,
3078          * is always fully load balanced
3079          * (unless the charge distribution is inhomogeneous).
3080          */
3081
3082         imbal = pme_load_imbalance(pme);
3083         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3084         {
3085             fprintf(stderr,
3086                     "\n"
3087                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3088                     "      For optimal PME load balancing\n"
3089                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3090                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3091                     "\n",
3092                     (int)((imbal-1)*100 + 0.5),
3093                     pme->nkx,pme->nky,pme->nnodes_major,
3094                     pme->nky,pme->nkz,pme->nnodes_minor);
3095         }
3096     }
3097
3098     /* For non-divisible grid we need pme_order iso pme_order-1 */
3099     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3100      * y is always copied through a buffer: we don't need padding in z,
3101      * but we do need the overlap in x because of the communication order.
3102      */
3103     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3104 #ifdef GMX_MPI
3105                       pme->mpi_comm_d[0],
3106 #endif
3107                       pme->nnodes_major,pme->nodeid_major,
3108                       pme->nkx,
3109                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3110
3111     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3112 #ifdef GMX_MPI
3113                       pme->mpi_comm_d[1],
3114 #endif
3115                       pme->nnodes_minor,pme->nodeid_minor,
3116                       pme->nky,
3117                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3118
3119     /* Check for a limitation of the (current) sum_fftgrid_dd code */
3120     if (pme->nthread > 1 &&
3121         (pme->overlap[0].noverlap_nodes > 1 ||
3122          pme->overlap[1].noverlap_nodes > 1))
3123     {
3124         gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3125     }
3126
3127     snew(pme->bsp_mod[XX],pme->nkx);
3128     snew(pme->bsp_mod[YY],pme->nky);
3129     snew(pme->bsp_mod[ZZ],pme->nkz);
3130
3131     /* The required size of the interpolation grid, including overlap.
3132      * The allocated size (pmegrid_n?) might be slightly larger.
3133      */
3134     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3135                       pme->overlap[0].s2g0[pme->nodeid_major];
3136     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3137                       pme->overlap[1].s2g0[pme->nodeid_minor];
3138     pme->pmegrid_nz_base = pme->nkz;
3139     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3140     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3141
3142     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3143     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3144     pme->pmegrid_start_iz = 0;
3145
3146     make_gridindex5_to_localindex(pme->nkx,
3147                                   pme->pmegrid_start_ix,
3148                                   pme->pmegrid_nx - (pme->pme_order-1),
3149                                   &pme->nnx,&pme->fshx);
3150     make_gridindex5_to_localindex(pme->nky,
3151                                   pme->pmegrid_start_iy,
3152                                   pme->pmegrid_ny - (pme->pme_order-1),
3153                                   &pme->nny,&pme->fshy);
3154     make_gridindex5_to_localindex(pme->nkz,
3155                                   pme->pmegrid_start_iz,
3156                                   pme->pmegrid_nz_base,
3157                                   &pme->nnz,&pme->fshz);
3158
3159     pmegrids_init(&pme->pmegridA,
3160                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3161                   pme->pmegrid_nz_base,
3162                   pme->pme_order,
3163                   pme->nthread,
3164                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3165                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3166
3167     sse_mask_init(&pme->spline_work,pme->pme_order);
3168
3169     ndata[0] = pme->nkx;
3170     ndata[1] = pme->nky;
3171     ndata[2] = pme->nkz;
3172
3173     /* This routine will allocate the grid data to fit the FFTs */
3174     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3175                             &pme->fftgridA,&pme->cfftgridA,
3176                             pme->mpi_comm_d,
3177                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3178                             bReproducible,pme->nthread);
3179
3180     if (bFreeEnergy)
3181     {
3182         pmegrids_init(&pme->pmegridB,
3183                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3184                       pme->pmegrid_nz_base,
3185                       pme->pme_order,
3186                       pme->nthread,
3187                       pme->nkx % pme->nnodes_major != 0,
3188                       pme->nky % pme->nnodes_minor != 0);
3189
3190         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3191                                 &pme->fftgridB,&pme->cfftgridB,
3192                                 pme->mpi_comm_d,
3193                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3194                                 bReproducible,pme->nthread);
3195     }
3196     else
3197     {
3198         pme->pmegridB.grid.grid = NULL;
3199         pme->fftgridB           = NULL;
3200         pme->cfftgridB          = NULL;
3201     }
3202
3203     if (!pme->bP3M)
3204     {
3205         /* Use plain SPME B-spline interpolation */
3206         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3207     }
3208     else
3209     {
3210         /* Use the P3M grid-optimized influence function */
3211         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3212     }
3213
3214     /* Use atc[0] for spreading */
3215     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3216     if (pme->ndecompdim >= 2)
3217     {
3218         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3219     }
3220
3221     if (pme->nnodes == 1) {
3222         pme->atc[0].n = homenr;
3223         pme_realloc_atomcomm_things(&pme->atc[0]);
3224     }
3225
3226     {
3227         int thread;
3228
3229         /* Use fft5d, order after FFT is y major, z, x minor */
3230
3231         snew(pme->work,pme->nthread);
3232         for(thread=0; thread<pme->nthread; thread++)
3233         {
3234             realloc_work(&pme->work[thread],pme->nkx);
3235         }
3236     }
3237
3238     *pmedata = pme;
3239
3240     return 0;
3241 }
3242
3243
3244 static void copy_local_grid(gmx_pme_t pme,
3245                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3246 {
3247     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3248     int  fft_my,fft_mz;
3249     int  nsx,nsy,nsz;
3250     ivec nf;
3251     int  offx,offy,offz,x,y,z,i0,i0t;
3252     int  d;
3253     pmegrid_t *pmegrid;
3254     real *grid_th;
3255
3256     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3257                                    local_fft_ndata,
3258                                    local_fft_offset,
3259                                    local_fft_size);
3260     fft_my = local_fft_size[YY];
3261     fft_mz = local_fft_size[ZZ];
3262
3263     pmegrid = &pmegrids->grid_th[thread];
3264
3265     nsx = pmegrid->n[XX];
3266     nsy = pmegrid->n[YY];
3267     nsz = pmegrid->n[ZZ];
3268
3269     for(d=0; d<DIM; d++)
3270     {
3271         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3272                     local_fft_ndata[d] - pmegrid->offset[d]);
3273     }
3274
3275     offx = pmegrid->offset[XX];
3276     offy = pmegrid->offset[YY];
3277     offz = pmegrid->offset[ZZ];
3278
3279     /* Directly copy the non-overlapping parts of the local grids.
3280      * This also initializes the full grid.
3281      */
3282     grid_th = pmegrid->grid;
3283     for(x=0; x<nf[XX]; x++)
3284     {
3285         for(y=0; y<nf[YY]; y++)
3286         {
3287             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3288             i0t = (x*nsy + y)*nsz;
3289             for(z=0; z<nf[ZZ]; z++)
3290             {
3291                 fftgrid[i0+z] = grid_th[i0t+z];
3292             }
3293         }
3294     }
3295 }
3296
3297 static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3298 {
3299     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3300     pme_overlap_t *overlap;
3301     int datasize,nind;
3302     int i,x,y,z,n;
3303
3304     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3305                                    local_fft_ndata,
3306                                    local_fft_offset,
3307                                    local_fft_size);
3308     /* Major dimension */
3309     overlap = &pme->overlap[0];
3310
3311     nind   = overlap->comm_data[0].send_nindex;
3312
3313     for(y=0; y<local_fft_ndata[YY]; y++) {
3314          printf(" %2d",y);
3315     }
3316     printf("\n");
3317
3318     i = 0;
3319     for(x=0; x<nind; x++) {
3320         for(y=0; y<local_fft_ndata[YY]; y++) {
3321             n = 0;
3322             for(z=0; z<local_fft_ndata[ZZ]; z++) {
3323                 if (sendbuf[i] != 0) n++;
3324                 i++;
3325             }
3326             printf(" %2d",n);
3327         }
3328         printf("\n");
3329     }
3330 }
3331
3332 static void
3333 reduce_threadgrid_overlap(gmx_pme_t pme,
3334                           const pmegrids_t *pmegrids,int thread,
3335                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3336 {
3337     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3338     int  fft_nx,fft_ny,fft_nz;
3339     int  fft_my,fft_mz;
3340     int  buf_my=-1;
3341     int  nsx,nsy,nsz;
3342     ivec ne;
3343     int  offx,offy,offz,x,y,z,i0,i0t;
3344     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3345     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3346     gmx_bool bCommX,bCommY;
3347     int  d;
3348     int  thread_f;
3349     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3350     const real *grid_th;
3351     real *commbuf=NULL;
3352
3353     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3354                                    local_fft_ndata,
3355                                    local_fft_offset,
3356                                    local_fft_size);
3357     fft_nx = local_fft_ndata[XX];
3358     fft_ny = local_fft_ndata[YY];
3359     fft_nz = local_fft_ndata[ZZ];
3360
3361     fft_my = local_fft_size[YY];
3362     fft_mz = local_fft_size[ZZ];
3363
3364     /* This routine is called when all thread have finished spreading.
3365      * Here each thread sums grid contributions calculated by other threads
3366      * to the thread local grid volume.
3367      * To minimize the number of grid copying operations,
3368      * this routines sums immediately from the pmegrid to the fftgrid.
3369      */
3370
3371     /* Determine which part of the full node grid we should operate on,
3372      * this is our thread local part of the full grid.
3373      */
3374     pmegrid = &pmegrids->grid_th[thread];
3375
3376     for(d=0; d<DIM; d++)
3377     {
3378         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3379                     local_fft_ndata[d]);
3380     }
3381
3382     offx = pmegrid->offset[XX];
3383     offy = pmegrid->offset[YY];
3384     offz = pmegrid->offset[ZZ];
3385
3386
3387     bClearBufX  = TRUE;
3388     bClearBufY  = TRUE;
3389     bClearBufXY = TRUE;
3390
3391     /* Now loop over all the thread data blocks that contribute
3392      * to the grid region we (our thread) are operating on.
3393      */
3394     /* Note that ffy_nx/y is equal to the number of grid points
3395      * between the first point of our node grid and the one of the next node.
3396      */
3397     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3398     {
3399         fx = pmegrid->ci[XX] + sx;
3400         ox = 0;
3401         bCommX = FALSE;
3402         if (fx < 0) {
3403             fx += pmegrids->nc[XX];
3404             ox -= fft_nx;
3405             bCommX = (pme->nnodes_major > 1);
3406         }
3407         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3408         ox += pmegrid_g->offset[XX];
3409         if (!bCommX)
3410         {
3411             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3412         }
3413         else
3414         {
3415             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3416         }
3417
3418         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3419         {
3420             fy = pmegrid->ci[YY] + sy;
3421             oy = 0;
3422             bCommY = FALSE;
3423             if (fy < 0) {
3424                 fy += pmegrids->nc[YY];
3425                 oy -= fft_ny;
3426                 bCommY = (pme->nnodes_minor > 1);
3427             }
3428             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3429             oy += pmegrid_g->offset[YY];
3430             if (!bCommY)
3431             {
3432                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3433             }
3434             else
3435             {
3436                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3437             }
3438
3439             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3440             {
3441                 fz = pmegrid->ci[ZZ] + sz;
3442                 oz = 0;
3443                 if (fz < 0)
3444                 {
3445                     fz += pmegrids->nc[ZZ];
3446                     oz -= fft_nz;
3447                 }
3448                 pmegrid_g = &pmegrids->grid_th[fz];
3449                 oz += pmegrid_g->offset[ZZ];
3450                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3451
3452                 if (sx == 0 && sy == 0 && sz == 0)
3453                 {
3454                     /* We have already added our local contribution
3455                      * before calling this routine, so skip it here.
3456                      */
3457                     continue;
3458                 }
3459
3460                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3461
3462                 pmegrid_f = &pmegrids->grid_th[thread_f];
3463
3464                 grid_th = pmegrid_f->grid;
3465
3466                 nsx = pmegrid_f->n[XX];
3467                 nsy = pmegrid_f->n[YY];
3468                 nsz = pmegrid_f->n[ZZ];
3469
3470 #ifdef DEBUG_PME_REDUCE
3471                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3472                        pme->nodeid,thread,thread_f,
3473                        pme->pmegrid_start_ix,
3474                        pme->pmegrid_start_iy,
3475                        pme->pmegrid_start_iz,
3476                        sx,sy,sz,
3477                        offx-ox,tx1-ox,offx,tx1,
3478                        offy-oy,ty1-oy,offy,ty1,
3479                        offz-oz,tz1-oz,offz,tz1);
3480 #endif
3481
3482                 if (!(bCommX || bCommY))
3483                 {
3484                     /* Copy from the thread local grid to the node grid */
3485                     for(x=offx; x<tx1; x++)
3486                     {
3487                         for(y=offy; y<ty1; y++)
3488                         {
3489                             i0  = (x*fft_my + y)*fft_mz;
3490                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3491                             for(z=offz; z<tz1; z++)
3492                             {
3493                                 fftgrid[i0+z] += grid_th[i0t+z];
3494                             }
3495                         }
3496                     }
3497                 }
3498                 else
3499                 {
3500                     /* The order of this conditional decides
3501                      * where the corner volume gets stored with x+y decomp.
3502                      */
3503                     if (bCommY)
3504                     {
3505                         commbuf = commbuf_y;
3506                         buf_my  = ty1 - offy;
3507                         if (bCommX)
3508                         {
3509                             /* We index commbuf modulo the local grid size */
3510                             commbuf += buf_my*fft_nx*fft_nz;
3511
3512                             bClearBuf  = bClearBufXY;
3513                             bClearBufXY = FALSE;
3514                         }
3515                         else
3516                         {
3517                             bClearBuf  = bClearBufY;
3518                             bClearBufY = FALSE;
3519                         }
3520                     }
3521                     else
3522                     {
3523                         commbuf = commbuf_x;
3524                         buf_my  = fft_ny;
3525                         bClearBuf  = bClearBufX;
3526                         bClearBufX = FALSE;
3527                     }
3528
3529                     /* Copy to the communication buffer */
3530                     for(x=offx; x<tx1; x++)
3531                     {
3532                         for(y=offy; y<ty1; y++)
3533                         {
3534                             i0  = (x*buf_my + y)*fft_nz;
3535                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3536
3537                             if (bClearBuf)
3538                             {
3539                                 /* First access of commbuf, initialize it */
3540                                 for(z=offz; z<tz1; z++)
3541                                 {
3542                                     commbuf[i0+z]  = grid_th[i0t+z];
3543                                 }
3544                             }
3545                             else
3546                             {
3547                                 for(z=offz; z<tz1; z++)
3548                                 {
3549                                     commbuf[i0+z] += grid_th[i0t+z];
3550                                 }
3551                             }
3552                         }
3553                     }
3554                 }
3555             }
3556         }
3557     }
3558 }
3559
3560
3561 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3562 {
3563     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3564     pme_overlap_t *overlap;
3565     int  send_nindex;
3566     int  recv_index0,recv_nindex;
3567 #ifdef GMX_MPI
3568     MPI_Status stat;
3569 #endif
3570     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3571     real *sendptr,*recvptr;
3572     int  x,y,z,indg,indb;
3573
3574     /* Note that this routine is only used for forward communication.
3575      * Since the force gathering, unlike the charge spreading,
3576      * can be trivially parallelized over the particles,
3577      * the backwards process is much simpler and can use the "old"
3578      * communication setup.
3579      */
3580
3581     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3582                                    local_fft_ndata,
3583                                    local_fft_offset,
3584                                    local_fft_size);
3585
3586     /* Currently supports only a single communication pulse */
3587
3588 /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3589     if (pme->nnodes_minor > 1)
3590     {
3591         /* Major dimension */
3592         overlap = &pme->overlap[1];
3593
3594         if (pme->nnodes_major > 1)
3595         {
3596              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3597         }
3598         else
3599         {
3600             size_yx = 0;
3601         }
3602         datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3603
3604         ipulse = 0;
3605
3606         send_id = overlap->send_id[ipulse];
3607         recv_id = overlap->recv_id[ipulse];
3608         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3609         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3610         recv_index0 = 0;
3611         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3612
3613         sendptr = overlap->sendbuf;
3614         recvptr = overlap->recvbuf;
3615
3616         /*
3617         printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3618                local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3619         printf("node %d send %f, %f\n",pme->nodeid,
3620                sendptr[0],sendptr[send_nindex*datasize-1]);
3621         */
3622
3623 #ifdef GMX_MPI
3624         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3625                      send_id,ipulse,
3626                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3627                      recv_id,ipulse,
3628                      overlap->mpi_comm,&stat);
3629 #endif
3630
3631         for(x=0; x<local_fft_ndata[XX]; x++)
3632         {
3633             for(y=0; y<recv_nindex; y++)
3634             {
3635                 indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3636                 indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3637                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3638                 {
3639                     fftgrid[indg+z] += recvptr[indb+z];
3640                 }
3641             }
3642         }
3643         if (pme->nnodes_major > 1)
3644         {
3645             sendptr = pme->overlap[0].sendbuf;
3646             for(x=0; x<size_yx; x++)
3647             {
3648                 for(y=0; y<recv_nindex; y++)
3649                 {
3650                     indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3651                     indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3652                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3653                     {
3654                         sendptr[indg+z] += recvptr[indb+z];
3655                     }
3656                 }
3657             }
3658         }
3659     }
3660
3661     /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3662     if (pme->nnodes_major > 1)
3663     {
3664         /* Major dimension */
3665         overlap = &pme->overlap[0];
3666
3667         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3668         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3669
3670         ipulse = 0;
3671
3672         send_id = overlap->send_id[ipulse];
3673         recv_id = overlap->recv_id[ipulse];
3674         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3675         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3676         recv_index0 = 0;
3677         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3678
3679         sendptr = overlap->sendbuf;
3680         recvptr = overlap->recvbuf;
3681
3682         if (debug != NULL)
3683         {
3684             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3685                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3686         }
3687
3688 #ifdef GMX_MPI
3689         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3690                      send_id,ipulse,
3691                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3692                      recv_id,ipulse,
3693                      overlap->mpi_comm,&stat);
3694 #endif
3695
3696         for(x=0; x<recv_nindex; x++)
3697         {
3698             for(y=0; y<local_fft_ndata[YY]; y++)
3699             {
3700                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3701                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3702                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3703                 {
3704                     fftgrid[indg+z] += recvptr[indb+z];
3705                 }
3706             }
3707         }
3708     }
3709 }
3710
3711
3712 static void spread_on_grid(gmx_pme_t pme,
3713                            pme_atomcomm_t *atc,pmegrids_t *grids,
3714                            gmx_bool bCalcSplines,gmx_bool bSpread,
3715                            real *fftgrid)
3716 {
3717     int nthread,thread;
3718 #ifdef PME_TIME_THREADS
3719     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3720     static double cs1=0,cs2=0,cs3=0;
3721     static double cs1a[6]={0,0,0,0,0,0};
3722     static int cnt=0;
3723 #endif
3724
3725     nthread = pme->nthread;
3726     assert(nthread>0);
3727
3728 #ifdef PME_TIME_THREADS
3729     c1 = omp_cyc_start();
3730 #endif
3731     if (bCalcSplines)
3732     {
3733 #pragma omp parallel for num_threads(nthread) schedule(static)
3734         for(thread=0; thread<nthread; thread++)
3735         {
3736             int start,end;
3737
3738             start = atc->n* thread   /nthread;
3739             end   = atc->n*(thread+1)/nthread;
3740
3741             /* Compute fftgrid index for all atoms,
3742              * with help of some extra variables.
3743              */
3744             calc_interpolation_idx(pme,atc,start,end,thread);
3745         }
3746     }
3747 #ifdef PME_TIME_THREADS
3748     c1 = omp_cyc_end(c1);
3749     cs1 += (double)c1;
3750 #endif
3751
3752 #ifdef PME_TIME_THREADS
3753     c2 = omp_cyc_start();
3754 #endif
3755 #pragma omp parallel for num_threads(nthread) schedule(static)
3756     for(thread=0; thread<nthread; thread++)
3757     {
3758         splinedata_t *spline;
3759         pmegrid_t *grid;
3760
3761         /* make local bsplines  */
3762         if (grids == NULL || grids->nthread == 1)
3763         {
3764             spline = &atc->spline[0];
3765
3766             spline->n = atc->n;
3767
3768             grid = &grids->grid;
3769         }
3770         else
3771         {
3772             spline = &atc->spline[thread];
3773
3774             make_thread_local_ind(atc,thread,spline);
3775
3776             grid = &grids->grid_th[thread];
3777         }
3778
3779         if (bCalcSplines)
3780         {
3781             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3782                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3783         }
3784
3785         if (bSpread)
3786         {
3787             /* put local atoms on grid. */
3788 #ifdef PME_TIME_SPREAD
3789             ct1a = omp_cyc_start();
3790 #endif
3791             spread_q_bsplines_thread(grid,atc,spline,&pme->spline_work);
3792
3793             if (grids->nthread > 1)
3794             {
3795                 copy_local_grid(pme,grids,thread,fftgrid);
3796             }
3797 #ifdef PME_TIME_SPREAD
3798             ct1a = omp_cyc_end(ct1a);
3799             cs1a[thread] += (double)ct1a;
3800 #endif
3801         }
3802     }
3803 #ifdef PME_TIME_THREADS
3804     c2 = omp_cyc_end(c2);
3805     cs2 += (double)c2;
3806 #endif
3807
3808     if (bSpread && grids->nthread > 1)
3809     {
3810 #ifdef PME_TIME_THREADS
3811         c3 = omp_cyc_start();
3812 #endif
3813 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3814         for(thread=0; thread<grids->nthread; thread++)
3815         {
3816             reduce_threadgrid_overlap(pme,grids,thread,
3817                                       fftgrid,
3818                                       pme->overlap[0].sendbuf,
3819                                       pme->overlap[1].sendbuf);
3820 #ifdef PRINT_PME_SENDBUF
3821             print_sendbuf(pme,pme->overlap[0].sendbuf);
3822 #endif
3823         }
3824 #ifdef PME_TIME_THREADS
3825         c3 = omp_cyc_end(c3);
3826         cs3 += (double)c3;
3827 #endif
3828
3829         if (pme->nnodes > 1)
3830         {
3831             /* Communicate the overlapping part of the fftgrid */
3832             sum_fftgrid_dd(pme,fftgrid);
3833         }
3834     }
3835
3836 #ifdef PME_TIME_THREADS
3837     cnt++;
3838     if (cnt % 20 == 0)
3839     {
3840         printf("idx %.2f spread %.2f red %.2f",
3841                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3842 #ifdef PME_TIME_SPREAD
3843         for(thread=0; thread<nthread; thread++)
3844             printf(" %.2f",cs1a[thread]*1e-9);
3845 #endif
3846         printf("\n");
3847     }
3848 #endif
3849 }
3850
3851
3852 static void dump_grid(FILE *fp,
3853                       int sx,int sy,int sz,int nx,int ny,int nz,
3854                       int my,int mz,const real *g)
3855 {
3856     int x,y,z;
3857
3858     for(x=0; x<nx; x++)
3859     {
3860         for(y=0; y<ny; y++)
3861         {
3862             for(z=0; z<nz; z++)
3863             {
3864                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3865                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3866             }
3867         }
3868     }
3869 }
3870
3871 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3872 {
3873     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3874
3875     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3876                                    local_fft_ndata,
3877                                    local_fft_offset,
3878                                    local_fft_size);
3879
3880     dump_grid(stderr,
3881               pme->pmegrid_start_ix,
3882               pme->pmegrid_start_iy,
3883               pme->pmegrid_start_iz,
3884               pme->pmegrid_nx-pme->pme_order+1,
3885               pme->pmegrid_ny-pme->pme_order+1,
3886               pme->pmegrid_nz-pme->pme_order+1,
3887               local_fft_size[YY],
3888               local_fft_size[ZZ],
3889               fftgrid);
3890 }
3891
3892
3893 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3894 {
3895     pme_atomcomm_t *atc;
3896     pmegrids_t *grid;
3897
3898     if (pme->nnodes > 1)
3899     {
3900         gmx_incons("gmx_pme_calc_energy called in parallel");
3901     }
3902     if (pme->bFEP > 1)
3903     {
3904         gmx_incons("gmx_pme_calc_energy with free energy");
3905     }
3906
3907     atc = &pme->atc_energy;
3908     atc->nthread   = 1;
3909     if (atc->spline == NULL)
3910     {
3911         snew(atc->spline,atc->nthread);
3912     }
3913     atc->nslab     = 1;
3914     atc->bSpread   = TRUE;
3915     atc->pme_order = pme->pme_order;
3916     atc->n         = n;
3917     pme_realloc_atomcomm_things(atc);
3918     atc->x         = x;
3919     atc->q         = q;
3920
3921     /* We only use the A-charges grid */
3922     grid = &pme->pmegridA;
3923
3924     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3925
3926     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3927 }
3928
3929
3930 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3931         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3932 {
3933     /* Reset all the counters related to performance over the run */
3934     wallcycle_stop(wcycle,ewcRUN);
3935     wallcycle_reset_all(wcycle);
3936     init_nrnb(nrnb);
3937     ir->init_step += step_rel;
3938     ir->nsteps    -= step_rel;
3939     wallcycle_start(wcycle,ewcRUN);
3940 }
3941
3942
3943 int gmx_pmeonly(gmx_pme_t pme,
3944                 t_commrec *cr,    t_nrnb *nrnb,
3945                 gmx_wallcycle_t wcycle,
3946                 real ewaldcoeff,  gmx_bool bGatherOnly,
3947                 t_inputrec *ir)
3948 {
3949     gmx_pme_pp_t pme_pp;
3950     int  natoms;
3951     matrix box;
3952     rvec *x_pp=NULL,*f_pp=NULL;
3953     real *chargeA=NULL,*chargeB=NULL;
3954     real lambda=0;
3955     int  maxshift_x=0,maxshift_y=0;
3956     real energy,dvdlambda;
3957     matrix vir;
3958     float cycles;
3959     int  count;
3960     gmx_bool bEnerVir;
3961     gmx_large_int_t step,step_rel;
3962
3963
3964     pme_pp = gmx_pme_pp_init(cr);
3965
3966     init_nrnb(nrnb);
3967
3968     count = 0;
3969     do /****** this is a quasi-loop over time steps! */
3970     {
3971         /* Domain decomposition */
3972         natoms = gmx_pme_recv_q_x(pme_pp,
3973                                   &chargeA,&chargeB,box,&x_pp,&f_pp,
3974                                   &maxshift_x,&maxshift_y,
3975                                   &pme->bFEP,&lambda,
3976                                   &bEnerVir,
3977                                   &step);
3978
3979         if (natoms == -1) {
3980             /* We should stop: break out of the loop */
3981             break;
3982         }
3983
3984         step_rel = step - ir->init_step;
3985
3986         if (count == 0)
3987             wallcycle_start(wcycle,ewcRUN);
3988
3989         wallcycle_start(wcycle,ewcPMEMESH);
3990
3991         dvdlambda = 0;
3992         clear_mat(vir);
3993         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
3994                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
3995                    &energy,lambda,&dvdlambda,
3996                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
3997
3998         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
3999
4000         gmx_pme_send_force_vir_ener(pme_pp,
4001                                     f_pp,vir,energy,dvdlambda,
4002                                     cycles);
4003
4004         count++;
4005
4006         if (step_rel == wcycle_get_reset_counters(wcycle))
4007         {
4008             /* Reset all the counters related to performance over the run */
4009             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4010             wcycle_set_reset_counters(wcycle, 0);
4011         }
4012
4013     } /***** end of quasi-loop, we stop with the break above */
4014     while (TRUE);
4015
4016     return 0;
4017 }
4018
4019 int gmx_pme_do(gmx_pme_t pme,
4020                int start,       int homenr,
4021                rvec x[],        rvec f[],
4022                real *chargeA,   real *chargeB,
4023                matrix box, t_commrec *cr,
4024                int  maxshift_x, int maxshift_y,
4025                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4026                matrix vir,      real ewaldcoeff,
4027                real *energy,    real lambda,
4028                real *dvdlambda, int flags)
4029 {
4030     int     q,d,i,j,ntot,npme;
4031     int     nx,ny,nz;
4032     int     n_d,local_ny;
4033     pme_atomcomm_t *atc=NULL;
4034     pmegrids_t *pmegrid=NULL;
4035     real    *grid=NULL;
4036     real    *ptr;
4037     rvec    *x_d,*f_d;
4038     real    *charge=NULL,*q_d;
4039     real    energy_AB[2];
4040     matrix  vir_AB[2];
4041     gmx_bool bClearF;
4042     gmx_parallel_3dfft_t pfft_setup;
4043     real *  fftgrid;
4044     t_complex * cfftgrid;
4045     int     thread;
4046     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4047     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4048
4049     assert(pme->nnodes > 0);
4050     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4051
4052     if (pme->nnodes > 1) {
4053         atc = &pme->atc[0];
4054         atc->npd = homenr;
4055         if (atc->npd > atc->pd_nalloc) {
4056             atc->pd_nalloc = over_alloc_dd(atc->npd);
4057             srenew(atc->pd,atc->pd_nalloc);
4058         }
4059         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4060     }
4061     else
4062     {
4063         /* This could be necessary for TPI */
4064         pme->atc[0].n = homenr;
4065     }
4066
4067     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4068         if (q == 0) {
4069             pmegrid = &pme->pmegridA;
4070             fftgrid = pme->fftgridA;
4071             cfftgrid = pme->cfftgridA;
4072             pfft_setup = pme->pfft_setupA;
4073             charge = chargeA+start;
4074         } else {
4075             pmegrid = &pme->pmegridB;
4076             fftgrid = pme->fftgridB;
4077             cfftgrid = pme->cfftgridB;
4078             pfft_setup = pme->pfft_setupB;
4079             charge = chargeB+start;
4080         }
4081         grid = pmegrid->grid.grid;
4082         /* Unpack structure */
4083         if (debug) {
4084             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4085                     cr->nnodes,cr->nodeid);
4086             fprintf(debug,"Grid = %p\n",(void*)grid);
4087             if (grid == NULL)
4088                 gmx_fatal(FARGS,"No grid!");
4089         }
4090         where();
4091
4092         m_inv_ur0(box,pme->recipbox);
4093
4094         if (pme->nnodes == 1) {
4095             atc = &pme->atc[0];
4096             if (DOMAINDECOMP(cr)) {
4097                 atc->n = homenr;
4098                 pme_realloc_atomcomm_things(atc);
4099             }
4100             atc->x = x;
4101             atc->q = charge;
4102             atc->f = f;
4103         } else {
4104             wallcycle_start(wcycle,ewcPME_REDISTXF);
4105             for(d=pme->ndecompdim-1; d>=0; d--)
4106             {
4107                 if (d == pme->ndecompdim-1)
4108                 {
4109                     n_d = homenr;
4110                     x_d = x + start;
4111                     q_d = charge;
4112                 }
4113                 else
4114                 {
4115                     n_d = pme->atc[d+1].n;
4116                     x_d = atc->x;
4117                     q_d = atc->q;
4118                 }
4119                 atc = &pme->atc[d];
4120                 atc->npd = n_d;
4121                 if (atc->npd > atc->pd_nalloc) {
4122                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4123                     srenew(atc->pd,atc->pd_nalloc);
4124                 }
4125                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4126                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4127                 where();
4128
4129                 /* Redistribute x (only once) and qA or qB */
4130                 if (DOMAINDECOMP(cr)) {
4131                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4132                 } else {
4133                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4134                 }
4135             }
4136             where();
4137
4138             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4139         }
4140
4141         if (debug)
4142             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4143                     cr->nodeid,atc->n);
4144
4145         if (flags & GMX_PME_SPREAD_Q)
4146         {
4147             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4148
4149             /* Spread the charges on a grid */
4150             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4151
4152             if (q == 0)
4153             {
4154                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4155             }
4156             inc_nrnb(nrnb,eNR_SPREADQBSP,
4157                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4158
4159             if (pme->nthread == 1)
4160             {
4161                 wrap_periodic_pmegrid(pme,grid);
4162
4163                 /* sum contributions to local grid from other nodes */
4164 #ifdef GMX_MPI
4165                 if (pme->nnodes > 1)
4166                 {
4167                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4168                     where();
4169                 }
4170 #endif
4171
4172                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4173             }
4174
4175             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4176
4177             /*
4178             dump_local_fftgrid(pme,fftgrid);
4179             exit(0);
4180             */
4181         }
4182
4183         /* Here we start a large thread parallel region */
4184 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4185         for(thread=0; thread<pme->nthread; thread++)
4186         {
4187             if (flags & GMX_PME_SOLVE)
4188             {
4189                 int loop_count;
4190
4191                 /* do 3d-fft */
4192                 if (thread == 0)
4193                 {
4194                     wallcycle_start(wcycle,ewcPME_FFT);
4195                 }
4196                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4197                                            fftgrid,cfftgrid,thread,wcycle);
4198                 if (thread == 0)
4199                 {
4200                     wallcycle_stop(wcycle,ewcPME_FFT);
4201                 }
4202                 where();
4203
4204                 /* solve in k-space for our local cells */
4205                 if (thread == 0)
4206                 {
4207                     wallcycle_start(wcycle,ewcPME_SOLVE);
4208                 }
4209                 loop_count =
4210                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4211                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4212                                   bCalcEnerVir,
4213                                   pme->nthread,thread);
4214                 if (thread == 0)
4215                 {
4216                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4217                     where();
4218                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4219                 }
4220             }
4221
4222             if (bCalcF)
4223             {
4224                 /* do 3d-invfft */
4225                 if (thread == 0)
4226                 {
4227                     where();
4228                     wallcycle_start(wcycle,ewcPME_FFT);
4229                 }
4230                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4231                                            cfftgrid,fftgrid,thread,wcycle);
4232                 if (thread == 0)
4233                 {
4234                     wallcycle_stop(wcycle,ewcPME_FFT);
4235
4236                     where();
4237
4238                     if (pme->nodeid == 0)
4239                     {
4240                         ntot = pme->nkx*pme->nky*pme->nkz;
4241                         npme  = ntot*log((real)ntot)/log(2.0);
4242                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4243                     }
4244
4245                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4246                 }
4247
4248                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4249             }
4250         }
4251         /* End of thread parallel section.
4252          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4253          */
4254
4255         if (bCalcF)
4256         {
4257             /* distribute local grid to all nodes */
4258 #ifdef GMX_MPI
4259             if (pme->nnodes > 1) {
4260                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4261             }
4262 #endif
4263             where();
4264
4265             unwrap_periodic_pmegrid(pme,grid);
4266
4267             /* interpolate forces for our local atoms */
4268
4269             where();
4270
4271             /* If we are running without parallelization,
4272              * atc->f is the actual force array, not a buffer,
4273              * therefore we should not clear it.
4274              */
4275             bClearF = (q == 0 && PAR(cr));
4276 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4277             for(thread=0; thread<pme->nthread; thread++)
4278             {
4279                 gather_f_bsplines(pme,grid,bClearF,atc,
4280                                   &atc->spline[thread],
4281                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4282             }
4283
4284             where();
4285
4286             inc_nrnb(nrnb,eNR_GATHERFBSP,
4287                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4288             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4289         }
4290
4291         if (bCalcEnerVir)
4292         {
4293             /* This should only be called on the master thread
4294              * and after the threads have synchronized.
4295              */
4296             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4297         }
4298     } /* of q-loop */
4299
4300     if (bCalcF && pme->nnodes > 1) {
4301         wallcycle_start(wcycle,ewcPME_REDISTXF);
4302         for(d=0; d<pme->ndecompdim; d++)
4303         {
4304             atc = &pme->atc[d];
4305             if (d == pme->ndecompdim - 1)
4306             {
4307                 n_d = homenr;
4308                 f_d = f + start;
4309             }
4310             else
4311             {
4312                 n_d = pme->atc[d+1].n;
4313                 f_d = pme->atc[d+1].f;
4314             }
4315             if (DOMAINDECOMP(cr)) {
4316                 dd_pmeredist_f(pme,atc,n_d,f_d,
4317                                d==pme->ndecompdim-1 && pme->bPPnode);
4318             } else {
4319                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4320             }
4321         }
4322
4323         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4324     }
4325     where();
4326
4327     if (bCalcEnerVir)
4328     {
4329         if (!pme->bFEP) {
4330             *energy = energy_AB[0];
4331             m_add(vir,vir_AB[0],vir);
4332         } else {
4333             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4334             *dvdlambda += energy_AB[1] - energy_AB[0];
4335             for(i=0; i<DIM; i++)
4336             {
4337                 for(j=0; j<DIM; j++)
4338                 {
4339                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4340                         lambda*vir_AB[1][i][j];
4341                 }
4342             }
4343         }
4344     }
4345     else
4346     {
4347         *energy = 0;
4348     }
4349
4350     if (debug)
4351     {
4352         fprintf(debug,"PME mesh energy: %g\n",*energy);
4353     }
4354
4355     return 0;
4356 }