renamed GMX_THREADS to GMX_THREAD_MPI
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #ifdef GMX_OPENMP
71 #include <omp.h>
72 #endif
73
74 #include <stdio.h>
75 #include <string.h>
76 #include <math.h>
77 #include "typedefs.h"
78 #include "txtdump.h"
79 #include "vec.h"
80 #include "gmxcomplex.h"
81 #include "smalloc.h"
82 #include "futil.h"
83 #include "coulomb.h"
84 #include "gmx_fatal.h"
85 #include "pme.h"
86 #include "network.h"
87 #include "physics.h"
88 #include "nrnb.h"
89 #include "copyrite.h"
90 #include "gmx_wallcycle.h"
91 #include "gmx_parallel_3dfft.h"
92 #include "pdbio.h"
93 #include "gmx_cyclecounter.h"
94
95 #if ( !defined(GMX_DOUBLE) && ( defined(GMX_IA32_SSE) || defined(GMX_X86_64_SSE) || defined(GMX_X86_64_SSE2) ) )
96 #include "gmx_sse2_single.h"
97
98 #define PME_SSE
99 /* Some old AMD processors could have problems with unaligned loads+stores */
100 #ifndef GMX_FAHCORE
101 #define PME_SSE_UNALIGNED
102 #endif
103 #endif
104
105 #include "mpelogging.h"
106
107 #define DFT_TOL 1e-7
108 /* #define PRT_FORCE */
109 /* conditions for on the fly time-measurement */
110 /* #define TAKETIME (step > 1 && timesteps < 10) */
111 #define TAKETIME FALSE
112
113 /* #define PME_TIME_THREADS */
114
115 #ifdef GMX_DOUBLE
116 #define mpi_type MPI_DOUBLE
117 #else
118 #define mpi_type MPI_FLOAT
119 #endif
120
121 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
122 #define GMX_CACHE_SEP 64
123
124 /* We only define a maximum to be able to use local arrays without allocation.
125  * An order larger than 12 should never be needed, even for test cases.
126  * If needed it can be changed here.
127  */
128 #define PME_ORDER_MAX 12
129
130 /* Internal datastructures */
131 typedef struct {
132     int send_index0;
133     int send_nindex;
134     int recv_index0;
135     int recv_nindex;
136 } pme_grid_comm_t;
137
138 typedef struct {
139 #ifdef GMX_MPI
140     MPI_Comm mpi_comm;
141 #endif
142     int  nnodes,nodeid;
143     int  *s2g0;
144     int  *s2g1;
145     int  noverlap_nodes;
146     int  *send_id,*recv_id;
147     pme_grid_comm_t *comm_data;
148     real *sendbuf;
149     real *recvbuf;
150 } pme_overlap_t;
151
152 typedef struct {
153     int *n;     /* Cumulative counts of the number of particles per thread */
154     int nalloc; /* Allocation size of i */
155     int *i;     /* Particle indices ordered on thread index (n) */
156 } thread_plist_t;
157
158 typedef struct {
159     int  n;
160     int  *ind;
161     splinevec theta;
162     splinevec dtheta;
163 } splinedata_t;
164
165 typedef struct {
166     int  dimind;            /* The index of the dimension, 0=x, 1=y */
167     int  nslab;
168     int  nodeid;
169 #ifdef GMX_MPI
170     MPI_Comm mpi_comm;
171 #endif
172
173     int  *node_dest;        /* The nodes to send x and q to with DD */
174     int  *node_src;         /* The nodes to receive x and q from with DD */
175     int  *buf_index;        /* Index for commnode into the buffers */
176
177     int  maxshift;
178
179     int  npd;
180     int  pd_nalloc;
181     int  *pd;
182     int  *count;            /* The number of atoms to send to each node */
183     int  **count_thread;
184     int  *rcount;           /* The number of atoms to receive */
185
186     int  n;
187     int  nalloc;
188     rvec *x;
189     real *q;
190     rvec *f;
191     gmx_bool bSpread;       /* These coordinates are used for spreading */
192     int  pme_order;
193     ivec *idx;
194     rvec *fractx;            /* Fractional coordinate relative to the
195                               * lower cell boundary
196                               */
197     int  nthread;
198     int  *thread_idx;        /* Which thread should spread which charge */
199     thread_plist_t *thread_plist;
200     splinedata_t *spline;
201 } pme_atomcomm_t;
202
203 #define FLBS  3
204 #define FLBSZ 4
205
206 typedef struct {
207     ivec ci;     /* The spatial location of this grid       */
208     ivec n;      /* The size of *grid, including order-1    */
209     ivec offset; /* The grid offset from the full node grid */
210     int  order;  /* PME spreading order                     */
211     real *grid;  /* The grid local thread, size n           */
212 } pmegrid_t;
213
214 typedef struct {
215     pmegrid_t grid;     /* The full node grid (non thread-local)            */
216     int  nthread;       /* The number of threads operating on this grid     */
217     ivec nc;            /* The local spatial decomposition over the threads */
218     pmegrid_t *grid_th; /* Array of grids for each thread                   */
219     int  **g2t;         /* The grid to thread index                         */
220     ivec nthread_comm;  /* The number of threads to communicate with        */
221 } pmegrids_t;
222
223
224 typedef struct {
225 #ifdef PME_SSE
226     /* Masks for SSE aligned spreading and gathering */
227     __m128 mask_SSE0[6],mask_SSE1[6];
228 #else
229     int dummy; /* C89 requires that struct has at least one member */
230 #endif
231 } pme_spline_work_t;
232
233 typedef struct {
234     /* work data for solve_pme */
235     int      nalloc;
236     real *   mhx;
237     real *   mhy;
238     real *   mhz;
239     real *   m2;
240     real *   denom;
241     real *   tmp1_alloc;
242     real *   tmp1;
243     real *   eterm;
244     real *   m2inv;
245
246     real     energy;
247     matrix   vir;
248 } pme_work_t;
249
250 typedef struct gmx_pme {
251     int  ndecompdim;         /* The number of decomposition dimensions */
252     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
253     int  nodeid_major;
254     int  nodeid_minor;
255     int  nnodes;             /* The number of nodes doing PME */
256     int  nnodes_major;
257     int  nnodes_minor;
258
259     MPI_Comm mpi_comm;
260     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
261 #ifdef GMX_MPI
262     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
263 #endif
264
265     int  nthread;            /* The number of threads doing PME */
266
267     gmx_bool bPPnode;        /* Node also does particle-particle forces */
268     gmx_bool bFEP;           /* Compute Free energy contribution */
269     int nkx,nky,nkz;         /* Grid dimensions */
270     int pme_order;
271     real epsilon_r;
272
273     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
274     pmegrids_t pmegridB;
275     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
276     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
277     /* pmegrid_nz might be larger than strictly necessary to ensure
278      * memory alignment, pmegrid_nz_base gives the real base size.
279      */
280     int     pmegrid_nz_base;
281     /* The local PME grid starting indices */
282     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
283
284     /* Work data for spreading and gathering */
285     pme_spline_work_t spline_work;
286
287     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
288     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
289     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
290
291     t_complex *cfftgridA;             /* Grids for complex FFT data */
292     t_complex *cfftgridB;
293     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
294
295     gmx_parallel_3dfft_t  pfft_setupA;
296     gmx_parallel_3dfft_t  pfft_setupB;
297
298     int  *nnx,*nny,*nnz;
299     real *fshx,*fshy,*fshz;
300
301     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
302     matrix    recipbox;
303     splinevec bsp_mod;
304
305     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
306
307     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
308
309     rvec *bufv;             /* Communication buffer */
310     real *bufr;             /* Communication buffer */
311     int  buf_nalloc;        /* The communication buffer size */
312
313     /* thread local work data for solve_pme */
314     pme_work_t *work;
315
316     /* Work data for PME_redist */
317     gmx_bool redist_init;
318     int *    scounts;
319     int *    rcounts;
320     int *    sdispls;
321     int *    rdispls;
322     int *    sidx;
323     int *    idxa;
324     real *   redist_buf;
325     int      redist_buf_nalloc;
326
327     /* Work data for sum_qgrid */
328     real *   sum_qgrid_tmp;
329     real *   sum_qgrid_dd_tmp;
330 } t_gmx_pme;
331
332
333 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
334                                    int start,int end,int thread)
335 {
336     int  i;
337     int  *idxptr,tix,tiy,tiz;
338     real *xptr,*fptr,tx,ty,tz;
339     real rxx,ryx,ryy,rzx,rzy,rzz;
340     int  nx,ny,nz;
341     int  start_ix,start_iy,start_iz;
342     int  *g2tx,*g2ty,*g2tz;
343     gmx_bool bThreads;
344     int  *thread_idx=NULL;
345     thread_plist_t *tpl=NULL;
346     int  *tpl_n=NULL;
347     int  thread_i;
348
349     nx  = pme->nkx;
350     ny  = pme->nky;
351     nz  = pme->nkz;
352
353     start_ix = pme->pmegrid_start_ix;
354     start_iy = pme->pmegrid_start_iy;
355     start_iz = pme->pmegrid_start_iz;
356
357     rxx = pme->recipbox[XX][XX];
358     ryx = pme->recipbox[YY][XX];
359     ryy = pme->recipbox[YY][YY];
360     rzx = pme->recipbox[ZZ][XX];
361     rzy = pme->recipbox[ZZ][YY];
362     rzz = pme->recipbox[ZZ][ZZ];
363
364     g2tx = pme->pmegridA.g2t[XX];
365     g2ty = pme->pmegridA.g2t[YY];
366     g2tz = pme->pmegridA.g2t[ZZ];
367
368     bThreads = (atc->nthread > 1);
369     if (bThreads)
370     {
371         thread_idx = atc->thread_idx;
372
373         tpl   = &atc->thread_plist[thread];
374         tpl_n = tpl->n;
375         for(i=0; i<atc->nthread; i++)
376         {
377             tpl_n[i] = 0;
378         }
379     }
380
381     for(i=start; i<end; i++) {
382         xptr   = atc->x[i];
383         idxptr = atc->idx[i];
384         fptr   = atc->fractx[i];
385
386         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
387         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
388         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
389         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
390
391         tix = (int)(tx);
392         tiy = (int)(ty);
393         tiz = (int)(tz);
394
395         /* Because decomposition only occurs in x and y,
396          * we never have a fraction correction in z.
397          */
398         fptr[XX] = tx - tix + pme->fshx[tix];
399         fptr[YY] = ty - tiy + pme->fshy[tiy];
400         fptr[ZZ] = tz - tiz;
401
402         idxptr[XX] = pme->nnx[tix];
403         idxptr[YY] = pme->nny[tiy];
404         idxptr[ZZ] = pme->nnz[tiz];
405
406 #ifdef DEBUG
407         range_check(idxptr[XX],0,pme->pmegrid_nx);
408         range_check(idxptr[YY],0,pme->pmegrid_ny);
409         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
410 #endif
411
412         if (bThreads)
413         {
414             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
415             thread_idx[i] = thread_i;
416             tpl_n[thread_i]++;
417         }
418     }
419
420     if (bThreads)
421     {
422         /* Make a list of particle indices sorted on thread */
423
424         /* Get the cumulative count */
425         for(i=1; i<atc->nthread; i++)
426         {
427             tpl_n[i] += tpl_n[i-1];
428         }
429         /* The current implementation distributes particles equally
430          * over the threads, so we could actually allocate for that
431          * in pme_realloc_atomcomm_things.
432          */
433         if (tpl_n[atc->nthread-1] > tpl->nalloc)
434         {
435             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
436             srenew(tpl->i,tpl->nalloc);
437         }
438         /* Set tpl_n to the cumulative start */
439         for(i=atc->nthread-1; i>=1; i--)
440         {
441             tpl_n[i] = tpl_n[i-1];
442         }
443         tpl_n[0] = 0;
444
445         /* Fill our thread local array with indices sorted on thread */
446         for(i=start; i<end; i++)
447         {
448             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
449         }
450         /* Now tpl_n contains the cummulative count again */
451     }
452 }
453
454 static void make_thread_local_ind(pme_atomcomm_t *atc,
455                                   int thread,splinedata_t *spline)
456 {
457     int  n,t,i,start,end;
458     thread_plist_t *tpl;
459
460     /* Combine the indices made by each thread into one index */
461
462     n = 0;
463     start = 0;
464     for(t=0; t<atc->nthread; t++)
465     {
466         tpl = &atc->thread_plist[t];
467         /* Copy our part (start - end) from the list of thread t */
468         if (thread > 0)
469         {
470             start = tpl->n[thread-1];
471         }
472         end = tpl->n[thread];
473         for(i=start; i<end; i++)
474         {
475             spline->ind[n++] = tpl->i[i];
476         }
477     }
478
479     spline->n = n;
480 }
481
482
483 static void pme_calc_pidx(int start, int end,
484                           matrix recipbox, rvec x[],
485                           pme_atomcomm_t *atc, int *count)
486 {
487     int  nslab,i;
488     int  si;
489     real *xptr,s;
490     real rxx,ryx,rzx,ryy,rzy;
491     int *pd;
492
493     /* Calculate PME task index (pidx) for each grid index.
494      * Here we always assign equally sized slabs to each node
495      * for load balancing reasons (the PME grid spacing is not used).
496      */
497
498     nslab = atc->nslab;
499     pd    = atc->pd;
500
501     /* Reset the count */
502     for(i=0; i<nslab; i++)
503     {
504         count[i] = 0;
505     }
506
507     if (atc->dimind == 0)
508     {
509         rxx = recipbox[XX][XX];
510         ryx = recipbox[YY][XX];
511         rzx = recipbox[ZZ][XX];
512         /* Calculate the node index in x-dimension */
513         for(i=start; i<end; i++)
514         {
515             xptr   = x[i];
516             /* Fractional coordinates along box vectors */
517             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
518             si = (int)(s + 2*nslab) % nslab;
519             pd[i] = si;
520             count[si]++;
521         }
522     }
523     else
524     {
525         ryy = recipbox[YY][YY];
526         rzy = recipbox[ZZ][YY];
527         /* Calculate the node index in y-dimension */
528         for(i=start; i<end; i++)
529         {
530             xptr   = x[i];
531             /* Fractional coordinates along box vectors */
532             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
533             si = (int)(s + 2*nslab) % nslab;
534             pd[i] = si;
535             count[si]++;
536         }
537     }
538 }
539
540 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
541                                   pme_atomcomm_t *atc)
542 {
543     int nthread,thread,slab;
544
545     nthread = atc->nthread;
546
547 #pragma omp parallel for num_threads(nthread) schedule(static)
548     for(thread=0; thread<nthread; thread++)
549     {
550         pme_calc_pidx(natoms* thread   /nthread,
551                       natoms*(thread+1)/nthread,
552                       recipbox,x,atc,atc->count_thread[thread]);
553     }
554     /* Non-parallel reduction, since nslab is small */
555
556     for(thread=1; thread<nthread; thread++)
557     {
558         for(slab=0; slab<atc->nslab; slab++)
559         {
560             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
561         }
562     }
563 }
564
565 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
566 {
567     int i,d;
568
569     srenew(spline->ind,atc->nalloc);
570     /* Initialize the index to identity so it works without threads */
571     for(i=0; i<atc->nalloc; i++)
572     {
573         spline->ind[i] = i;
574     }
575
576     for(d=0;d<DIM;d++)
577     {
578         srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
579         srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
580     }
581 }
582
583 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
584 {
585     int nalloc_old,i,j,nalloc_tpl;
586
587     /* We have to avoid a NULL pointer for atc->x to avoid
588      * possible fatal errors in MPI routines.
589      */
590     if (atc->n > atc->nalloc || atc->nalloc == 0)
591     {
592         nalloc_old = atc->nalloc;
593         atc->nalloc = over_alloc_dd(max(atc->n,1));
594
595         if (atc->nslab > 1) {
596             srenew(atc->x,atc->nalloc);
597             srenew(atc->q,atc->nalloc);
598             srenew(atc->f,atc->nalloc);
599             for(i=nalloc_old; i<atc->nalloc; i++)
600             {
601                 clear_rvec(atc->f[i]);
602             }
603         }
604         if (atc->bSpread) {
605             srenew(atc->fractx,atc->nalloc);
606             srenew(atc->idx   ,atc->nalloc);
607
608             if (atc->nthread > 1)
609             {
610                 srenew(atc->thread_idx,atc->nalloc);
611             }
612
613             for(i=0; i<atc->nthread; i++)
614             {
615                 pme_realloc_splinedata(&atc->spline[i],atc);
616             }
617         }
618     }
619 }
620
621 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
622                          int n, gmx_bool bXF, rvec *x_f, real *charge,
623                          pme_atomcomm_t *atc)
624 /* Redistribute particle data for PME calculation */
625 /* domain decomposition by x coordinate           */
626 {
627     int *idxa;
628     int i, ii;
629
630     if(FALSE == pme->redist_init) {
631         snew(pme->scounts,atc->nslab);
632         snew(pme->rcounts,atc->nslab);
633         snew(pme->sdispls,atc->nslab);
634         snew(pme->rdispls,atc->nslab);
635         snew(pme->sidx,atc->nslab);
636         pme->redist_init = TRUE;
637     }
638     if (n > pme->redist_buf_nalloc) {
639         pme->redist_buf_nalloc = over_alloc_dd(n);
640         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
641     }
642
643     pme->idxa = atc->pd;
644
645 #ifdef GMX_MPI
646     if (forw && bXF) {
647         /* forward, redistribution from pp to pme */
648
649         /* Calculate send counts and exchange them with other nodes */
650         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
651         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
652         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
653
654         /* Calculate send and receive displacements and index into send
655            buffer */
656         pme->sdispls[0]=0;
657         pme->rdispls[0]=0;
658         pme->sidx[0]=0;
659         for(i=1; i<atc->nslab; i++) {
660             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
661             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
662             pme->sidx[i]=pme->sdispls[i];
663         }
664         /* Total # of particles to be received */
665         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
666
667         pme_realloc_atomcomm_things(atc);
668
669         /* Copy particle coordinates into send buffer and exchange*/
670         for(i=0; (i<n); i++) {
671             ii=DIM*pme->sidx[pme->idxa[i]];
672             pme->sidx[pme->idxa[i]]++;
673             pme->redist_buf[ii+XX]=x_f[i][XX];
674             pme->redist_buf[ii+YY]=x_f[i][YY];
675             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
676         }
677         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
678                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
679                       pme->rvec_mpi, atc->mpi_comm);
680     }
681     if (forw) {
682         /* Copy charge into send buffer and exchange*/
683         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
684         for(i=0; (i<n); i++) {
685             ii=pme->sidx[pme->idxa[i]];
686             pme->sidx[pme->idxa[i]]++;
687             pme->redist_buf[ii]=charge[i];
688         }
689         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
690                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
691                       atc->mpi_comm);
692     }
693     else { /* backward, redistribution from pme to pp */
694         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
695                       pme->redist_buf, pme->scounts, pme->sdispls,
696                       pme->rvec_mpi, atc->mpi_comm);
697
698         /* Copy data from receive buffer */
699         for(i=0; i<atc->nslab; i++)
700             pme->sidx[i] = pme->sdispls[i];
701         for(i=0; (i<n); i++) {
702             ii = DIM*pme->sidx[pme->idxa[i]];
703             x_f[i][XX] += pme->redist_buf[ii+XX];
704             x_f[i][YY] += pme->redist_buf[ii+YY];
705             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
706             pme->sidx[pme->idxa[i]]++;
707         }
708     }
709 #endif
710 }
711
712 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
713                             gmx_bool bBackward,int shift,
714                             void *buf_s,int nbyte_s,
715                             void *buf_r,int nbyte_r)
716 {
717 #ifdef GMX_MPI
718     int dest,src;
719     MPI_Status stat;
720
721     if (bBackward == FALSE) {
722         dest = atc->node_dest[shift];
723         src  = atc->node_src[shift];
724     } else {
725         dest = atc->node_src[shift];
726         src  = atc->node_dest[shift];
727     }
728
729     if (nbyte_s > 0 && nbyte_r > 0) {
730         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
731                      dest,shift,
732                      buf_r,nbyte_r,MPI_BYTE,
733                      src,shift,
734                      atc->mpi_comm,&stat);
735     } else if (nbyte_s > 0) {
736         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
737                  dest,shift,
738                  atc->mpi_comm);
739     } else if (nbyte_r > 0) {
740         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
741                  src,shift,
742                  atc->mpi_comm,&stat);
743     }
744 #endif
745 }
746
747 static void dd_pmeredist_x_q(gmx_pme_t pme,
748                              int n, gmx_bool bX, rvec *x, real *charge,
749                              pme_atomcomm_t *atc)
750 {
751     int *commnode,*buf_index;
752     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
753
754     commnode  = atc->node_dest;
755     buf_index = atc->buf_index;
756
757     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
758
759     nsend = 0;
760     for(i=0; i<nnodes_comm; i++) {
761         buf_index[commnode[i]] = nsend;
762         nsend += atc->count[commnode[i]];
763     }
764     if (bX) {
765         if (atc->count[atc->nodeid] + nsend != n)
766             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
767                       "This usually means that your system is not well equilibrated.",
768                       n - (atc->count[atc->nodeid] + nsend),
769                       pme->nodeid,'x'+atc->dimind);
770
771         if (nsend > pme->buf_nalloc) {
772             pme->buf_nalloc = over_alloc_dd(nsend);
773             srenew(pme->bufv,pme->buf_nalloc);
774             srenew(pme->bufr,pme->buf_nalloc);
775         }
776
777         atc->n = atc->count[atc->nodeid];
778         for(i=0; i<nnodes_comm; i++) {
779             scount = atc->count[commnode[i]];
780             /* Communicate the count */
781             if (debug)
782                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
783                         atc->dimind,atc->nodeid,commnode[i],scount);
784             pme_dd_sendrecv(atc,FALSE,i,
785                             &scount,sizeof(int),
786                             &atc->rcount[i],sizeof(int));
787             atc->n += atc->rcount[i];
788         }
789
790         pme_realloc_atomcomm_things(atc);
791     }
792
793     local_pos = 0;
794     for(i=0; i<n; i++) {
795         node = atc->pd[i];
796         if (node == atc->nodeid) {
797             /* Copy direct to the receive buffer */
798             if (bX) {
799                 copy_rvec(x[i],atc->x[local_pos]);
800             }
801             atc->q[local_pos] = charge[i];
802             local_pos++;
803         } else {
804             /* Copy to the send buffer */
805             if (bX) {
806                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
807             }
808             pme->bufr[buf_index[node]] = charge[i];
809             buf_index[node]++;
810         }
811     }
812
813     buf_pos = 0;
814     for(i=0; i<nnodes_comm; i++) {
815         scount = atc->count[commnode[i]];
816         rcount = atc->rcount[i];
817         if (scount > 0 || rcount > 0) {
818             if (bX) {
819                 /* Communicate the coordinates */
820                 pme_dd_sendrecv(atc,FALSE,i,
821                                 pme->bufv[buf_pos],scount*sizeof(rvec),
822                                 atc->x[local_pos],rcount*sizeof(rvec));
823             }
824             /* Communicate the charges */
825             pme_dd_sendrecv(atc,FALSE,i,
826                             pme->bufr+buf_pos,scount*sizeof(real),
827                             atc->q+local_pos,rcount*sizeof(real));
828             buf_pos   += scount;
829             local_pos += atc->rcount[i];
830         }
831     }
832 }
833
834 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
835                            int n, rvec *f,
836                            gmx_bool bAddF)
837 {
838   int *commnode,*buf_index;
839   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
840
841   commnode  = atc->node_dest;
842   buf_index = atc->buf_index;
843
844   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
845
846   local_pos = atc->count[atc->nodeid];
847   buf_pos = 0;
848   for(i=0; i<nnodes_comm; i++) {
849     scount = atc->rcount[i];
850     rcount = atc->count[commnode[i]];
851     if (scount > 0 || rcount > 0) {
852       /* Communicate the forces */
853       pme_dd_sendrecv(atc,TRUE,i,
854                       atc->f[local_pos],scount*sizeof(rvec),
855                       pme->bufv[buf_pos],rcount*sizeof(rvec));
856       local_pos += scount;
857     }
858     buf_index[commnode[i]] = buf_pos;
859     buf_pos   += rcount;
860   }
861
862     local_pos = 0;
863     if (bAddF)
864     {
865         for(i=0; i<n; i++)
866         {
867             node = atc->pd[i];
868             if (node == atc->nodeid)
869             {
870                 /* Add from the local force array */
871                 rvec_inc(f[i],atc->f[local_pos]);
872                 local_pos++;
873             }
874             else
875             {
876                 /* Add from the receive buffer */
877                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
878                 buf_index[node]++;
879             }
880         }
881     }
882     else
883     {
884         for(i=0; i<n; i++)
885         {
886             node = atc->pd[i];
887             if (node == atc->nodeid)
888             {
889                 /* Copy from the local force array */
890                 copy_rvec(atc->f[local_pos],f[i]);
891                 local_pos++;
892             }
893             else
894             {
895                 /* Copy from the receive buffer */
896                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
897                 buf_index[node]++;
898             }
899         }
900     }
901 }
902
903 #ifdef GMX_MPI
904 static void
905 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
906 {
907     pme_overlap_t *overlap;
908     int send_index0,send_nindex;
909     int recv_index0,recv_nindex;
910     MPI_Status stat;
911     int i,j,k,ix,iy,iz,icnt;
912     int ipulse,send_id,recv_id,datasize;
913     real *p;
914     real *sendptr,*recvptr;
915
916     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
917     overlap = &pme->overlap[1];
918
919     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
920     {
921         /* Since we have already (un)wrapped the overlap in the z-dimension,
922          * we only have to communicate 0 to nkz (not pmegrid_nz).
923          */
924         if (direction==GMX_SUM_QGRID_FORWARD)
925         {
926             send_id = overlap->send_id[ipulse];
927             recv_id = overlap->recv_id[ipulse];
928             send_index0   = overlap->comm_data[ipulse].send_index0;
929             send_nindex   = overlap->comm_data[ipulse].send_nindex;
930             recv_index0   = overlap->comm_data[ipulse].recv_index0;
931             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
932         }
933         else
934         {
935             send_id = overlap->recv_id[ipulse];
936             recv_id = overlap->send_id[ipulse];
937             send_index0   = overlap->comm_data[ipulse].recv_index0;
938             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
939             recv_index0   = overlap->comm_data[ipulse].send_index0;
940             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
941         }
942
943         /* Copy data to contiguous send buffer */
944         if (debug)
945         {
946             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
947                     pme->nodeid,overlap->nodeid,send_id,
948                     pme->pmegrid_start_iy,
949                     send_index0-pme->pmegrid_start_iy,
950                     send_index0-pme->pmegrid_start_iy+send_nindex);
951         }
952         icnt = 0;
953         for(i=0;i<pme->pmegrid_nx;i++)
954         {
955             ix = i;
956             for(j=0;j<send_nindex;j++)
957             {
958                 iy = j + send_index0 - pme->pmegrid_start_iy;
959                 for(k=0;k<pme->nkz;k++)
960                 {
961                     iz = k;
962                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
963                 }
964             }
965         }
966
967         datasize      = pme->pmegrid_nx * pme->nkz;
968
969         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
970                      send_id,ipulse,
971                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
972                      recv_id,ipulse,
973                      overlap->mpi_comm,&stat);
974
975         /* Get data from contiguous recv buffer */
976         if (debug)
977         {
978             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
979                     pme->nodeid,overlap->nodeid,recv_id,
980                     pme->pmegrid_start_iy,
981                     recv_index0-pme->pmegrid_start_iy,
982                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
983         }
984         icnt = 0;
985         for(i=0;i<pme->pmegrid_nx;i++)
986         {
987             ix = i;
988             for(j=0;j<recv_nindex;j++)
989             {
990                 iy = j + recv_index0 - pme->pmegrid_start_iy;
991                 for(k=0;k<pme->nkz;k++)
992                 {
993                     iz = k;
994                     if(direction==GMX_SUM_QGRID_FORWARD)
995                     {
996                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
997                     }
998                     else
999                     {
1000                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1001                     }
1002                 }
1003             }
1004         }
1005     }
1006
1007     /* Major dimension is easier, no copying required,
1008      * but we might have to sum to separate array.
1009      * Since we don't copy, we have to communicate up to pmegrid_nz,
1010      * not nkz as for the minor direction.
1011      */
1012     overlap = &pme->overlap[0];
1013
1014     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1015     {
1016         if(direction==GMX_SUM_QGRID_FORWARD)
1017         {
1018             send_id = overlap->send_id[ipulse];
1019             recv_id = overlap->recv_id[ipulse];
1020             send_index0   = overlap->comm_data[ipulse].send_index0;
1021             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1022             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1023             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1024             recvptr   = overlap->recvbuf;
1025         }
1026         else
1027         {
1028             send_id = overlap->recv_id[ipulse];
1029             recv_id = overlap->send_id[ipulse];
1030             send_index0   = overlap->comm_data[ipulse].recv_index0;
1031             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1032             recv_index0   = overlap->comm_data[ipulse].send_index0;
1033             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1034             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1035         }
1036
1037         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1038         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1039
1040         if (debug)
1041         {
1042             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1043                     pme->nodeid,overlap->nodeid,send_id,
1044                     pme->pmegrid_start_ix,
1045                     send_index0-pme->pmegrid_start_ix,
1046                     send_index0-pme->pmegrid_start_ix+send_nindex);
1047             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1048                     pme->nodeid,overlap->nodeid,recv_id,
1049                     pme->pmegrid_start_ix,
1050                     recv_index0-pme->pmegrid_start_ix,
1051                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1052         }
1053
1054         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1055                      send_id,ipulse,
1056                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1057                      recv_id,ipulse,
1058                      overlap->mpi_comm,&stat);
1059
1060         /* ADD data from contiguous recv buffer */
1061         if(direction==GMX_SUM_QGRID_FORWARD)
1062         {
1063             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1064             for(i=0;i<recv_nindex*datasize;i++)
1065             {
1066                 p[i] += overlap->recvbuf[i];
1067             }
1068         }
1069     }
1070 }
1071 #endif
1072
1073
1074 static int
1075 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1076 {
1077     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1078     ivec    local_pme_size;
1079     int     i,ix,iy,iz;
1080     int     pmeidx,fftidx;
1081
1082     /* Dimensions should be identical for A/B grid, so we just use A here */
1083     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1084                                    local_fft_ndata,
1085                                    local_fft_offset,
1086                                    local_fft_size);
1087
1088     local_pme_size[0] = pme->pmegrid_nx;
1089     local_pme_size[1] = pme->pmegrid_ny;
1090     local_pme_size[2] = pme->pmegrid_nz;
1091
1092     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1093      the offset is identical, and the PME grid always has more data (due to overlap)
1094      */
1095     {
1096 #ifdef DEBUG_PME
1097         FILE *fp,*fp2;
1098         char fn[STRLEN],format[STRLEN];
1099         real val;
1100         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1101         fp = ffopen(fn,"w");
1102         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1103         fp2 = ffopen(fn,"w");
1104      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1105 #endif
1106
1107     for(ix=0;ix<local_fft_ndata[XX];ix++)
1108     {
1109         for(iy=0;iy<local_fft_ndata[YY];iy++)
1110         {
1111             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1112             {
1113                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1114                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1115                 fftgrid[fftidx] = pmegrid[pmeidx];
1116 #ifdef DEBUG_PME
1117                 val = 100*pmegrid[pmeidx];
1118                 if (pmegrid[pmeidx] != 0)
1119                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1120                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1121                 if (pmegrid[pmeidx] != 0)
1122                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1123                             "qgrid",
1124                             pme->pmegrid_start_ix + ix,
1125                             pme->pmegrid_start_iy + iy,
1126                             pme->pmegrid_start_iz + iz,
1127                             pmegrid[pmeidx]);
1128 #endif
1129             }
1130         }
1131     }
1132 #ifdef DEBUG_PME
1133     fclose(fp);
1134     fclose(fp2);
1135 #endif
1136     }
1137     return 0;
1138 }
1139
1140
1141 static gmx_cycles_t omp_cyc_start()
1142 {
1143     return gmx_cycles_read();
1144 }
1145
1146 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1147 {
1148     return gmx_cycles_read() - c;
1149 }
1150
1151
1152 static int
1153 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1154                         int nthread,int thread)
1155 {
1156     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1157     ivec    local_pme_size;
1158     int     ixy0,ixy1,ixy,ix,iy,iz;
1159     int     pmeidx,fftidx;
1160 #ifdef PME_TIME_THREADS
1161     gmx_cycles_t c1;
1162     static double cs1=0;
1163     static int cnt=0;
1164 #endif
1165
1166 #ifdef PME_TIME_THREADS
1167     c1 = omp_cyc_start();
1168 #endif
1169     /* Dimensions should be identical for A/B grid, so we just use A here */
1170     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1171                                    local_fft_ndata,
1172                                    local_fft_offset,
1173                                    local_fft_size);
1174
1175     local_pme_size[0] = pme->pmegrid_nx;
1176     local_pme_size[1] = pme->pmegrid_ny;
1177     local_pme_size[2] = pme->pmegrid_nz;
1178
1179     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1180      the offset is identical, and the PME grid always has more data (due to overlap)
1181      */
1182     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1183     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1184
1185     for(ixy=ixy0;ixy<ixy1;ixy++)
1186     {
1187         ix = ixy/local_fft_ndata[YY];
1188         iy = ixy - ix*local_fft_ndata[YY];
1189
1190         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1191         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1192         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1193         {
1194             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1195         }
1196     }
1197
1198 #ifdef PME_TIME_THREADS
1199     c1 = omp_cyc_end(c1);
1200     cs1 += (double)c1;
1201     cnt++;
1202     if (cnt % 20 == 0)
1203     {
1204         printf("copy %.2f\n",cs1*1e-9);
1205     }
1206 #endif
1207
1208     return 0;
1209 }
1210
1211
1212 static void
1213 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1214 {
1215     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1216
1217     nx = pme->nkx;
1218     ny = pme->nky;
1219     nz = pme->nkz;
1220
1221     pnx = pme->pmegrid_nx;
1222     pny = pme->pmegrid_ny;
1223     pnz = pme->pmegrid_nz;
1224
1225     overlap = pme->pme_order - 1;
1226
1227     /* Add periodic overlap in z */
1228     for(ix=0; ix<pme->pmegrid_nx; ix++)
1229     {
1230         for(iy=0; iy<pme->pmegrid_ny; iy++)
1231         {
1232             for(iz=0; iz<overlap; iz++)
1233             {
1234                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1235                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1236             }
1237         }
1238     }
1239
1240     if (pme->nnodes_minor == 1)
1241     {
1242        for(ix=0; ix<pme->pmegrid_nx; ix++)
1243        {
1244            for(iy=0; iy<overlap; iy++)
1245            {
1246                for(iz=0; iz<nz; iz++)
1247                {
1248                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1249                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1250                }
1251            }
1252        }
1253     }
1254
1255     if (pme->nnodes_major == 1)
1256     {
1257         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1258
1259         for(ix=0; ix<overlap; ix++)
1260         {
1261             for(iy=0; iy<ny_x; iy++)
1262             {
1263                 for(iz=0; iz<nz; iz++)
1264                 {
1265                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1266                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1267                 }
1268             }
1269         }
1270     }
1271 }
1272
1273
1274 static void
1275 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1276 {
1277     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1278
1279     nx = pme->nkx;
1280     ny = pme->nky;
1281     nz = pme->nkz;
1282
1283     pnx = pme->pmegrid_nx;
1284     pny = pme->pmegrid_ny;
1285     pnz = pme->pmegrid_nz;
1286
1287     overlap = pme->pme_order - 1;
1288
1289     if (pme->nnodes_major == 1)
1290     {
1291         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1292
1293         for(ix=0; ix<overlap; ix++)
1294         {
1295             int iy,iz;
1296
1297             for(iy=0; iy<ny_x; iy++)
1298             {
1299                 for(iz=0; iz<nz; iz++)
1300                 {
1301                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1302                         pmegrid[(ix*pny+iy)*pnz+iz];
1303                 }
1304             }
1305         }
1306     }
1307
1308     if (pme->nnodes_minor == 1)
1309     {
1310 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1311        for(ix=0; ix<pme->pmegrid_nx; ix++)
1312        {
1313            int iy,iz;
1314
1315            for(iy=0; iy<overlap; iy++)
1316            {
1317                for(iz=0; iz<nz; iz++)
1318                {
1319                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1320                        pmegrid[(ix*pny+iy)*pnz+iz];
1321                }
1322            }
1323        }
1324     }
1325
1326     /* Copy periodic overlap in z */
1327 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1328     for(ix=0; ix<pme->pmegrid_nx; ix++)
1329     {
1330         int iy,iz;
1331
1332         for(iy=0; iy<pme->pmegrid_ny; iy++)
1333         {
1334             for(iz=0; iz<overlap; iz++)
1335             {
1336                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1337                     pmegrid[(ix*pny+iy)*pnz+iz];
1338             }
1339         }
1340     }
1341 }
1342
1343 static void clear_grid(int nx,int ny,int nz,real *grid,
1344                        ivec fs,int *flag,
1345                        int fx,int fy,int fz,
1346                        int order)
1347 {
1348     int nc,ncz;
1349     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1350     int flind;
1351
1352     nc  = 2 + (order - 2)/FLBS;
1353     ncz = 2 + (order - 2)/FLBSZ;
1354
1355     for(fsx=fx; fsx<fx+nc; fsx++)
1356     {
1357         for(fsy=fy; fsy<fy+nc; fsy++)
1358         {
1359             for(fsz=fz; fsz<fz+ncz; fsz++)
1360             {
1361                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1362                 if (flag[flind] == 0)
1363                 {
1364                     gx = fsx*FLBS;
1365                     gy = fsy*FLBS;
1366                     gz = fsz*FLBSZ;
1367                     g0x = (gx*ny + gy)*nz + gz;
1368                     for(x=0; x<FLBS; x++)
1369                     {
1370                         g0y = g0x;
1371                         for(y=0; y<FLBS; y++)
1372                         {
1373                             for(z=0; z<FLBSZ; z++)
1374                             {
1375                                 grid[g0y+z] = 0;
1376                             }
1377                             g0y += nz;
1378                         }
1379                         g0x += ny*nz;
1380                     }
1381
1382                     flag[flind] = 1;
1383                 }
1384             }
1385         }
1386     }
1387 }
1388
1389 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1390 #define DO_BSPLINE(order)                            \
1391 for(ithx=0; (ithx<order); ithx++)                    \
1392 {                                                    \
1393     index_x = (i0+ithx)*pny*pnz;                     \
1394     valx    = qn*thx[ithx];                          \
1395                                                      \
1396     for(ithy=0; (ithy<order); ithy++)                \
1397     {                                                \
1398         valxy    = valx*thy[ithy];                   \
1399         index_xy = index_x+(j0+ithy)*pnz;            \
1400                                                      \
1401         for(ithz=0; (ithz<order); ithz++)            \
1402         {                                            \
1403             index_xyz        = index_xy+(k0+ithz);   \
1404             grid[index_xyz] += valxy*thz[ithz];      \
1405         }                                            \
1406     }                                                \
1407 }
1408
1409
1410 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1411                                      pme_atomcomm_t *atc, splinedata_t *spline,
1412                                      pme_spline_work_t *work)
1413 {
1414
1415     /* spread charges from home atoms to local grid */
1416     real     *grid;
1417     pme_overlap_t *ol;
1418     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1419     int *    idxptr;
1420     int      order,norder,index_x,index_xy,index_xyz;
1421     real     valx,valxy,qn;
1422     real     *thx,*thy,*thz;
1423     int      localsize, bndsize;
1424     int      pnx,pny,pnz,ndatatot;
1425     int      offx,offy,offz;
1426
1427     pnx = pmegrid->n[XX];
1428     pny = pmegrid->n[YY];
1429     pnz = pmegrid->n[ZZ];
1430
1431     offx = pmegrid->offset[XX];
1432     offy = pmegrid->offset[YY];
1433     offz = pmegrid->offset[ZZ];
1434
1435     ndatatot = pnx*pny*pnz;
1436     grid = pmegrid->grid;
1437     for(i=0;i<ndatatot;i++)
1438     {
1439         grid[i] = 0;
1440     }
1441
1442     order = pmegrid->order;
1443
1444     for(nn=0; nn<spline->n; nn++)
1445     {
1446         n  = spline->ind[nn];
1447         qn = atc->q[n];
1448
1449         if (qn != 0)
1450         {
1451             idxptr = atc->idx[n];
1452             norder = nn*order;
1453
1454             i0   = idxptr[XX] - offx;
1455             j0   = idxptr[YY] - offy;
1456             k0   = idxptr[ZZ] - offz;
1457
1458             thx = spline->theta[XX] + norder;
1459             thy = spline->theta[YY] + norder;
1460             thz = spline->theta[ZZ] + norder;
1461
1462             switch (order) {
1463             case 4:
1464 #ifdef PME_SSE
1465 #ifdef PME_SSE_UNALIGNED
1466 #define PME_SPREAD_SSE_ORDER4
1467 #else
1468 #define PME_SPREAD_SSE_ALIGNED
1469 #define PME_ORDER 4
1470 #endif
1471 #include "pme_sse_single.h"
1472 #else
1473                 DO_BSPLINE(4);
1474 #endif
1475                 break;
1476             case 5:
1477 #ifdef PME_SSE
1478 #define PME_SPREAD_SSE_ALIGNED
1479 #define PME_ORDER 5
1480 #include "pme_sse_single.h"
1481 #else
1482                 DO_BSPLINE(5);
1483 #endif
1484                 break;
1485             default:
1486                 DO_BSPLINE(order);
1487                 break;
1488             }
1489         }
1490     }
1491 }
1492
1493
1494 static void alloc_real_aligned(int n,real **ptr_raw,real **ptr)
1495 {
1496     snew(*ptr_raw,n+8);
1497
1498     *ptr = (real *) (((size_t) *ptr_raw + 16) & (~((size_t) 15)));
1499
1500 }
1501 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1502 {
1503 #ifdef PME_SSE
1504     if (pme_order == 5
1505 #ifndef PME_SSE_UNALIGNED
1506         || pme_order == 4
1507 #endif
1508         )
1509     {
1510         /* Round nz up to a multiple of 4 to ensure alignment */
1511         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1512     }
1513 #endif
1514 }
1515
1516 static void set_gridsize_alignment(int *gridsize,int pme_order)
1517 {
1518 #ifdef PME_SSE
1519 #ifndef PME_SSE_UNALIGNED
1520     if (pme_order == 4)
1521     {
1522         /* Add extra elements to ensured aligned operations do not go
1523          * beyond the allocated grid size.
1524          * Note that for pme_order=5, the pme grid z-size alignment
1525          * ensures that we will not go beyond the grid size.
1526          */
1527          *gridsize += 4;
1528     }
1529 #endif
1530 #endif
1531 }
1532
1533 static void pmegrid_init(pmegrid_t *grid,
1534                          int cx, int cy, int cz,
1535                          int x0, int y0, int z0,
1536                          int x1, int y1, int z1,
1537                          gmx_bool set_alignment,
1538                          int pme_order,
1539                          real *ptr)
1540 {
1541     int nz,gridsize;
1542
1543     grid->ci[XX] = cx;
1544     grid->ci[YY] = cy;
1545     grid->ci[ZZ] = cz;
1546     grid->offset[XX] = x0;
1547     grid->offset[YY] = y0;
1548     grid->offset[ZZ] = z0;
1549     grid->n[XX]      = x1 - x0 + pme_order - 1;
1550     grid->n[YY]      = y1 - y0 + pme_order - 1;
1551     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1552
1553     nz = grid->n[ZZ];
1554     set_grid_alignment(&nz,pme_order);
1555     if (set_alignment)
1556     {
1557         grid->n[ZZ] = nz;
1558     }
1559     else if (nz != grid->n[ZZ])
1560     {
1561         gmx_incons("pmegrid_init call with an unaligned z size");
1562     }
1563
1564     grid->order = pme_order;
1565     if (ptr == NULL)
1566     {
1567         gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
1568         set_gridsize_alignment(&gridsize,pme_order);
1569         snew_aligned(grid->grid,gridsize,16);
1570     }
1571     else
1572     {
1573         grid->grid = ptr;
1574     }
1575 }
1576
1577 static int div_round_up(int enumerator,int denominator)
1578 {
1579     return (enumerator + denominator - 1)/denominator;
1580 }
1581
1582 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1583                                   ivec nsub)
1584 {
1585     int gsize_opt,gsize;
1586     int nsx,nsy,nsz;
1587     char *env;
1588
1589     gsize_opt = -1;
1590     for(nsx=1; nsx<=nthread; nsx++)
1591     {
1592         if (nthread % nsx == 0)
1593         {
1594             for(nsy=1; nsy<=nthread; nsy++)
1595             {
1596                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1597                 {
1598                     nsz = nthread/(nsx*nsy);
1599
1600                     /* Determine the number of grid points per thread */
1601                     gsize =
1602                         (div_round_up(n[XX],nsx) + ovl)*
1603                         (div_round_up(n[YY],nsy) + ovl)*
1604                         (div_round_up(n[ZZ],nsz) + ovl);
1605
1606                     /* Minimize the number of grids points per thread
1607                      * and, secondarily, the number of cuts in minor dimensions.
1608                      */
1609                     if (gsize_opt == -1 ||
1610                         gsize < gsize_opt ||
1611                         (gsize == gsize_opt &&
1612                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1613                     {
1614                         nsub[XX] = nsx;
1615                         nsub[YY] = nsy;
1616                         nsub[ZZ] = nsz;
1617                         gsize_opt = gsize;
1618                     }
1619                 }
1620             }
1621         }
1622     }
1623
1624     env = getenv("GMX_PME_THREAD_DIVISION");
1625     if (env != NULL)
1626     {
1627         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1628     }
1629
1630     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1631     {
1632         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1633     }
1634 }
1635
1636 static void pmegrids_init(pmegrids_t *grids,
1637                           int nx,int ny,int nz,int nz_base,
1638                           int pme_order,
1639                           int nthread,
1640                           int overlap_x,
1641                           int overlap_y)
1642 {
1643     ivec n,n_base,g0,g1;
1644     int t,x,y,z,d,i,tfac;
1645     int max_comm_lines;
1646
1647     n[XX] = nx - (pme_order - 1);
1648     n[YY] = ny - (pme_order - 1);
1649     n[ZZ] = nz - (pme_order - 1);
1650
1651     copy_ivec(n,n_base);
1652     n_base[ZZ] = nz_base;
1653
1654     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1655                  NULL);
1656
1657     grids->nthread = nthread;
1658
1659     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1660
1661     if (grids->nthread > 1)
1662     {
1663         ivec nst;
1664         int gridsize;
1665         real *grid_all;
1666
1667         for(d=0; d<DIM; d++)
1668         {
1669             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1670         }
1671         set_grid_alignment(&nst[ZZ],pme_order);
1672
1673         if (debug)
1674         {
1675             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1676                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1677             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1678                     nx,ny,nz,
1679                     nst[XX],nst[YY],nst[ZZ]);
1680         }
1681
1682         snew(grids->grid_th,grids->nthread);
1683         t = 0;
1684         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1685         set_gridsize_alignment(&gridsize,pme_order);
1686         snew_aligned(grid_all,
1687                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1688                      16);
1689
1690         for(x=0; x<grids->nc[XX]; x++)
1691         {
1692             for(y=0; y<grids->nc[YY]; y++)
1693             {
1694                 for(z=0; z<grids->nc[ZZ]; z++)
1695                 {
1696                     pmegrid_init(&grids->grid_th[t],
1697                                  x,y,z,
1698                                  (n[XX]*(x  ))/grids->nc[XX],
1699                                  (n[YY]*(y  ))/grids->nc[YY],
1700                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1701                                  (n[XX]*(x+1))/grids->nc[XX],
1702                                  (n[YY]*(y+1))/grids->nc[YY],
1703                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1704                                  TRUE,
1705                                  pme_order,
1706                                  grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1707                     t++;
1708                 }
1709             }
1710         }
1711     }
1712
1713     snew(grids->g2t,DIM);
1714     tfac = 1;
1715     for(d=DIM-1; d>=0; d--)
1716     {
1717         snew(grids->g2t[d],n[d]);
1718         t = 0;
1719         for(i=0; i<n[d]; i++)
1720         {
1721             /* The second check should match the parameters
1722              * of the pmegrid_init call above.
1723              */
1724             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1725             {
1726                 t++;
1727             }
1728             grids->g2t[d][i] = t*tfac;
1729         }
1730
1731         tfac *= grids->nc[d];
1732
1733         switch (d)
1734         {
1735         case XX: max_comm_lines = overlap_x;     break;
1736         case YY: max_comm_lines = overlap_y;     break;
1737         case ZZ: max_comm_lines = pme_order - 1; break;
1738         }
1739         grids->nthread_comm[d] = 0;
1740         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
1741         {
1742             grids->nthread_comm[d]++;
1743         }
1744         if (debug != NULL)
1745         {
1746             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1747                     'x'+d,grids->nthread_comm[d]);
1748         }
1749         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1750          * work, but this is not a problematic restriction.
1751          */
1752         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1753         {
1754             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1755         }
1756     }
1757 }
1758
1759
1760 static void pmegrids_destroy(pmegrids_t *grids)
1761 {
1762     int t;
1763
1764     if (grids->grid.grid != NULL)
1765     {
1766         sfree(grids->grid.grid);
1767
1768         if (grids->nthread > 0)
1769         {
1770             for(t=0; t<grids->nthread; t++)
1771             {
1772                 sfree(grids->grid_th[t].grid);
1773             }
1774             sfree(grids->grid_th);
1775         }
1776     }
1777 }
1778
1779
1780 static void realloc_work(pme_work_t *work,int nkx)
1781 {
1782     if (nkx > work->nalloc)
1783     {
1784         work->nalloc = nkx;
1785         srenew(work->mhx  ,work->nalloc);
1786         srenew(work->mhy  ,work->nalloc);
1787         srenew(work->mhz  ,work->nalloc);
1788         srenew(work->m2   ,work->nalloc);
1789         srenew(work->denom,work->nalloc);
1790         /* Allocate an aligned pointer for SSE operations, including 3 extra
1791          * elements at the end since SSE operates on 4 elements at a time.
1792          */
1793         sfree_aligned(work->denom);
1794         sfree_aligned(work->tmp1);
1795         sfree_aligned(work->eterm);
1796         snew_aligned(work->denom,work->nalloc+3,16);
1797         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1798         snew_aligned(work->eterm,work->nalloc+3,16);
1799         srenew(work->m2inv,work->nalloc);
1800     }
1801 }
1802
1803
1804 static void free_work(pme_work_t *work)
1805 {
1806     sfree(work->mhx);
1807     sfree(work->mhy);
1808     sfree(work->mhz);
1809     sfree(work->m2);
1810     sfree_aligned(work->denom);
1811     sfree_aligned(work->tmp1);
1812     sfree_aligned(work->eterm);
1813     sfree(work->m2inv);
1814 }
1815
1816
1817 #ifdef PME_SSE
1818     /* Calculate exponentials through SSE in float precision */
1819 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1820 {
1821     {
1822         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1823         __m128 f_sse;
1824         __m128 lu;
1825         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1826         int kx;
1827         f_sse = _mm_load1_ps(&f);
1828         for(kx=0; kx<end; kx+=4)
1829         {
1830             tmp_d1   = _mm_load_ps(d_aligned+kx);
1831             lu       = _mm_rcp_ps(tmp_d1);
1832             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1833             tmp_r    = _mm_load_ps(r_aligned+kx);
1834             tmp_r    = gmx_mm_exp_ps(tmp_r);
1835             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1836             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1837             _mm_store_ps(e_aligned+kx,tmp_e);
1838         }
1839     }
1840 }
1841 #else
1842 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1843 {
1844     int kx;
1845     for(kx=start; kx<end; kx++)
1846     {
1847         d[kx] = 1.0/d[kx];
1848     }
1849     for(kx=start; kx<end; kx++)
1850     {
1851         r[kx] = exp(r[kx]);
1852     }
1853     for(kx=start; kx<end; kx++)
1854     {
1855         e[kx] = f*r[kx]*d[kx];
1856     }
1857 }
1858 #endif
1859
1860
1861 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1862                          real ewaldcoeff,real vol,
1863                          gmx_bool bEnerVir,
1864                          int nthread,int thread)
1865 {
1866     /* do recip sum over local cells in grid */
1867     /* y major, z middle, x minor or continuous */
1868     t_complex *p0;
1869     int     kx,ky,kz,maxkx,maxky,maxkz;
1870     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1871     real    mx,my,mz;
1872     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1873     real    ets2,struct2,vfactor,ets2vf;
1874     real    d1,d2,energy=0;
1875     real    by,bz;
1876     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1877     real    rxx,ryx,ryy,rzx,rzy,rzz;
1878     pme_work_t *work;
1879     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1880     real    mhxk,mhyk,mhzk,m2k;
1881     real    corner_fac;
1882     ivec    complex_order;
1883     ivec    local_ndata,local_offset,local_size;
1884     real    elfac;
1885
1886     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1887
1888     nx = pme->nkx;
1889     ny = pme->nky;
1890     nz = pme->nkz;
1891
1892     /* Dimensions should be identical for A/B grid, so we just use A here */
1893     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1894                                       complex_order,
1895                                       local_ndata,
1896                                       local_offset,
1897                                       local_size);
1898
1899     rxx = pme->recipbox[XX][XX];
1900     ryx = pme->recipbox[YY][XX];
1901     ryy = pme->recipbox[YY][YY];
1902     rzx = pme->recipbox[ZZ][XX];
1903     rzy = pme->recipbox[ZZ][YY];
1904     rzz = pme->recipbox[ZZ][ZZ];
1905
1906     maxkx = (nx+1)/2;
1907     maxky = (ny+1)/2;
1908     maxkz = nz/2+1;
1909
1910     work = &pme->work[thread];
1911     mhx   = work->mhx;
1912     mhy   = work->mhy;
1913     mhz   = work->mhz;
1914     m2    = work->m2;
1915     denom = work->denom;
1916     tmp1  = work->tmp1;
1917     eterm = work->eterm;
1918     m2inv = work->m2inv;
1919
1920     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1921     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1922
1923     for(iyz=iyz0; iyz<iyz1; iyz++)
1924     {
1925         iy = iyz/local_ndata[ZZ];
1926         iz = iyz - iy*local_ndata[ZZ];
1927
1928         ky = iy + local_offset[YY];
1929
1930         if (ky < maxky)
1931         {
1932             my = ky;
1933         }
1934         else
1935         {
1936             my = (ky - ny);
1937         }
1938
1939         by = M_PI*vol*pme->bsp_mod[YY][ky];
1940
1941         kz = iz + local_offset[ZZ];
1942
1943         mz = kz;
1944
1945         bz = pme->bsp_mod[ZZ][kz];
1946
1947         /* 0.5 correction for corner points */
1948         corner_fac = 1;
1949         if (kz == 0 || kz == (nz+1)/2)
1950         {
1951             corner_fac = 0.5;
1952         }
1953
1954         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1955
1956         /* We should skip the k-space point (0,0,0) */
1957         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1958         {
1959             kxstart = local_offset[XX];
1960         }
1961         else
1962         {
1963             kxstart = local_offset[XX] + 1;
1964             p0++;
1965         }
1966         kxend = local_offset[XX] + local_ndata[XX];
1967
1968         if (bEnerVir)
1969         {
1970             /* More expensive inner loop, especially because of the storage
1971              * of the mh elements in array's.
1972              * Because x is the minor grid index, all mh elements
1973              * depend on kx for triclinic unit cells.
1974              */
1975
1976                 /* Two explicit loops to avoid a conditional inside the loop */
1977             for(kx=kxstart; kx<maxkx; kx++)
1978             {
1979                 mx = kx;
1980
1981                 mhxk      = mx * rxx;
1982                 mhyk      = mx * ryx + my * ryy;
1983                 mhzk      = mx * rzx + my * rzy + mz * rzz;
1984                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
1985                 mhx[kx]   = mhxk;
1986                 mhy[kx]   = mhyk;
1987                 mhz[kx]   = mhzk;
1988                 m2[kx]    = m2k;
1989                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
1990                 tmp1[kx]  = -factor*m2k;
1991             }
1992
1993             for(kx=maxkx; kx<kxend; kx++)
1994             {
1995                 mx = (kx - nx);
1996
1997                 mhxk      = mx * rxx;
1998                 mhyk      = mx * ryx + my * ryy;
1999                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2000                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2001                 mhx[kx]   = mhxk;
2002                 mhy[kx]   = mhyk;
2003                 mhz[kx]   = mhzk;
2004                 m2[kx]    = m2k;
2005                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2006                 tmp1[kx]  = -factor*m2k;
2007             }
2008
2009             for(kx=kxstart; kx<kxend; kx++)
2010             {
2011                 m2inv[kx] = 1.0/m2[kx];
2012             }
2013
2014             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2015
2016             for(kx=kxstart; kx<kxend; kx++,p0++)
2017             {
2018                 d1      = p0->re;
2019                 d2      = p0->im;
2020
2021                 p0->re  = d1*eterm[kx];
2022                 p0->im  = d2*eterm[kx];
2023
2024                 struct2 = 2.0*(d1*d1+d2*d2);
2025
2026                 tmp1[kx] = eterm[kx]*struct2;
2027             }
2028
2029             for(kx=kxstart; kx<kxend; kx++)
2030             {
2031                 ets2     = corner_fac*tmp1[kx];
2032                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2033                 energy  += ets2;
2034
2035                 ets2vf   = ets2*vfactor;
2036                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2037                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2038                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2039                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2040                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2041                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2042             }
2043         }
2044         else
2045         {
2046             /* We don't need to calculate the energy and the virial.
2047              * In this case the triclinic overhead is small.
2048              */
2049
2050             /* Two explicit loops to avoid a conditional inside the loop */
2051
2052             for(kx=kxstart; kx<maxkx; kx++)
2053             {
2054                 mx = kx;
2055
2056                 mhxk      = mx * rxx;
2057                 mhyk      = mx * ryx + my * ryy;
2058                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2059                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2060                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2061                 tmp1[kx]  = -factor*m2k;
2062             }
2063
2064             for(kx=maxkx; kx<kxend; kx++)
2065             {
2066                 mx = (kx - nx);
2067
2068                 mhxk      = mx * rxx;
2069                 mhyk      = mx * ryx + my * ryy;
2070                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2071                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2072                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2073                 tmp1[kx]  = -factor*m2k;
2074             }
2075
2076             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2077
2078             for(kx=kxstart; kx<kxend; kx++,p0++)
2079             {
2080                 d1      = p0->re;
2081                 d2      = p0->im;
2082
2083                 p0->re  = d1*eterm[kx];
2084                 p0->im  = d2*eterm[kx];
2085             }
2086         }
2087     }
2088
2089     if (bEnerVir)
2090     {
2091         /* Update virial with local values.
2092          * The virial is symmetric by definition.
2093          * this virial seems ok for isotropic scaling, but I'm
2094          * experiencing problems on semiisotropic membranes.
2095          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2096          */
2097         work->vir[XX][XX] = 0.25*virxx;
2098         work->vir[YY][YY] = 0.25*viryy;
2099         work->vir[ZZ][ZZ] = 0.25*virzz;
2100         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2101         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2102         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2103
2104         /* This energy should be corrected for a charged system */
2105         work->energy = 0.5*energy;
2106     }
2107
2108     /* Return the loop count */
2109     return local_ndata[YY]*local_ndata[XX];
2110 }
2111
2112 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2113                              real *mesh_energy,matrix vir)
2114 {
2115     /* This function sums output over threads
2116      * and should therefore only be called after thread synchronization.
2117      */
2118     int thread;
2119
2120     *mesh_energy = pme->work[0].energy;
2121     copy_mat(pme->work[0].vir,vir);
2122
2123     for(thread=1; thread<nthread; thread++)
2124     {
2125         *mesh_energy += pme->work[thread].energy;
2126         m_add(vir,pme->work[thread].vir,vir);
2127     }
2128 }
2129
2130 static int solve_pme_yzx_wrapper(gmx_pme_t pme,t_complex *grid,
2131                                  real ewaldcoeff,real vol,
2132                                  gmx_bool bEnerVir,real *mesh_energy,matrix vir)
2133 {
2134     int  nthread,thread;
2135     int  nelements=0;
2136
2137     nthread = pme->nthread;
2138
2139 #pragma omp parallel for num_threads(nthread) schedule(static)
2140     for(thread=0; thread<nthread; thread++)
2141     {
2142         int n;
2143
2144         n = solve_pme_yzx(pme,grid,ewaldcoeff,vol,bEnerVir,nthread,thread);
2145         if (thread == 0)
2146         {
2147             nelements = n;
2148         }
2149     }
2150
2151     if (bEnerVir)
2152     {
2153         get_pme_ener_vir(pme,nthread,mesh_energy,vir);
2154     }
2155
2156     return nelements;
2157 }
2158
2159
2160 #define DO_FSPLINE(order)                      \
2161 for(ithx=0; (ithx<order); ithx++)              \
2162 {                                              \
2163     index_x = (i0+ithx)*pny*pnz;               \
2164     tx      = thx[ithx];                       \
2165     dx      = dthx[ithx];                      \
2166                                                \
2167     for(ithy=0; (ithy<order); ithy++)          \
2168     {                                          \
2169         index_xy = index_x+(j0+ithy)*pnz;      \
2170         ty       = thy[ithy];                  \
2171         dy       = dthy[ithy];                 \
2172         fxy1     = fz1 = 0;                    \
2173                                                \
2174         for(ithz=0; (ithz<order); ithz++)      \
2175         {                                      \
2176             gval  = grid[index_xy+(k0+ithz)];  \
2177             fxy1 += thz[ithz]*gval;            \
2178             fz1  += dthz[ithz]*gval;           \
2179         }                                      \
2180         fx += dx*ty*fxy1;                      \
2181         fy += tx*dy*fxy1;                      \
2182         fz += tx*ty*fz1;                       \
2183     }                                          \
2184 }
2185
2186
2187 void gather_f_bsplines(gmx_pme_t pme,real *grid,
2188                        gmx_bool bClearF,pme_atomcomm_t *atc,
2189                        splinedata_t *spline,
2190                        real scale)
2191 {
2192     /* sum forces for local particles */
2193     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2194     int     index_x,index_xy;
2195     int     nx,ny,nz,pnx,pny,pnz;
2196     int *   idxptr;
2197     real    tx,ty,dx,dy,qn;
2198     real    fx,fy,fz,gval;
2199     real    fxy1,fz1;
2200     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2201     int     norder;
2202     real    rxx,ryx,ryy,rzx,rzy,rzz;
2203     int     order;
2204
2205     pme_spline_work_t *work;
2206
2207     work = &pme->spline_work;
2208
2209     order = pme->pme_order;
2210     thx   = spline->theta[XX];
2211     thy   = spline->theta[YY];
2212     thz   = spline->theta[ZZ];
2213     dthx  = spline->dtheta[XX];
2214     dthy  = spline->dtheta[YY];
2215     dthz  = spline->dtheta[ZZ];
2216     nx    = pme->nkx;
2217     ny    = pme->nky;
2218     nz    = pme->nkz;
2219     pnx   = pme->pmegrid_nx;
2220     pny   = pme->pmegrid_ny;
2221     pnz   = pme->pmegrid_nz;
2222
2223     rxx   = pme->recipbox[XX][XX];
2224     ryx   = pme->recipbox[YY][XX];
2225     ryy   = pme->recipbox[YY][YY];
2226     rzx   = pme->recipbox[ZZ][XX];
2227     rzy   = pme->recipbox[ZZ][YY];
2228     rzz   = pme->recipbox[ZZ][ZZ];
2229
2230     for(nn=0; nn<spline->n; nn++)
2231     {
2232         n  = spline->ind[nn];
2233         qn = atc->q[n];
2234
2235         if (bClearF)
2236         {
2237             atc->f[n][XX] = 0;
2238             atc->f[n][YY] = 0;
2239             atc->f[n][ZZ] = 0;
2240         }
2241         if (qn != 0)
2242         {
2243             fx     = 0;
2244             fy     = 0;
2245             fz     = 0;
2246             idxptr = atc->idx[n];
2247             norder = nn*order;
2248
2249             i0   = idxptr[XX];
2250             j0   = idxptr[YY];
2251             k0   = idxptr[ZZ];
2252
2253             /* Pointer arithmetic alert, next six statements */
2254             thx  = spline->theta[XX] + norder;
2255             thy  = spline->theta[YY] + norder;
2256             thz  = spline->theta[ZZ] + norder;
2257             dthx = spline->dtheta[XX] + norder;
2258             dthy = spline->dtheta[YY] + norder;
2259             dthz = spline->dtheta[ZZ] + norder;
2260
2261             switch (order) {
2262             case 4:
2263 #ifdef PME_SSE
2264 #ifdef PME_SSE_UNALIGNED
2265 #define PME_GATHER_F_SSE_ORDER4
2266 #else
2267 #define PME_GATHER_F_SSE_ALIGNED
2268 #define PME_ORDER 4
2269 #endif
2270 #include "pme_sse_single.h"
2271 #else
2272                 DO_FSPLINE(4);
2273 #endif
2274                 break;
2275             case 5:
2276 #ifdef PME_SSE
2277 #define PME_GATHER_F_SSE_ALIGNED
2278 #define PME_ORDER 5
2279 #include "pme_sse_single.h"
2280 #else
2281                 DO_FSPLINE(5);
2282 #endif
2283                 break;
2284             default:
2285                 DO_FSPLINE(order);
2286                 break;
2287             }
2288
2289             atc->f[n][XX] += -qn*( fx*nx*rxx );
2290             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2291             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2292         }
2293     }
2294     /* Since the energy and not forces are interpolated
2295      * the net force might not be exactly zero.
2296      * This can be solved by also interpolating F, but
2297      * that comes at a cost.
2298      * A better hack is to remove the net force every
2299      * step, but that must be done at a higher level
2300      * since this routine doesn't see all atoms if running
2301      * in parallel. Don't know how important it is?  EL 990726
2302      */
2303 }
2304
2305
2306 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2307                                    pme_atomcomm_t *atc)
2308 {
2309     splinedata_t *spline;
2310     int     n,ithx,ithy,ithz,i0,j0,k0;
2311     int     index_x,index_xy;
2312     int *   idxptr;
2313     real    energy,pot,tx,ty,qn,gval;
2314     real    *thx,*thy,*thz;
2315     int     norder;
2316     int     order;
2317
2318     spline = &atc->spline[0];
2319
2320     order = pme->pme_order;
2321
2322     energy = 0;
2323     for(n=0; (n<atc->n); n++) {
2324         qn      = atc->q[n];
2325
2326         if (qn != 0) {
2327             idxptr = atc->idx[n];
2328             norder = n*order;
2329
2330             i0   = idxptr[XX];
2331             j0   = idxptr[YY];
2332             k0   = idxptr[ZZ];
2333
2334             /* Pointer arithmetic alert, next three statements */
2335             thx  = spline->theta[XX] + norder;
2336             thy  = spline->theta[YY] + norder;
2337             thz  = spline->theta[ZZ] + norder;
2338
2339             pot = 0;
2340             for(ithx=0; (ithx<order); ithx++)
2341             {
2342                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2343                 tx      = thx[ithx];
2344
2345                 for(ithy=0; (ithy<order); ithy++)
2346                 {
2347                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2348                     ty       = thy[ithy];
2349
2350                     for(ithz=0; (ithz<order); ithz++)
2351                     {
2352                         gval  = grid[index_xy+(k0+ithz)];
2353                         pot  += tx*ty*thz[ithz]*gval;
2354                     }
2355
2356                 }
2357             }
2358
2359             energy += pot*qn;
2360         }
2361     }
2362
2363     return energy;
2364 }
2365
2366 /* Macro to force loop unrolling by fixing order.
2367  * This gives a significant performance gain.
2368  */
2369 #define CALC_SPLINE(order)                     \
2370 {                                              \
2371     int j,k,l;                                 \
2372     real dr,div;                               \
2373     real data[PME_ORDER_MAX];                  \
2374     real ddata[PME_ORDER_MAX];                 \
2375                                                \
2376     for(j=0; (j<DIM); j++)                     \
2377     {                                          \
2378         dr  = xptr[j];                         \
2379                                                \
2380         /* dr is relative offset from lower cell limit */ \
2381         data[order-1] = 0;                     \
2382         data[1] = dr;                          \
2383         data[0] = 1 - dr;                      \
2384                                                \
2385         for(k=3; (k<order); k++)               \
2386         {                                      \
2387             div = 1.0/(k - 1.0);               \
2388             data[k-1] = div*dr*data[k-2];      \
2389             for(l=1; (l<(k-1)); l++)           \
2390             {                                  \
2391                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2392                                    data[k-l-1]);                \
2393             }                                  \
2394             data[0] = div*(1-dr)*data[0];      \
2395         }                                      \
2396         /* differentiate */                    \
2397         ddata[0] = -data[0];                   \
2398         for(k=1; (k<order); k++)               \
2399         {                                      \
2400             ddata[k] = data[k-1] - data[k];    \
2401         }                                      \
2402                                                \
2403         div = 1.0/(order - 1);                 \
2404         data[order-1] = div*dr*data[order-2];  \
2405         for(l=1; (l<(order-1)); l++)           \
2406         {                                      \
2407             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2408                                (order-l-dr)*data[order-l-1]); \
2409         }                                      \
2410         data[0] = div*(1 - dr)*data[0];        \
2411                                                \
2412         for(k=0; k<order; k++)                 \
2413         {                                      \
2414             theta[j][i*order+k]  = data[k];    \
2415             dtheta[j][i*order+k] = ddata[k];   \
2416         }                                      \
2417     }                                          \
2418 }
2419
2420 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2421                    rvec fractx[],int nr,int ind[],real charge[],
2422                    gmx_bool bFreeEnergy)
2423 {
2424     /* construct splines for local atoms */
2425     int  i,ii;
2426     real *xptr;
2427
2428     for(i=0; i<nr; i++)
2429     {
2430         /* With free energy we do not use the charge check.
2431          * In most cases this will be more efficient than calling make_bsplines
2432          * twice, since usually more than half the particles have charges.
2433          */
2434         ii = ind[i];
2435         if (bFreeEnergy || charge[ii] != 0.0) {
2436             xptr = fractx[ii];
2437             switch(order) {
2438             case 4:  CALC_SPLINE(4);     break;
2439             case 5:  CALC_SPLINE(5);     break;
2440             default: CALC_SPLINE(order); break;
2441             }
2442         }
2443     }
2444 }
2445
2446
2447 void make_dft_mod(real *mod,real *data,int ndata)
2448 {
2449   int i,j;
2450   real sc,ss,arg;
2451
2452   for(i=0;i<ndata;i++) {
2453     sc=ss=0;
2454     for(j=0;j<ndata;j++) {
2455       arg=(2.0*M_PI*i*j)/ndata;
2456       sc+=data[j]*cos(arg);
2457       ss+=data[j]*sin(arg);
2458     }
2459     mod[i]=sc*sc+ss*ss;
2460   }
2461   for(i=0;i<ndata;i++)
2462     if(mod[i]<1e-7)
2463       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2464 }
2465
2466
2467
2468 void make_bspline_moduli(splinevec bsp_mod,int nx,int ny,int nz,int order)
2469 {
2470   int nmax=max(nx,max(ny,nz));
2471   real *data,*ddata,*bsp_data;
2472   int i,k,l;
2473   real div;
2474
2475   snew(data,order);
2476   snew(ddata,order);
2477   snew(bsp_data,nmax);
2478
2479   data[order-1]=0;
2480   data[1]=0;
2481   data[0]=1;
2482
2483   for(k=3;k<order;k++) {
2484     div=1.0/(k-1.0);
2485     data[k-1]=0;
2486     for(l=1;l<(k-1);l++)
2487       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2488     data[0]=div*data[0];
2489   }
2490   /* differentiate */
2491   ddata[0]=-data[0];
2492   for(k=1;k<order;k++)
2493     ddata[k]=data[k-1]-data[k];
2494   div=1.0/(order-1);
2495   data[order-1]=0;
2496   for(l=1;l<(order-1);l++)
2497     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2498   data[0]=div*data[0];
2499
2500   for(i=0;i<nmax;i++)
2501     bsp_data[i]=0;
2502   for(i=1;i<=order;i++)
2503     bsp_data[i]=data[i-1];
2504
2505   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2506   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2507   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2508
2509   sfree(data);
2510   sfree(ddata);
2511   sfree(bsp_data);
2512 }
2513
2514 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2515 {
2516   int nslab,n,i;
2517   int fw,bw;
2518
2519   nslab = atc->nslab;
2520
2521   n = 0;
2522   for(i=1; i<=nslab/2; i++) {
2523     fw = (atc->nodeid + i) % nslab;
2524     bw = (atc->nodeid - i + nslab) % nslab;
2525     if (n < nslab - 1) {
2526       atc->node_dest[n] = fw;
2527       atc->node_src[n]  = bw;
2528       n++;
2529     }
2530     if (n < nslab - 1) {
2531       atc->node_dest[n] = bw;
2532       atc->node_src[n]  = fw;
2533       n++;
2534     }
2535   }
2536 }
2537
2538 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2539 {
2540     int thread;
2541
2542     if(NULL != log)
2543     {
2544         fprintf(log,"Destroying PME data structures.\n");
2545     }
2546
2547     sfree((*pmedata)->nnx);
2548     sfree((*pmedata)->nny);
2549     sfree((*pmedata)->nnz);
2550
2551     pmegrids_destroy(&(*pmedata)->pmegridA);
2552
2553     sfree((*pmedata)->fftgridA);
2554     sfree((*pmedata)->cfftgridA);
2555     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2556
2557     if ((*pmedata)->pmegridB.grid.grid != NULL)
2558     {
2559         pmegrids_destroy(&(*pmedata)->pmegridB);
2560         sfree((*pmedata)->fftgridB);
2561         sfree((*pmedata)->cfftgridB);
2562         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2563     }
2564     for(thread=0; thread<(*pmedata)->nthread; thread++)
2565     {
2566         free_work(&(*pmedata)->work[thread]);
2567     }
2568     sfree((*pmedata)->work);
2569
2570     sfree(*pmedata);
2571     *pmedata = NULL;
2572
2573   return 0;
2574 }
2575
2576 static int mult_up(int n,int f)
2577 {
2578     return ((n + f - 1)/f)*f;
2579 }
2580
2581
2582 static double pme_load_imbalance(gmx_pme_t pme)
2583 {
2584     int    nma,nmi;
2585     double n1,n2,n3;
2586
2587     nma = pme->nnodes_major;
2588     nmi = pme->nnodes_minor;
2589
2590     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2591     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2592     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2593
2594     /* pme_solve is roughly double the cost of an fft */
2595
2596     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2597 }
2598
2599 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2600                           int dimind,gmx_bool bSpread)
2601 {
2602     int nk,k,s,thread;
2603
2604     atc->dimind = dimind;
2605     atc->nslab  = 1;
2606     atc->nodeid = 0;
2607     atc->pd_nalloc = 0;
2608 #ifdef GMX_MPI
2609     if (pme->nnodes > 1)
2610     {
2611         atc->mpi_comm = pme->mpi_comm_d[dimind];
2612         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2613         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2614     }
2615     if (debug)
2616     {
2617         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2618     }
2619 #endif
2620
2621     atc->bSpread   = bSpread;
2622     atc->pme_order = pme->pme_order;
2623
2624     if (atc->nslab > 1)
2625     {
2626         /* These three allocations are not required for particle decomp. */
2627         snew(atc->node_dest,atc->nslab);
2628         snew(atc->node_src,atc->nslab);
2629         setup_coordinate_communication(atc);
2630
2631         snew(atc->count_thread,pme->nthread);
2632         for(thread=0; thread<pme->nthread; thread++)
2633         {
2634             snew(atc->count_thread[thread],atc->nslab);
2635         }
2636         atc->count = atc->count_thread[0];
2637         snew(atc->rcount,atc->nslab);
2638         snew(atc->buf_index,atc->nslab);
2639     }
2640
2641     atc->nthread = pme->nthread;
2642     if (atc->nthread > 1)
2643     {
2644         snew(atc->thread_plist,atc->nthread);
2645     }
2646     snew(atc->spline,atc->nthread);
2647     for(thread=0; thread<atc->nthread; thread++)
2648     {
2649         if (atc->nthread > 1)
2650         {
2651             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2652             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2653         }
2654     }
2655 }
2656
2657 static void
2658 init_overlap_comm(pme_overlap_t *  ol,
2659                   int              norder,
2660 #ifdef GMX_MPI
2661                   MPI_Comm         comm,
2662 #endif
2663                   int              nnodes,
2664                   int              nodeid,
2665                   int              ndata,
2666                   int              commplainsize)
2667 {
2668     int lbnd,rbnd,maxlr,b,i;
2669     int exten;
2670     int nn,nk;
2671     pme_grid_comm_t *pgc;
2672     gmx_bool bCont;
2673     int fft_start,fft_end,send_index1,recv_index1;
2674
2675 #ifdef GMX_MPI
2676     ol->mpi_comm = comm;
2677 #endif
2678
2679     ol->nnodes = nnodes;
2680     ol->nodeid = nodeid;
2681
2682     /* Linear translation of the PME grid wo'nt affect reciprocal space
2683      * calculations, so to optimize we only interpolate "upwards",
2684      * which also means we only have to consider overlap in one direction.
2685      * I.e., particles on this node might also be spread to grid indices
2686      * that belong to higher nodes (modulo nnodes)
2687      */
2688
2689     snew(ol->s2g0,ol->nnodes+1);
2690     snew(ol->s2g1,ol->nnodes);
2691     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2692     for(i=0; i<nnodes; i++)
2693     {
2694         /* s2g0 the local interpolation grid start.
2695          * s2g1 the local interpolation grid end.
2696          * Because grid overlap communication only goes forward,
2697          * the grid the slabs for fft's should be rounded down.
2698          */
2699         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2700         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2701
2702         if (debug)
2703         {
2704             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2705         }
2706     }
2707     ol->s2g0[nnodes] = ndata;
2708     if (debug) { fprintf(debug,"\n"); }
2709
2710     /* Determine with how many nodes we need to communicate the grid overlap */
2711     b = 0;
2712     do
2713     {
2714         b++;
2715         bCont = FALSE;
2716         for(i=0; i<nnodes; i++)
2717         {
2718             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2719                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2720             {
2721                 bCont = TRUE;
2722             }
2723         }
2724     }
2725     while (bCont && b < nnodes);
2726     ol->noverlap_nodes = b - 1;
2727
2728     snew(ol->send_id,ol->noverlap_nodes);
2729     snew(ol->recv_id,ol->noverlap_nodes);
2730     for(b=0; b<ol->noverlap_nodes; b++)
2731     {
2732         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2733         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2734     }
2735     snew(ol->comm_data, ol->noverlap_nodes);
2736
2737     for(b=0; b<ol->noverlap_nodes; b++)
2738     {
2739         pgc = &ol->comm_data[b];
2740         /* Send */
2741         fft_start        = ol->s2g0[ol->send_id[b]];
2742         fft_end          = ol->s2g0[ol->send_id[b]+1];
2743         if (ol->send_id[b] < nodeid)
2744         {
2745             fft_start += ndata;
2746             fft_end   += ndata;
2747         }
2748         send_index1      = ol->s2g1[nodeid];
2749         send_index1      = min(send_index1,fft_end);
2750         pgc->send_index0 = fft_start;
2751         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2752
2753         /* We always start receiving to the first index of our slab */
2754         fft_start        = ol->s2g0[ol->nodeid];
2755         fft_end          = ol->s2g0[ol->nodeid+1];
2756         recv_index1      = ol->s2g1[ol->recv_id[b]];
2757         if (ol->recv_id[b] > nodeid)
2758         {
2759             recv_index1 -= ndata;
2760         }
2761         recv_index1      = min(recv_index1,fft_end);
2762         pgc->recv_index0 = fft_start;
2763         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2764     }
2765
2766     /* For non-divisible grid we need pme_order iso pme_order-1 */
2767     snew(ol->sendbuf,norder*commplainsize);
2768     snew(ol->recvbuf,norder*commplainsize);
2769 }
2770
2771 static void
2772 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2773                               int **global_to_local,
2774                               real **fraction_shift)
2775 {
2776     int i;
2777     int * gtl;
2778     real * fsh;
2779
2780     snew(gtl,5*n);
2781     snew(fsh,5*n);
2782     for(i=0; (i<5*n); i++)
2783     {
2784         /* Determine the global to local grid index */
2785         gtl[i] = (i - local_start + n) % n;
2786         /* For coordinates that fall within the local grid the fraction
2787          * is correct, we don't need to shift it.
2788          */
2789         fsh[i] = 0;
2790         if (local_range < n)
2791         {
2792             /* Due to rounding issues i could be 1 beyond the lower or
2793              * upper boundary of the local grid. Correct the index for this.
2794              * If we shift the index, we need to shift the fraction by
2795              * the same amount in the other direction to not affect
2796              * the weights.
2797              * Note that due to this shifting the weights at the end of
2798              * the spline might change, but that will only involve values
2799              * between zero and values close to the precision of a real,
2800              * which is anyhow the accuracy of the whole mesh calculation.
2801              */
2802             /* With local_range=0 we should not change i=local_start */
2803             if (i % n != local_start)
2804             {
2805                 if (gtl[i] == n-1)
2806                 {
2807                     gtl[i] = 0;
2808                     fsh[i] = -1;
2809                 }
2810                 else if (gtl[i] == local_range)
2811                 {
2812                     gtl[i] = local_range - 1;
2813                     fsh[i] = 1;
2814                 }
2815             }
2816         }
2817     }
2818
2819     *global_to_local = gtl;
2820     *fraction_shift  = fsh;
2821 }
2822
2823 static void sse_mask_init(pme_spline_work_t *work,int order)
2824 {
2825 #ifdef PME_SSE
2826     float  tmp[8];
2827     __m128 zero_SSE;
2828     int    of,i;
2829
2830     zero_SSE = _mm_setzero_ps();
2831
2832     for(of=0; of<8-(order-1); of++)
2833     {
2834         for(i=0; i<8; i++)
2835         {
2836             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2837         }
2838         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2839         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2840         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2841         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2842     }
2843 #endif
2844 }
2845
2846 static void
2847 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2848 {
2849     int nk_new;
2850
2851     if (*nk % nnodes != 0)
2852     {
2853         nk_new = nnodes*(*nk/nnodes + 1);
2854
2855         if (2*nk_new >= 3*(*nk))
2856         {
2857             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2858                       dim,*nk,dim,nnodes,dim);
2859         }
2860
2861         if (fplog != NULL)
2862         {
2863             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2864                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2865         }
2866
2867         *nk = nk_new;
2868     }
2869 }
2870
2871 int gmx_pme_init(gmx_pme_t *         pmedata,
2872                  t_commrec *         cr,
2873                  int                 nnodes_major,
2874                  int                 nnodes_minor,
2875                  t_inputrec *        ir,
2876                  int                 homenr,
2877                  gmx_bool            bFreeEnergy,
2878                  gmx_bool            bReproducible,
2879                  int                 nthread)
2880 {
2881     gmx_pme_t pme=NULL;
2882
2883     pme_atomcomm_t *atc;
2884     ivec ndata;
2885
2886     if (debug)
2887         fprintf(debug,"Creating PME data structures.\n");
2888     snew(pme,1);
2889
2890     pme->redist_init         = FALSE;
2891     pme->sum_qgrid_tmp       = NULL;
2892     pme->sum_qgrid_dd_tmp    = NULL;
2893     pme->buf_nalloc          = 0;
2894     pme->redist_buf_nalloc   = 0;
2895
2896     pme->nnodes              = 1;
2897     pme->bPPnode             = TRUE;
2898
2899     pme->nnodes_major        = nnodes_major;
2900     pme->nnodes_minor        = nnodes_minor;
2901
2902 #ifdef GMX_MPI
2903     if (PAR(cr))
2904     {
2905         pme->mpi_comm        = cr->mpi_comm_mygroup;
2906
2907         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
2908         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
2909     }
2910 #endif
2911
2912     if (pme->nnodes == 1)
2913     {
2914         pme->ndecompdim = 0;
2915         pme->nodeid_major = 0;
2916         pme->nodeid_minor = 0;
2917 #ifdef GMX_MPI
2918         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
2919 #endif
2920     }
2921     else
2922     {
2923         if (nnodes_minor == 1)
2924         {
2925 #ifdef GMX_MPI
2926             pme->mpi_comm_d[0] = pme->mpi_comm;
2927             pme->mpi_comm_d[1] = MPI_COMM_NULL;
2928 #endif
2929             pme->ndecompdim = 1;
2930             pme->nodeid_major = pme->nodeid;
2931             pme->nodeid_minor = 0;
2932
2933         }
2934         else if (nnodes_major == 1)
2935         {
2936 #ifdef GMX_MPI
2937             pme->mpi_comm_d[0] = MPI_COMM_NULL;
2938             pme->mpi_comm_d[1] = pme->mpi_comm;
2939 #endif
2940             pme->ndecompdim = 1;
2941             pme->nodeid_major = 0;
2942             pme->nodeid_minor = pme->nodeid;
2943         }
2944         else
2945         {
2946             if (pme->nnodes % nnodes_major != 0)
2947             {
2948                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
2949             }
2950             pme->ndecompdim = 2;
2951
2952 #ifdef GMX_MPI
2953             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
2954                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
2955             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
2956                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
2957
2958             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
2959             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
2960             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
2961             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
2962 #endif
2963         }
2964         pme->bPPnode = (cr->duty & DUTY_PP);
2965     }
2966
2967     pme->nthread = nthread;
2968
2969     if (ir->ePBC == epbcSCREW)
2970     {
2971         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
2972     }
2973
2974     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
2975     pme->nkx         = ir->nkx;
2976     pme->nky         = ir->nky;
2977     pme->nkz         = ir->nkz;
2978     pme->pme_order   = ir->pme_order;
2979     pme->epsilon_r   = ir->epsilon_r;
2980
2981     if (pme->pme_order > PME_ORDER_MAX)
2982     {
2983         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
2984                   pme->pme_order,PME_ORDER_MAX);
2985     }
2986
2987     /* Currently pme.c supports only the fft5d FFT code.
2988      * Therefore the grid always needs to be divisible by nnodes.
2989      * When the old 1D code is also supported again, change this check.
2990      *
2991      * This check should be done before calling gmx_pme_init
2992      * and fplog should be passed iso stderr.
2993      *
2994     if (pme->ndecompdim >= 2)
2995     */
2996     if (pme->ndecompdim >= 1)
2997     {
2998         /*
2999         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3000                                         'x',nnodes_major,&pme->nkx);
3001         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3002                                         'y',nnodes_minor,&pme->nky);
3003         */
3004     }
3005
3006     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3007         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3008         pme->nkz <= pme->pme_order)
3009     {
3010         gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
3011     }
3012
3013     if (pme->nnodes > 1) {
3014         double imbal;
3015
3016 #ifdef GMX_MPI
3017         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3018         MPI_Type_commit(&(pme->rvec_mpi));
3019 #endif
3020
3021         /* Note that the charge spreading and force gathering, which usually
3022          * takes about the same amount of time as FFT+solve_pme,
3023          * is always fully load balanced
3024          * (unless the charge distribution is inhomogeneous).
3025          */
3026
3027         imbal = pme_load_imbalance(pme);
3028         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3029         {
3030             fprintf(stderr,
3031                     "\n"
3032                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3033                     "      For optimal PME load balancing\n"
3034                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3035                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3036                     "\n",
3037                     (int)((imbal-1)*100 + 0.5),
3038                     pme->nkx,pme->nky,pme->nnodes_major,
3039                     pme->nky,pme->nkz,pme->nnodes_minor);
3040         }
3041     }
3042
3043     /* For non-divisible grid we need pme_order iso pme_order-1 */
3044     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3045      * y is always copied through a buffer: we don't need padding in z,
3046      * but we do need the overlap in x because of the communication order.
3047      */
3048     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3049 #ifdef GMX_MPI
3050                       pme->mpi_comm_d[0],
3051 #endif
3052                       pme->nnodes_major,pme->nodeid_major,
3053                       pme->nkx,
3054                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3055
3056     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3057 #ifdef GMX_MPI
3058                       pme->mpi_comm_d[1],
3059 #endif
3060                       pme->nnodes_minor,pme->nodeid_minor,
3061                       pme->nky,
3062                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
3063
3064     /* Check for a limitation of the (current) sum_fftgrid_dd code */
3065     if (pme->nthread > 1 &&
3066         (pme->overlap[0].noverlap_nodes > 1 ||
3067          pme->overlap[1].noverlap_nodes > 1))
3068     {
3069         gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
3070     }
3071
3072     snew(pme->bsp_mod[XX],pme->nkx);
3073     snew(pme->bsp_mod[YY],pme->nky);
3074     snew(pme->bsp_mod[ZZ],pme->nkz);
3075
3076     /* The required size of the interpolation grid, including overlap.
3077      * The allocated size (pmegrid_n?) might be slightly larger.
3078      */
3079     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3080                       pme->overlap[0].s2g0[pme->nodeid_major];
3081     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3082                       pme->overlap[1].s2g0[pme->nodeid_minor];
3083     pme->pmegrid_nz_base = pme->nkz;
3084     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3085     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3086
3087     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3088     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3089     pme->pmegrid_start_iz = 0;
3090
3091     make_gridindex5_to_localindex(pme->nkx,
3092                                   pme->pmegrid_start_ix,
3093                                   pme->pmegrid_nx - (pme->pme_order-1),
3094                                   &pme->nnx,&pme->fshx);
3095     make_gridindex5_to_localindex(pme->nky,
3096                                   pme->pmegrid_start_iy,
3097                                   pme->pmegrid_ny - (pme->pme_order-1),
3098                                   &pme->nny,&pme->fshy);
3099     make_gridindex5_to_localindex(pme->nkz,
3100                                   pme->pmegrid_start_iz,
3101                                   pme->pmegrid_nz_base,
3102                                   &pme->nnz,&pme->fshz);
3103
3104     pmegrids_init(&pme->pmegridA,
3105                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3106                   pme->pmegrid_nz_base,
3107                   pme->pme_order,
3108                   pme->nthread,
3109                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3110                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3111
3112     sse_mask_init(&pme->spline_work,pme->pme_order);
3113
3114     ndata[0] = pme->nkx;
3115     ndata[1] = pme->nky;
3116     ndata[2] = pme->nkz;
3117
3118     /* This routine will allocate the grid data to fit the FFTs */
3119     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3120                             &pme->fftgridA,&pme->cfftgridA,
3121                             pme->mpi_comm_d,
3122                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3123                             bReproducible,pme->nthread);
3124
3125     if (bFreeEnergy)
3126     {
3127         pmegrids_init(&pme->pmegridB,
3128                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3129                       pme->pmegrid_nz_base,
3130                       pme->pme_order,
3131                       pme->nthread,
3132                       pme->nkx % pme->nnodes_major != 0,
3133                       pme->nky % pme->nnodes_minor != 0);
3134
3135         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3136                                 &pme->fftgridB,&pme->cfftgridB,
3137                                 pme->mpi_comm_d,
3138                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3139                                 bReproducible,pme->nthread);
3140     }
3141     else
3142     {
3143         pme->pmegridB.grid.grid = NULL;
3144         pme->fftgridB           = NULL;
3145         pme->cfftgridB          = NULL;
3146     }
3147
3148     make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3149
3150     /* Use atc[0] for spreading */
3151     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3152     if (pme->ndecompdim >= 2)
3153     {
3154         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3155     }
3156
3157     if (pme->nnodes == 1) {
3158         pme->atc[0].n = homenr;
3159         pme_realloc_atomcomm_things(&pme->atc[0]);
3160     }
3161
3162     {
3163         int thread;
3164
3165         /* Use fft5d, order after FFT is y major, z, x minor */
3166
3167         snew(pme->work,pme->nthread);
3168         for(thread=0; thread<pme->nthread; thread++)
3169         {
3170             realloc_work(&pme->work[thread],pme->nkx);
3171         }
3172     }
3173
3174     *pmedata = pme;
3175
3176     return 0;
3177 }
3178
3179
3180 static void copy_local_grid(gmx_pme_t pme,
3181                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3182 {
3183     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3184     int  fft_my,fft_mz;
3185     int  nsx,nsy,nsz;
3186     ivec nf;
3187     int  offx,offy,offz,x,y,z,i0,i0t;
3188     int  d;
3189     pmegrid_t *pmegrid;
3190     real *grid_th;
3191
3192     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3193                                    local_fft_ndata,
3194                                    local_fft_offset,
3195                                    local_fft_size);
3196     fft_my = local_fft_size[YY];
3197     fft_mz = local_fft_size[ZZ];
3198
3199     pmegrid = &pmegrids->grid_th[thread];
3200
3201     nsx = pmegrid->n[XX];
3202     nsy = pmegrid->n[YY];
3203     nsz = pmegrid->n[ZZ];
3204
3205     for(d=0; d<DIM; d++)
3206     {
3207         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3208                     local_fft_ndata[d] - pmegrid->offset[d]);
3209     }
3210
3211     offx = pmegrid->offset[XX];
3212     offy = pmegrid->offset[YY];
3213     offz = pmegrid->offset[ZZ];
3214
3215     /* Directly copy the non-overlapping parts of the local grids.
3216      * This also initializes the full grid.
3217      */
3218     grid_th = pmegrid->grid;
3219     for(x=0; x<nf[XX]; x++)
3220     {
3221         for(y=0; y<nf[YY]; y++)
3222         {
3223             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3224             i0t = (x*nsy + y)*nsz;
3225             for(z=0; z<nf[ZZ]; z++)
3226             {
3227                 fftgrid[i0+z] = grid_th[i0t+z];
3228             }
3229         }
3230     }
3231 }
3232
3233 static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
3234 {
3235     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3236     pme_overlap_t *overlap;
3237     int datasize,nind;
3238     int i,x,y,z,n;
3239
3240     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3241                                    local_fft_ndata,
3242                                    local_fft_offset,
3243                                    local_fft_size);
3244     /* Major dimension */
3245     overlap = &pme->overlap[0];
3246
3247     nind   = overlap->comm_data[0].send_nindex;
3248
3249     for(y=0; y<local_fft_ndata[YY]; y++) {
3250          printf(" %2d",y);
3251     }
3252     printf("\n");
3253
3254     i = 0;
3255     for(x=0; x<nind; x++) {
3256         for(y=0; y<local_fft_ndata[YY]; y++) {
3257             n = 0;
3258             for(z=0; z<local_fft_ndata[ZZ]; z++) {
3259                 if (sendbuf[i] != 0) n++;
3260                 i++;
3261             }
3262             printf(" %2d",n);
3263         }
3264         printf("\n");
3265     }
3266 }
3267
3268 static void
3269 reduce_threadgrid_overlap(gmx_pme_t pme,
3270                           const pmegrids_t *pmegrids,int thread,
3271                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3272 {
3273     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3274     int  fft_nx,fft_ny,fft_nz;
3275     int  fft_my,fft_mz;
3276     int  buf_my=-1;
3277     int  nsx,nsy,nsz;
3278     ivec ne;
3279     int  offx,offy,offz,x,y,z,i0,i0t;
3280     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3281     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3282     gmx_bool bCommX,bCommY;
3283     int  d;
3284     int  thread_f;
3285     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3286     const real *grid_th;
3287     real *commbuf=NULL;
3288
3289     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3290                                    local_fft_ndata,
3291                                    local_fft_offset,
3292                                    local_fft_size);
3293     fft_nx = local_fft_ndata[XX];
3294     fft_ny = local_fft_ndata[YY];
3295     fft_nz = local_fft_ndata[ZZ];
3296
3297     fft_my = local_fft_size[YY];
3298     fft_mz = local_fft_size[ZZ];
3299
3300     /* This routine is called when all thread have finished spreading.
3301      * Here each thread sums grid contributions calculated by other threads
3302      * to the thread local grid volume.
3303      * To minimize the number of grid copying operations,
3304      * this routines sums immediately from the pmegrid to the fftgrid.
3305      */
3306
3307     /* Determine which part of the full node grid we should operate on,
3308      * this is our thread local part of the full grid.
3309      */
3310     pmegrid = &pmegrids->grid_th[thread];
3311
3312     for(d=0; d<DIM; d++)
3313     {
3314         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3315                     local_fft_ndata[d]);
3316     }
3317
3318     offx = pmegrid->offset[XX];
3319     offy = pmegrid->offset[YY];
3320     offz = pmegrid->offset[ZZ];
3321
3322
3323     bClearBufX  = TRUE;
3324     bClearBufY  = TRUE;
3325     bClearBufXY = TRUE;
3326
3327     /* Now loop over all the thread data blocks that contribute
3328      * to the grid region we (our thread) are operating on.
3329      */
3330     /* Note that ffy_nx/y is equal to the number of grid points
3331      * between the first point of our node grid and the one of the next node.
3332      */
3333     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3334     {
3335         fx = pmegrid->ci[XX] + sx;
3336         ox = 0;
3337         bCommX = FALSE;
3338         if (fx < 0) {
3339             fx += pmegrids->nc[XX];
3340             ox -= fft_nx;
3341             bCommX = (pme->nnodes_major > 1);
3342         }
3343         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3344         ox += pmegrid_g->offset[XX];
3345         if (!bCommX)
3346         {
3347             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3348         }
3349         else
3350         {
3351             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3352         }
3353
3354         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3355         {
3356             fy = pmegrid->ci[YY] + sy;
3357             oy = 0;
3358             bCommY = FALSE;
3359             if (fy < 0) {
3360                 fy += pmegrids->nc[YY];
3361                 oy -= fft_ny;
3362                 bCommY = (pme->nnodes_minor > 1);
3363             }
3364             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3365             oy += pmegrid_g->offset[YY];
3366             if (!bCommY)
3367             {
3368                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3369             }
3370             else
3371             {
3372                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3373             }
3374
3375             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3376             {
3377                 fz = pmegrid->ci[ZZ] + sz;
3378                 oz = 0;
3379                 if (fz < 0)
3380                 {
3381                     fz += pmegrids->nc[ZZ];
3382                     oz -= fft_nz;
3383                 }
3384                 pmegrid_g = &pmegrids->grid_th[fz];
3385                 oz += pmegrid_g->offset[ZZ];
3386                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3387
3388                 if (sx == 0 && sy == 0 && sz == 0)
3389                 {
3390                     /* We have already added our local contribution
3391                      * before calling this routine, so skip it here.
3392                      */
3393                     continue;
3394                 }
3395
3396                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3397
3398                 pmegrid_f = &pmegrids->grid_th[thread_f];
3399
3400                 grid_th = pmegrid_f->grid;
3401
3402                 nsx = pmegrid_f->n[XX];
3403                 nsy = pmegrid_f->n[YY];
3404                 nsz = pmegrid_f->n[ZZ];
3405
3406 #ifdef DEBUG_PME_REDUCE
3407                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3408                        pme->nodeid,thread,thread_f,
3409                        pme->pmegrid_start_ix,
3410                        pme->pmegrid_start_iy,
3411                        pme->pmegrid_start_iz,
3412                        sx,sy,sz,
3413                        offx-ox,tx1-ox,offx,tx1,
3414                        offy-oy,ty1-oy,offy,ty1,
3415                        offz-oz,tz1-oz,offz,tz1);
3416 #endif
3417
3418                 if (!(bCommX || bCommY))
3419                 {
3420                     /* Copy from the thread local grid to the node grid */
3421                     for(x=offx; x<tx1; x++)
3422                     {
3423                         for(y=offy; y<ty1; y++)
3424                         {
3425                             i0  = (x*fft_my + y)*fft_mz;
3426                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3427                             for(z=offz; z<tz1; z++)
3428                             {
3429                                 fftgrid[i0+z] += grid_th[i0t+z];
3430                             }
3431                         }
3432                     }
3433                 }
3434                 else
3435                 {
3436                     /* The order of this conditional decides
3437                      * where the corner volume gets stored with x+y decomp.
3438                      */
3439                     if (bCommY)
3440                     {
3441                         commbuf = commbuf_y;
3442                         buf_my  = ty1 - offy;
3443                         if (bCommX)
3444                         {
3445                             /* We index commbuf modulo the local grid size */
3446                             commbuf += buf_my*fft_nx*fft_nz;
3447
3448                             bClearBuf  = bClearBufXY;
3449                             bClearBufXY = FALSE;
3450                         }
3451                         else
3452                         {
3453                             bClearBuf  = bClearBufY;
3454                             bClearBufY = FALSE;
3455                         }
3456                     }
3457                     else
3458                     {
3459                         commbuf = commbuf_x;
3460                         buf_my  = fft_ny;
3461                         bClearBuf  = bClearBufX;
3462                         bClearBufX = FALSE;
3463                     }
3464
3465                     /* Copy to the communication buffer */
3466                     for(x=offx; x<tx1; x++)
3467                     {
3468                         for(y=offy; y<ty1; y++)
3469                         {
3470                             i0  = (x*buf_my + y)*fft_nz;
3471                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3472
3473                             if (bClearBuf)
3474                             {
3475                                 /* First access of commbuf, initialize it */
3476                                 for(z=offz; z<tz1; z++)
3477                                 {
3478                                     commbuf[i0+z]  = grid_th[i0t+z];
3479                                 }
3480                             }
3481                             else
3482                             {
3483                                 for(z=offz; z<tz1; z++)
3484                                 {
3485                                     commbuf[i0+z] += grid_th[i0t+z];
3486                                 }
3487                             }
3488                         }
3489                     }
3490                 }
3491             }
3492         }
3493     }
3494 }
3495
3496
3497 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3498 {
3499     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3500     pme_overlap_t *overlap;
3501     int  send_nindex;
3502     int  recv_index0,recv_nindex;
3503 #ifdef GMX_MPI
3504     MPI_Status stat;
3505 #endif
3506     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3507     real *sendptr,*recvptr;
3508     int  x,y,z,indg,indb;
3509
3510     /* Note that this routine is only used for forward communication.
3511      * Since the force gathering, unlike the charge spreading,
3512      * can be trivially parallelized over the particles,
3513      * the backwards process is much simpler and can use the "old"
3514      * communication setup.
3515      */
3516
3517     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3518                                    local_fft_ndata,
3519                                    local_fft_offset,
3520                                    local_fft_size);
3521
3522     /* Currently supports only a single communication pulse */
3523
3524 /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3525     if (pme->nnodes_minor > 1)
3526     {
3527         /* Major dimension */
3528         overlap = &pme->overlap[1];
3529
3530         if (pme->nnodes_major > 1)
3531         {
3532              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3533         }
3534         else
3535         {
3536             size_yx = 0;
3537         }
3538         datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
3539
3540         ipulse = 0;
3541
3542         send_id = overlap->send_id[ipulse];
3543         recv_id = overlap->recv_id[ipulse];
3544         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3545         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3546         recv_index0 = 0;
3547         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3548
3549         sendptr = overlap->sendbuf;
3550         recvptr = overlap->recvbuf;
3551
3552         /*
3553         printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
3554                local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
3555         printf("node %d send %f, %f\n",pme->nodeid,
3556                sendptr[0],sendptr[send_nindex*datasize-1]);
3557         */
3558
3559 #ifdef GMX_MPI
3560         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3561                      send_id,ipulse,
3562                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3563                      recv_id,ipulse,
3564                      overlap->mpi_comm,&stat);
3565 #endif
3566
3567         for(x=0; x<local_fft_ndata[XX]; x++)
3568         {
3569             for(y=0; y<recv_nindex; y++)
3570             {
3571                 indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3572                 indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
3573                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3574                 {
3575                     fftgrid[indg+z] += recvptr[indb+z];
3576                 }
3577             }
3578         }
3579         if (pme->nnodes_major > 1)
3580         {
3581             sendptr = pme->overlap[0].sendbuf;
3582             for(x=0; x<size_yx; x++)
3583             {
3584                 for(y=0; y<recv_nindex; y++)
3585                 {
3586                     indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3587                     indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
3588                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3589                     {
3590                         sendptr[indg+z] += recvptr[indb+z];
3591                     }
3592                 }
3593             }
3594         }
3595     }
3596
3597     /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
3598     if (pme->nnodes_major > 1)
3599     {
3600         /* Major dimension */
3601         overlap = &pme->overlap[0];
3602
3603         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3604         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3605
3606         ipulse = 0;
3607
3608         send_id = overlap->send_id[ipulse];
3609         recv_id = overlap->recv_id[ipulse];
3610         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3611         /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
3612         recv_index0 = 0;
3613         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3614
3615         sendptr = overlap->sendbuf;
3616         recvptr = overlap->recvbuf;
3617
3618         if (debug != NULL)
3619         {
3620             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3621                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3622         }
3623
3624 #ifdef GMX_MPI
3625         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3626                      send_id,ipulse,
3627                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3628                      recv_id,ipulse,
3629                      overlap->mpi_comm,&stat);
3630 #endif
3631
3632         for(x=0; x<recv_nindex; x++)
3633         {
3634             for(y=0; y<local_fft_ndata[YY]; y++)
3635             {
3636                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3637                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3638                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3639                 {
3640                     fftgrid[indg+z] += recvptr[indb+z];
3641                 }
3642             }
3643         }
3644     }
3645 }
3646
3647
3648 static void spread_on_grid(gmx_pme_t pme,
3649                            pme_atomcomm_t *atc,pmegrids_t *grids,
3650                            gmx_bool bCalcSplines,gmx_bool bSpread,
3651                            real *fftgrid)
3652 {
3653     int nthread,thread;
3654 #ifdef PME_TIME_THREADS
3655     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3656     static double cs1=0,cs2=0,cs3=0;
3657     static double cs1a[6]={0,0,0,0,0,0};
3658     static int cnt=0;
3659 #endif
3660
3661     nthread = pme->nthread;
3662
3663 #ifdef PME_TIME_THREADS
3664     c1 = omp_cyc_start();
3665 #endif
3666     if (bCalcSplines)
3667     {
3668 #pragma omp parallel for num_threads(nthread) schedule(static)
3669         for(thread=0; thread<nthread; thread++)
3670         {
3671             int start,end;
3672
3673             start = atc->n* thread   /nthread;
3674             end   = atc->n*(thread+1)/nthread;
3675
3676             /* Compute fftgrid index for all atoms,
3677              * with help of some extra variables.
3678              */
3679             calc_interpolation_idx(pme,atc,start,end,thread);
3680         }
3681     }
3682 #ifdef PME_TIME_THREADS
3683     c1 = omp_cyc_end(c1);
3684     cs1 += (double)c1;
3685 #endif
3686
3687 #ifdef PME_TIME_THREADS
3688     c2 = omp_cyc_start();
3689 #endif
3690 #pragma omp parallel for num_threads(nthread) schedule(static)
3691     for(thread=0; thread<nthread; thread++)
3692     {
3693         splinedata_t *spline;
3694         pmegrid_t *grid;
3695
3696         /* make local bsplines  */
3697         if (grids->nthread == 1)
3698         {
3699             spline = &atc->spline[0];
3700
3701             spline->n = atc->n;
3702
3703             grid = &grids->grid;
3704         }
3705         else
3706         {
3707             spline = &atc->spline[thread];
3708
3709             make_thread_local_ind(atc,thread,spline);
3710
3711             grid = &grids->grid_th[thread];
3712         }
3713
3714         if (bCalcSplines)
3715         {
3716             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3717                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3718         }
3719
3720         if (bSpread)
3721         {
3722             /* put local atoms on grid. */
3723 #ifdef PME_TIME_SPREAD
3724             ct1a = omp_cyc_start();
3725 #endif
3726             spread_q_bsplines_thread(grid,atc,spline,&pme->spline_work);
3727
3728             if (grids->nthread > 1)
3729             {
3730                 copy_local_grid(pme,grids,thread,fftgrid);
3731             }
3732 #ifdef PME_TIME_SPREAD
3733             ct1a = omp_cyc_end(ct1a);
3734             cs1a[thread] += (double)ct1a;
3735 #endif
3736         }
3737     }
3738 #ifdef PME_TIME_THREADS
3739     c2 = omp_cyc_end(c2);
3740     cs2 += (double)c2;
3741 #endif
3742
3743     if (grids->nthread > 1)
3744     {
3745 #ifdef PME_TIME_THREADS
3746         c3 = omp_cyc_start();
3747 #endif
3748 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3749         for(thread=0; thread<grids->nthread; thread++)
3750         {
3751             reduce_threadgrid_overlap(pme,grids,thread,
3752                                       fftgrid,
3753                                       pme->overlap[0].sendbuf,
3754                                       pme->overlap[1].sendbuf);
3755 #ifdef PRINT_PME_SENDBUF
3756             print_sendbuf(pme,pme->overlap[0].sendbuf);
3757 #endif
3758         }
3759 #ifdef PME_TIME_THREADS
3760         c3 = omp_cyc_end(c3);
3761         cs3 += (double)c3;
3762 #endif
3763
3764         if (pme->nnodes > 1)
3765         {
3766             /* Communicate the overlapping part of the fftgrid */
3767             sum_fftgrid_dd(pme,fftgrid);
3768         }
3769     }
3770
3771 #ifdef PME_TIME_THREADS
3772     cnt++;
3773     if (cnt % 20 == 0)
3774     {
3775         printf("idx %.2f spread %.2f red %.2f",
3776                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3777 #ifdef PME_TIME_SPREAD
3778         for(thread=0; thread<nthread; thread++)
3779             printf(" %.2f",cs1a[thread]*1e-9);
3780 #endif
3781         printf("\n");
3782     }
3783 #endif
3784 }
3785
3786
3787 static void dump_grid(FILE *fp,
3788                       int sx,int sy,int sz,int nx,int ny,int nz,
3789                       int my,int mz,const real *g)
3790 {
3791     int x,y,z;
3792
3793     for(x=0; x<nx; x++)
3794     {
3795         for(y=0; y<ny; y++)
3796         {
3797             for(z=0; z<nz; z++)
3798             {
3799                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3800                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3801             }
3802         }
3803     }
3804 }
3805
3806 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3807 {
3808     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3809
3810     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3811                                    local_fft_ndata,
3812                                    local_fft_offset,
3813                                    local_fft_size);
3814
3815     dump_grid(stderr,
3816               pme->pmegrid_start_ix,
3817               pme->pmegrid_start_iy,
3818               pme->pmegrid_start_iz,
3819               pme->pmegrid_nx-pme->pme_order+1,
3820               pme->pmegrid_ny-pme->pme_order+1,
3821               pme->pmegrid_nz-pme->pme_order+1,
3822               local_fft_size[YY],
3823               local_fft_size[ZZ],
3824               fftgrid);
3825 }
3826
3827
3828 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3829 {
3830     pme_atomcomm_t *atc;
3831     pmegrids_t *grid;
3832
3833     if (pme->nnodes > 1)
3834     {
3835         gmx_incons("gmx_pme_calc_energy called in parallel");
3836     }
3837     if (pme->bFEP > 1)
3838     {
3839         gmx_incons("gmx_pme_calc_energy with free energy");
3840     }
3841
3842     atc = &pme->atc_energy;
3843     atc->nslab     = 1;
3844     atc->bSpread   = TRUE;
3845     atc->pme_order = pme->pme_order;
3846     atc->n         = n;
3847     pme_realloc_atomcomm_things(atc);
3848     atc->x         = x;
3849     atc->q         = q;
3850
3851     /* We only use the A-charges grid */
3852     grid = &pme->pmegridA;
3853
3854     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
3855
3856     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
3857 }
3858
3859
3860 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
3861         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
3862 {
3863     /* Reset all the counters related to performance over the run */
3864     wallcycle_stop(wcycle,ewcRUN);
3865     wallcycle_reset_all(wcycle);
3866     init_nrnb(nrnb);
3867     ir->init_step += step_rel;
3868     ir->nsteps    -= step_rel;
3869     wallcycle_start(wcycle,ewcRUN);
3870 }
3871
3872
3873 int gmx_pmeonly(gmx_pme_t pme,
3874                 t_commrec *cr,    t_nrnb *nrnb,
3875                 gmx_wallcycle_t wcycle,
3876                 real ewaldcoeff,  gmx_bool bGatherOnly,
3877                 t_inputrec *ir)
3878 {
3879     gmx_pme_pp_t pme_pp;
3880     int  natoms;
3881     matrix box;
3882     rvec *x_pp=NULL,*f_pp=NULL;
3883     real *chargeA=NULL,*chargeB=NULL;
3884     real lambda=0;
3885     int  maxshift_x=0,maxshift_y=0;
3886     real energy,dvdlambda;
3887     matrix vir;
3888     float cycles;
3889     int  count;
3890     gmx_bool bEnerVir;
3891     gmx_large_int_t step,step_rel;
3892
3893
3894     pme_pp = gmx_pme_pp_init(cr);
3895
3896     init_nrnb(nrnb);
3897
3898     count = 0;
3899     do /****** this is a quasi-loop over time steps! */
3900     {
3901         /* Domain decomposition */
3902         natoms = gmx_pme_recv_q_x(pme_pp,
3903                                   &chargeA,&chargeB,box,&x_pp,&f_pp,
3904                                   &maxshift_x,&maxshift_y,
3905                                   &pme->bFEP,&lambda,
3906                                   &bEnerVir,
3907                                   &step);
3908
3909         if (natoms == -1) {
3910             /* We should stop: break out of the loop */
3911             break;
3912         }
3913
3914         step_rel = step - ir->init_step;
3915
3916         if (count == 0)
3917             wallcycle_start(wcycle,ewcRUN);
3918
3919         wallcycle_start(wcycle,ewcPMEMESH);
3920
3921         dvdlambda = 0;
3922         clear_mat(vir);
3923         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
3924                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
3925                    &energy,lambda,&dvdlambda,
3926                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
3927
3928         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
3929
3930         gmx_pme_send_force_vir_ener(pme_pp,
3931                                     f_pp,vir,energy,dvdlambda,
3932                                     cycles);
3933
3934         count++;
3935
3936         if (step_rel == wcycle_get_reset_counters(wcycle))
3937         {
3938             /* Reset all the counters related to performance over the run */
3939             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
3940             wcycle_set_reset_counters(wcycle, 0);
3941         }
3942
3943     } /***** end of quasi-loop, we stop with the break above */
3944     while (TRUE);
3945
3946     return 0;
3947 }
3948
3949 int gmx_pme_do(gmx_pme_t pme,
3950                int start,       int homenr,
3951                rvec x[],        rvec f[],
3952                real *chargeA,   real *chargeB,
3953                matrix box, t_commrec *cr,
3954                int  maxshift_x, int maxshift_y,
3955                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
3956                matrix vir,      real ewaldcoeff,
3957                real *energy,    real lambda,
3958                real *dvdlambda, int flags)
3959 {
3960     int     q,d,i,j,ntot,npme;
3961     int     nx,ny,nz;
3962     int     n_d,local_ny;
3963     pme_atomcomm_t *atc=NULL;
3964     pmegrids_t *pmegrid=NULL;
3965     real    *grid=NULL;
3966     real    *ptr;
3967     rvec    *x_d,*f_d;
3968     real    *charge=NULL,*q_d;
3969     real    energy_AB[2];
3970     matrix  vir_AB[2];
3971     gmx_bool bClearF;
3972     gmx_parallel_3dfft_t pfft_setup;
3973     real *  fftgrid;
3974     t_complex * cfftgrid;
3975     int     thread;
3976
3977     if (pme->nnodes > 1) {
3978         atc = &pme->atc[0];
3979         atc->npd = homenr;
3980         if (atc->npd > atc->pd_nalloc) {
3981             atc->pd_nalloc = over_alloc_dd(atc->npd);
3982             srenew(atc->pd,atc->pd_nalloc);
3983         }
3984         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
3985     }
3986     else
3987     {
3988         /* This could be necessary for TPI */
3989         pme->atc[0].n = homenr;
3990     }
3991
3992     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
3993         if (q == 0) {
3994             pmegrid = &pme->pmegridA;
3995             fftgrid = pme->fftgridA;
3996             cfftgrid = pme->cfftgridA;
3997             pfft_setup = pme->pfft_setupA;
3998             charge = chargeA+start;
3999         } else {
4000             pmegrid = &pme->pmegridB;
4001             fftgrid = pme->fftgridB;
4002             cfftgrid = pme->cfftgridB;
4003             pfft_setup = pme->pfft_setupB;
4004             charge = chargeB+start;
4005         }
4006         grid = pmegrid->grid.grid;
4007         /* Unpack structure */
4008         if (debug) {
4009             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4010                     cr->nnodes,cr->nodeid);
4011             fprintf(debug,"Grid = %p\n",(void*)grid);
4012             if (grid == NULL)
4013                 gmx_fatal(FARGS,"No grid!");
4014         }
4015         where();
4016
4017         m_inv_ur0(box,pme->recipbox);
4018
4019         if (pme->nnodes == 1) {
4020             atc = &pme->atc[0];
4021             if (DOMAINDECOMP(cr)) {
4022                 atc->n = homenr;
4023                 pme_realloc_atomcomm_things(atc);
4024             }
4025             atc->x = x;
4026             atc->q = charge;
4027             atc->f = f;
4028         } else {
4029             wallcycle_start(wcycle,ewcPME_REDISTXF);
4030             for(d=pme->ndecompdim-1; d>=0; d--)
4031             {
4032                 if (d == pme->ndecompdim-1)
4033                 {
4034                     n_d = homenr;
4035                     x_d = x + start;
4036                     q_d = charge;
4037                 }
4038                 else
4039                 {
4040                     n_d = pme->atc[d+1].n;
4041                     x_d = atc->x;
4042                     q_d = atc->q;
4043                 }
4044                 atc = &pme->atc[d];
4045                 atc->npd = n_d;
4046                 if (atc->npd > atc->pd_nalloc) {
4047                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4048                     srenew(atc->pd,atc->pd_nalloc);
4049                 }
4050                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4051                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4052                 where();
4053
4054                 GMX_BARRIER(cr->mpi_comm_mygroup);
4055                 /* Redistribute x (only once) and qA or qB */
4056                 if (DOMAINDECOMP(cr)) {
4057                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4058                 } else {
4059                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4060                 }
4061             }
4062             where();
4063
4064             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4065         }
4066
4067         if (debug)
4068             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4069                     cr->nodeid,atc->n);
4070
4071         if (flags & GMX_PME_SPREAD_Q)
4072         {
4073             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4074
4075             /* Spread the charges on a grid */
4076             GMX_MPE_LOG(ev_spread_on_grid_start);
4077
4078             /* Spread the charges on a grid */
4079             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4080             GMX_MPE_LOG(ev_spread_on_grid_finish);
4081
4082             if (q == 0)
4083             {
4084                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4085             }
4086             inc_nrnb(nrnb,eNR_SPREADQBSP,
4087                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4088
4089             if (pme->nthread == 1)
4090             {
4091                 wrap_periodic_pmegrid(pme,grid);
4092
4093                 /* sum contributions to local grid from other nodes */
4094 #ifdef GMX_MPI
4095                 if (pme->nnodes > 1)
4096                 {
4097                     GMX_BARRIER(cr->mpi_comm_mygroup);
4098                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4099                     where();
4100                 }
4101 #endif
4102
4103                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4104             }
4105
4106             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4107
4108             /*
4109             dump_local_fftgrid(pme,fftgrid);
4110             exit(0);
4111             */
4112         }
4113
4114         /* Here we start a large thread parallel region */
4115 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4116         for(thread=0; thread<pme->nthread; thread++)
4117         {
4118             if (flags & GMX_PME_SOLVE)
4119             {
4120                 int loop_count;
4121
4122                 /* do 3d-fft */
4123                 if (thread == 0)
4124                 {
4125                     GMX_BARRIER(cr->mpi_comm_mygroup);
4126                     GMX_MPE_LOG(ev_gmxfft3d_start);
4127                     wallcycle_start(wcycle,ewcPME_FFT);
4128                 }
4129                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4130                                            fftgrid,cfftgrid,thread,wcycle);
4131                 if (thread == 0)
4132                 {
4133                     wallcycle_stop(wcycle,ewcPME_FFT);
4134                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4135                 }
4136                 where();
4137
4138                 /* solve in k-space for our local cells */
4139                 if (thread == 0)
4140                 {
4141                     GMX_BARRIER(cr->mpi_comm_mygroup);
4142                     GMX_MPE_LOG(ev_solve_pme_start);
4143                     wallcycle_start(wcycle,ewcPME_SOLVE);
4144                 }
4145                 loop_count =
4146                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4147                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4148                                   flags & GMX_PME_CALC_ENER_VIR,
4149                                   pme->nthread,thread);
4150                 if (thread == 0)
4151                 {
4152                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4153                     where();
4154                     GMX_MPE_LOG(ev_solve_pme_finish);
4155                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4156                 }
4157             }
4158
4159             if (flags & GMX_PME_CALC_F)
4160             {
4161                 /* do 3d-invfft */
4162                 if (thread == 0)
4163                 {
4164                     GMX_BARRIER(cr->mpi_comm_mygroup);
4165                     GMX_MPE_LOG(ev_gmxfft3d_start);
4166                     where();
4167                     wallcycle_start(wcycle,ewcPME_FFT);
4168                 }
4169                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4170                                            cfftgrid,fftgrid,thread,wcycle);
4171                 if (thread == 0)
4172                 {
4173                     wallcycle_stop(wcycle,ewcPME_FFT);
4174
4175                     where();
4176                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4177
4178                     if (pme->nodeid == 0)
4179                     {
4180                         ntot = pme->nkx*pme->nky*pme->nkz;
4181                         npme  = ntot*log((real)ntot)/log(2.0);
4182                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4183                     }
4184
4185                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4186                 }
4187
4188                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4189             }
4190         }
4191         /* End of thread parallel section.
4192          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4193          */
4194
4195         if (flags & GMX_PME_CALC_F)
4196         {
4197             /* distribute local grid to all nodes */
4198 #ifdef GMX_MPI
4199             if (pme->nnodes > 1) {
4200                 GMX_BARRIER(cr->mpi_comm_mygroup);
4201                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4202             }
4203 #endif
4204             where();
4205
4206             unwrap_periodic_pmegrid(pme,grid);
4207
4208             /* interpolate forces for our local atoms */
4209             GMX_BARRIER(cr->mpi_comm_mygroup);
4210             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4211
4212             where();
4213
4214             /* If we are running without parallelization,
4215              * atc->f is the actual force array, not a buffer,
4216              * therefore we should not clear it.
4217              */
4218             bClearF = (q == 0 && PAR(cr));
4219 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4220             for(thread=0; thread<pme->nthread; thread++)
4221             {
4222                 gather_f_bsplines(pme,grid,bClearF,atc,
4223                                   &atc->spline[thread],
4224                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4225             }
4226
4227             where();
4228
4229             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4230
4231             inc_nrnb(nrnb,eNR_GATHERFBSP,
4232                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4233             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4234         }
4235
4236         if (flags & GMX_PME_CALC_ENER_VIR)
4237         {
4238             /* This should only be called on the master thread
4239              * and after the threads have synchronized.
4240              */
4241             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4242         }
4243     } /* of q-loop */
4244
4245     if ((flags & GMX_PME_CALC_F) && pme->nnodes > 1) {
4246         wallcycle_start(wcycle,ewcPME_REDISTXF);
4247         for(d=0; d<pme->ndecompdim; d++)
4248         {
4249             atc = &pme->atc[d];
4250             if (d == pme->ndecompdim - 1)
4251             {
4252                 n_d = homenr;
4253                 f_d = f + start;
4254             }
4255             else
4256             {
4257                 n_d = pme->atc[d+1].n;
4258                 f_d = pme->atc[d+1].f;
4259             }
4260             GMX_BARRIER(cr->mpi_comm_mygroup);
4261             if (DOMAINDECOMP(cr)) {
4262                 dd_pmeredist_f(pme,atc,n_d,f_d,
4263                                d==pme->ndecompdim-1 && pme->bPPnode);
4264             } else {
4265                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4266             }
4267         }
4268
4269         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4270     }
4271     where();
4272
4273     if (!pme->bFEP) {
4274         *energy = energy_AB[0];
4275         m_add(vir,vir_AB[0],vir);
4276     } else {
4277         *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4278         *dvdlambda += energy_AB[1] - energy_AB[0];
4279         for(i=0; i<DIM; i++)
4280             for(j=0; j<DIM; j++)
4281                 vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + lambda*vir_AB[1][i][j];
4282     }
4283
4284     if (debug)
4285     {
4286         fprintf(debug,"PME mesh energy: %g\n",*energy);
4287     }
4288
4289     return 0;
4290 }