5ff27d7e0ec73abcbda257ad4506557b0ed09682
[alexxy/gromacs.git] / src / mdlib / pme.c
1 /* -*- mode: c; tab-width: 4; indent-tabs-mode: nil; c-basic-offset: 4; c-file-style: "stroustrup"; -*-
2  *
3  *
4  *                This source code is part of
5  *
6  *                 G   R   O   M   A   C   S
7  *
8  *          GROningen MAchine for Chemical Simulations
9  *
10  *                        VERSION 3.2.0
11  * Written by David van der Spoel, Erik Lindahl, Berk Hess, and others.
12  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
13  * Copyright (c) 2001-2004, The GROMACS development team,
14  * check out http://www.gromacs.org for more information.
15
16  * This program is free software; you can redistribute it and/or
17  * modify it under the terms of the GNU General Public License
18  * as published by the Free Software Foundation; either version 2
19  * of the License, or (at your option) any later version.
20  *
21  * If you want to redistribute modifications, please consider that
22  * scientific software is very special. Version control is crucial -
23  * bugs must be traceable. We will be happy to consider code for
24  * inclusion in the official distribution, but derived work must not
25  * be called official GROMACS. Details are found in the README & COPYING
26  * files - if they are missing, get the official version at www.gromacs.org.
27  *
28  * To help us fund GROMACS development, we humbly ask that you cite
29  * the papers on the package - you can find them in the top README file.
30  *
31  * For more info, check our website at http://www.gromacs.org
32  *
33  * And Hey:
34  * GROwing Monsters And Cloning Shrimps
35  */
36 /* IMPORTANT FOR DEVELOPERS:
37  *
38  * Triclinic pme stuff isn't entirely trivial, and we've experienced
39  * some bugs during development (many of them due to me). To avoid
40  * this in the future, please check the following things if you make
41  * changes in this file:
42  *
43  * 1. You should obtain identical (at least to the PME precision)
44  *    energies, forces, and virial for
45  *    a rectangular box and a triclinic one where the z (or y) axis is
46  *    tilted a whole box side. For instance you could use these boxes:
47  *
48  *    rectangular       triclinic
49  *     2  0  0           2  0  0
50  *     0  2  0           0  2  0
51  *     0  0  6           2  2  6
52  *
53  * 2. You should check the energy conservation in a triclinic box.
54  *
55  * It might seem an overkill, but better safe than sorry.
56  * /Erik 001109
57  */
58
59 #ifdef HAVE_CONFIG_H
60 #include <config.h>
61 #endif
62
63 #ifdef GMX_LIB_MPI
64 #include <mpi.h>
65 #endif
66 #ifdef GMX_THREAD_MPI
67 #include "tmpi.h"
68 #endif
69
70 #include <stdio.h>
71 #include <string.h>
72 #include <math.h>
73 #include <assert.h>
74 #include "typedefs.h"
75 #include "txtdump.h"
76 #include "vec.h"
77 #include "gmxcomplex.h"
78 #include "smalloc.h"
79 #include "futil.h"
80 #include "coulomb.h"
81 #include "gmx_fatal.h"
82 #include "pme.h"
83 #include "network.h"
84 #include "physics.h"
85 #include "nrnb.h"
86 #include "copyrite.h"
87 #include "gmx_wallcycle.h"
88 #include "gmx_parallel_3dfft.h"
89 #include "pdbio.h"
90 #include "gmx_cyclecounter.h"
91 #include "gmx_omp.h"
92
93 /* Single precision, with SSE2 or higher available */
94 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
95
96 #include "gmx_x86_simd_single.h"
97
98 #define PME_SSE
99 /* Some old AMD processors could have problems with unaligned loads+stores */
100 #ifndef GMX_FAHCORE
101 #define PME_SSE_UNALIGNED
102 #endif
103 #endif
104
105 #include "mpelogging.h"
106
107 #define DFT_TOL 1e-7
108 /* #define PRT_FORCE */
109 /* conditions for on the fly time-measurement */
110 /* #define TAKETIME (step > 1 && timesteps < 10) */
111 #define TAKETIME FALSE
112
113 /* #define PME_TIME_THREADS */
114
115 #ifdef GMX_DOUBLE
116 #define mpi_type MPI_DOUBLE
117 #else
118 #define mpi_type MPI_FLOAT
119 #endif
120
121 /* GMX_CACHE_SEP should be a multiple of 16 to preserve alignment */
122 #define GMX_CACHE_SEP 64
123
124 /* We only define a maximum to be able to use local arrays without allocation.
125  * An order larger than 12 should never be needed, even for test cases.
126  * If needed it can be changed here.
127  */
128 #define PME_ORDER_MAX 12
129
130 /* Internal datastructures */
131 typedef struct {
132     int send_index0;
133     int send_nindex;
134     int recv_index0;
135     int recv_nindex;
136     int recv_size;   /* Receive buffer width, used with OpenMP */
137 } pme_grid_comm_t;
138
139 typedef struct {
140 #ifdef GMX_MPI
141     MPI_Comm mpi_comm;
142 #endif
143     int  nnodes,nodeid;
144     int  *s2g0;
145     int  *s2g1;
146     int  noverlap_nodes;
147     int  *send_id,*recv_id;
148     int  send_size;             /* Send buffer width, used with OpenMP */
149     pme_grid_comm_t *comm_data;
150     real *sendbuf;
151     real *recvbuf;
152 } pme_overlap_t;
153
154 typedef struct {
155     int *n;     /* Cumulative counts of the number of particles per thread */
156     int nalloc; /* Allocation size of i */
157     int *i;     /* Particle indices ordered on thread index (n) */
158 } thread_plist_t;
159
160 typedef struct {
161     int  *thread_one;
162     int  n;
163     int  *ind;
164     splinevec theta;
165     real *ptr_theta_z;
166     splinevec dtheta;
167     real *ptr_dtheta_z;
168 } splinedata_t;
169
170 typedef struct {
171     int  dimind;            /* The index of the dimension, 0=x, 1=y */
172     int  nslab;
173     int  nodeid;
174 #ifdef GMX_MPI
175     MPI_Comm mpi_comm;
176 #endif
177
178     int  *node_dest;        /* The nodes to send x and q to with DD */
179     int  *node_src;         /* The nodes to receive x and q from with DD */
180     int  *buf_index;        /* Index for commnode into the buffers */
181
182     int  maxshift;
183
184     int  npd;
185     int  pd_nalloc;
186     int  *pd;
187     int  *count;            /* The number of atoms to send to each node */
188     int  **count_thread;
189     int  *rcount;           /* The number of atoms to receive */
190
191     int  n;
192     int  nalloc;
193     rvec *x;
194     real *q;
195     rvec *f;
196     gmx_bool bSpread;       /* These coordinates are used for spreading */
197     int  pme_order;
198     ivec *idx;
199     rvec *fractx;            /* Fractional coordinate relative to the
200                               * lower cell boundary
201                               */
202     int  nthread;
203     int  *thread_idx;        /* Which thread should spread which charge */
204     thread_plist_t *thread_plist;
205     splinedata_t *spline;
206 } pme_atomcomm_t;
207
208 #define FLBS  3
209 #define FLBSZ 4
210
211 typedef struct {
212     ivec ci;     /* The spatial location of this grid         */
213     ivec n;      /* The used size of *grid, including order-1 */
214     ivec offset; /* The grid offset from the full node grid   */
215     int  order;  /* PME spreading order                       */
216     ivec s;      /* The allocated size of *grid, s >= n       */
217     real *grid;  /* The grid local thread, size n             */
218 } pmegrid_t;
219
220 typedef struct {
221     pmegrid_t grid;     /* The full node grid (non thread-local)            */
222     int  nthread;       /* The number of threads operating on this grid     */
223     ivec nc;            /* The local spatial decomposition over the threads */
224     pmegrid_t *grid_th; /* Array of grids for each thread                   */
225     real *grid_all;     /* Allocated array for the grids in *grid_th        */
226     int  **g2t;         /* The grid to thread index                         */
227     ivec nthread_comm;  /* The number of threads to communicate with        */
228 } pmegrids_t;
229
230
231 typedef struct {
232 #ifdef PME_SSE
233     /* Masks for SSE aligned spreading and gathering */
234     __m128 mask_SSE0[6],mask_SSE1[6];
235 #else
236     int dummy; /* C89 requires that struct has at least one member */
237 #endif
238 } pme_spline_work_t;
239
240 typedef struct {
241     /* work data for solve_pme */
242     int      nalloc;
243     real *   mhx;
244     real *   mhy;
245     real *   mhz;
246     real *   m2;
247     real *   denom;
248     real *   tmp1_alloc;
249     real *   tmp1;
250     real *   eterm;
251     real *   m2inv;
252
253     real     energy;
254     matrix   vir;
255 } pme_work_t;
256
257 typedef struct gmx_pme {
258     int  ndecompdim;         /* The number of decomposition dimensions */
259     int  nodeid;             /* Our nodeid in mpi->mpi_comm */
260     int  nodeid_major;
261     int  nodeid_minor;
262     int  nnodes;             /* The number of nodes doing PME */
263     int  nnodes_major;
264     int  nnodes_minor;
265
266     MPI_Comm mpi_comm;
267     MPI_Comm mpi_comm_d[2];  /* Indexed on dimension, 0=x, 1=y */
268 #ifdef GMX_MPI
269     MPI_Datatype  rvec_mpi;  /* the pme vector's MPI type */
270 #endif
271
272     int  nthread;            /* The number of threads doing PME */
273
274     gmx_bool bPPnode;        /* Node also does particle-particle forces */
275     gmx_bool bFEP;           /* Compute Free energy contribution */
276     int nkx,nky,nkz;         /* Grid dimensions */
277     gmx_bool bP3M;           /* Do P3M: optimize the influence function */
278     int pme_order;
279     real epsilon_r;
280
281     pmegrids_t pmegridA;  /* Grids on which we do spreading/interpolation, includes overlap */
282     pmegrids_t pmegridB;
283     /* The PME charge spreading grid sizes/strides, includes pme_order-1 */
284     int     pmegrid_nx,pmegrid_ny,pmegrid_nz;
285     /* pmegrid_nz might be larger than strictly necessary to ensure
286      * memory alignment, pmegrid_nz_base gives the real base size.
287      */
288     int     pmegrid_nz_base;
289     /* The local PME grid starting indices */
290     int     pmegrid_start_ix,pmegrid_start_iy,pmegrid_start_iz;
291
292     /* Work data for spreading and gathering */
293     pme_spline_work_t *spline_work;
294
295     real *fftgridA;             /* Grids for FFT. With 1D FFT decomposition this can be a pointer */
296     real *fftgridB;             /* inside the interpolation grid, but separate for 2D PME decomp. */
297     int   fftgrid_nx,fftgrid_ny,fftgrid_nz;
298
299     t_complex *cfftgridA;             /* Grids for complex FFT data */
300     t_complex *cfftgridB;
301     int   cfftgrid_nx,cfftgrid_ny,cfftgrid_nz;
302
303     gmx_parallel_3dfft_t  pfft_setupA;
304     gmx_parallel_3dfft_t  pfft_setupB;
305
306     int  *nnx,*nny,*nnz;
307     real *fshx,*fshy,*fshz;
308
309     pme_atomcomm_t atc[2];  /* Indexed on decomposition index */
310     matrix    recipbox;
311     splinevec bsp_mod;
312
313     pme_overlap_t overlap[2]; /* Indexed on dimension, 0=x, 1=y */
314
315     pme_atomcomm_t atc_energy; /* Only for gmx_pme_calc_energy */
316
317     rvec *bufv;             /* Communication buffer */
318     real *bufr;             /* Communication buffer */
319     int  buf_nalloc;        /* The communication buffer size */
320
321     /* thread local work data for solve_pme */
322     pme_work_t *work;
323
324     /* Work data for PME_redist */
325     gmx_bool redist_init;
326     int *    scounts;
327     int *    rcounts;
328     int *    sdispls;
329     int *    rdispls;
330     int *    sidx;
331     int *    idxa;
332     real *   redist_buf;
333     int      redist_buf_nalloc;
334
335     /* Work data for sum_qgrid */
336     real *   sum_qgrid_tmp;
337     real *   sum_qgrid_dd_tmp;
338 } t_gmx_pme;
339
340
341 static void calc_interpolation_idx(gmx_pme_t pme,pme_atomcomm_t *atc,
342                                    int start,int end,int thread)
343 {
344     int  i;
345     int  *idxptr,tix,tiy,tiz;
346     real *xptr,*fptr,tx,ty,tz;
347     real rxx,ryx,ryy,rzx,rzy,rzz;
348     int  nx,ny,nz;
349     int  start_ix,start_iy,start_iz;
350     int  *g2tx,*g2ty,*g2tz;
351     gmx_bool bThreads;
352     int  *thread_idx=NULL;
353     thread_plist_t *tpl=NULL;
354     int  *tpl_n=NULL;
355     int  thread_i;
356
357     nx  = pme->nkx;
358     ny  = pme->nky;
359     nz  = pme->nkz;
360
361     start_ix = pme->pmegrid_start_ix;
362     start_iy = pme->pmegrid_start_iy;
363     start_iz = pme->pmegrid_start_iz;
364
365     rxx = pme->recipbox[XX][XX];
366     ryx = pme->recipbox[YY][XX];
367     ryy = pme->recipbox[YY][YY];
368     rzx = pme->recipbox[ZZ][XX];
369     rzy = pme->recipbox[ZZ][YY];
370     rzz = pme->recipbox[ZZ][ZZ];
371
372     g2tx = pme->pmegridA.g2t[XX];
373     g2ty = pme->pmegridA.g2t[YY];
374     g2tz = pme->pmegridA.g2t[ZZ];
375
376     bThreads = (atc->nthread > 1);
377     if (bThreads)
378     {
379         thread_idx = atc->thread_idx;
380
381         tpl   = &atc->thread_plist[thread];
382         tpl_n = tpl->n;
383         for(i=0; i<atc->nthread; i++)
384         {
385             tpl_n[i] = 0;
386         }
387     }
388
389     for(i=start; i<end; i++) {
390         xptr   = atc->x[i];
391         idxptr = atc->idx[i];
392         fptr   = atc->fractx[i];
393
394         /* Fractional coordinates along box vectors, add 2.0 to make 100% sure we are positive for triclinic boxes */
395         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + 2.0 );
396         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + 2.0 );
397         tz = nz * (                                   xptr[ZZ] * rzz + 2.0 );
398
399         tix = (int)(tx);
400         tiy = (int)(ty);
401         tiz = (int)(tz);
402
403         /* Because decomposition only occurs in x and y,
404          * we never have a fraction correction in z.
405          */
406         fptr[XX] = tx - tix + pme->fshx[tix];
407         fptr[YY] = ty - tiy + pme->fshy[tiy];
408         fptr[ZZ] = tz - tiz;
409
410         idxptr[XX] = pme->nnx[tix];
411         idxptr[YY] = pme->nny[tiy];
412         idxptr[ZZ] = pme->nnz[tiz];
413
414 #ifdef DEBUG
415         range_check(idxptr[XX],0,pme->pmegrid_nx);
416         range_check(idxptr[YY],0,pme->pmegrid_ny);
417         range_check(idxptr[ZZ],0,pme->pmegrid_nz);
418 #endif
419
420         if (bThreads)
421         {
422             thread_i = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
423             thread_idx[i] = thread_i;
424             tpl_n[thread_i]++;
425         }
426     }
427
428     if (bThreads)
429     {
430         /* Make a list of particle indices sorted on thread */
431
432         /* Get the cumulative count */
433         for(i=1; i<atc->nthread; i++)
434         {
435             tpl_n[i] += tpl_n[i-1];
436         }
437         /* The current implementation distributes particles equally
438          * over the threads, so we could actually allocate for that
439          * in pme_realloc_atomcomm_things.
440          */
441         if (tpl_n[atc->nthread-1] > tpl->nalloc)
442         {
443             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
444             srenew(tpl->i,tpl->nalloc);
445         }
446         /* Set tpl_n to the cumulative start */
447         for(i=atc->nthread-1; i>=1; i--)
448         {
449             tpl_n[i] = tpl_n[i-1];
450         }
451         tpl_n[0] = 0;
452
453         /* Fill our thread local array with indices sorted on thread */
454         for(i=start; i<end; i++)
455         {
456             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
457         }
458         /* Now tpl_n contains the cummulative count again */
459     }
460 }
461
462 static void make_thread_local_ind(pme_atomcomm_t *atc,
463                                   int thread,splinedata_t *spline)
464 {
465     int  n,t,i,start,end;
466     thread_plist_t *tpl;
467
468     /* Combine the indices made by each thread into one index */
469
470     n = 0;
471     start = 0;
472     for(t=0; t<atc->nthread; t++)
473     {
474         tpl = &atc->thread_plist[t];
475         /* Copy our part (start - end) from the list of thread t */
476         if (thread > 0)
477         {
478             start = tpl->n[thread-1];
479         }
480         end = tpl->n[thread];
481         for(i=start; i<end; i++)
482         {
483             spline->ind[n++] = tpl->i[i];
484         }
485     }
486
487     spline->n = n;
488 }
489
490
491 static void pme_calc_pidx(int start, int end,
492                           matrix recipbox, rvec x[],
493                           pme_atomcomm_t *atc, int *count)
494 {
495     int  nslab,i;
496     int  si;
497     real *xptr,s;
498     real rxx,ryx,rzx,ryy,rzy;
499     int *pd;
500
501     /* Calculate PME task index (pidx) for each grid index.
502      * Here we always assign equally sized slabs to each node
503      * for load balancing reasons (the PME grid spacing is not used).
504      */
505
506     nslab = atc->nslab;
507     pd    = atc->pd;
508
509     /* Reset the count */
510     for(i=0; i<nslab; i++)
511     {
512         count[i] = 0;
513     }
514
515     if (atc->dimind == 0)
516     {
517         rxx = recipbox[XX][XX];
518         ryx = recipbox[YY][XX];
519         rzx = recipbox[ZZ][XX];
520         /* Calculate the node index in x-dimension */
521         for(i=start; i<end; i++)
522         {
523             xptr   = x[i];
524             /* Fractional coordinates along box vectors */
525             s = nslab*(xptr[XX]*rxx + xptr[YY]*ryx + xptr[ZZ]*rzx);
526             si = (int)(s + 2*nslab) % nslab;
527             pd[i] = si;
528             count[si]++;
529         }
530     }
531     else
532     {
533         ryy = recipbox[YY][YY];
534         rzy = recipbox[ZZ][YY];
535         /* Calculate the node index in y-dimension */
536         for(i=start; i<end; i++)
537         {
538             xptr   = x[i];
539             /* Fractional coordinates along box vectors */
540             s = nslab*(xptr[YY]*ryy + xptr[ZZ]*rzy);
541             si = (int)(s + 2*nslab) % nslab;
542             pd[i] = si;
543             count[si]++;
544         }
545     }
546 }
547
548 static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
549                                   pme_atomcomm_t *atc)
550 {
551     int nthread,thread,slab;
552
553     nthread = atc->nthread;
554
555 #pragma omp parallel for num_threads(nthread) schedule(static)
556     for(thread=0; thread<nthread; thread++)
557     {
558         pme_calc_pidx(natoms* thread   /nthread,
559                       natoms*(thread+1)/nthread,
560                       recipbox,x,atc,atc->count_thread[thread]);
561     }
562     /* Non-parallel reduction, since nslab is small */
563
564     for(thread=1; thread<nthread; thread++)
565     {
566         for(slab=0; slab<atc->nslab; slab++)
567         {
568             atc->count_thread[0][slab] += atc->count_thread[thread][slab];
569         }
570     }
571 }
572
573 static void realloc_splinevec(splinevec th,real **ptr_z,int nalloc)
574 {
575     const int padding=4;
576     int i;
577
578     srenew(th[XX],nalloc);
579     srenew(th[YY],nalloc);
580     /* In z we add padding, this is only required for the aligned SSE code */
581     srenew(*ptr_z,nalloc+2*padding);
582     th[ZZ] = *ptr_z + padding;
583
584     for(i=0; i<padding; i++)
585     {
586         (*ptr_z)[               i] = 0;
587         (*ptr_z)[padding+nalloc+i] = 0;
588     }
589 }
590
591 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
592 {
593     int i,d;
594
595     srenew(spline->ind,atc->nalloc);
596     /* Initialize the index to identity so it works without threads */
597     for(i=0; i<atc->nalloc; i++)
598     {
599         spline->ind[i] = i;
600     }
601
602     realloc_splinevec(spline->theta,&spline->ptr_theta_z,
603                       atc->pme_order*atc->nalloc);
604     realloc_splinevec(spline->dtheta,&spline->ptr_dtheta_z,
605                       atc->pme_order*atc->nalloc);
606 }
607
608 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
609 {
610     int nalloc_old,i,j,nalloc_tpl;
611
612     /* We have to avoid a NULL pointer for atc->x to avoid
613      * possible fatal errors in MPI routines.
614      */
615     if (atc->n > atc->nalloc || atc->nalloc == 0)
616     {
617         nalloc_old = atc->nalloc;
618         atc->nalloc = over_alloc_dd(max(atc->n,1));
619
620         if (atc->nslab > 1) {
621             srenew(atc->x,atc->nalloc);
622             srenew(atc->q,atc->nalloc);
623             srenew(atc->f,atc->nalloc);
624             for(i=nalloc_old; i<atc->nalloc; i++)
625             {
626                 clear_rvec(atc->f[i]);
627             }
628         }
629         if (atc->bSpread) {
630             srenew(atc->fractx,atc->nalloc);
631             srenew(atc->idx   ,atc->nalloc);
632
633             if (atc->nthread > 1)
634             {
635                 srenew(atc->thread_idx,atc->nalloc);
636             }
637
638             for(i=0; i<atc->nthread; i++)
639             {
640                 pme_realloc_splinedata(&atc->spline[i],atc);
641             }
642         }
643     }
644 }
645
646 static void pmeredist_pd(gmx_pme_t pme, gmx_bool forw,
647                          int n, gmx_bool bXF, rvec *x_f, real *charge,
648                          pme_atomcomm_t *atc)
649 /* Redistribute particle data for PME calculation */
650 /* domain decomposition by x coordinate           */
651 {
652     int *idxa;
653     int i, ii;
654
655     if(FALSE == pme->redist_init) {
656         snew(pme->scounts,atc->nslab);
657         snew(pme->rcounts,atc->nslab);
658         snew(pme->sdispls,atc->nslab);
659         snew(pme->rdispls,atc->nslab);
660         snew(pme->sidx,atc->nslab);
661         pme->redist_init = TRUE;
662     }
663     if (n > pme->redist_buf_nalloc) {
664         pme->redist_buf_nalloc = over_alloc_dd(n);
665         srenew(pme->redist_buf,pme->redist_buf_nalloc*DIM);
666     }
667
668     pme->idxa = atc->pd;
669
670 #ifdef GMX_MPI
671     if (forw && bXF) {
672         /* forward, redistribution from pp to pme */
673
674         /* Calculate send counts and exchange them with other nodes */
675         for(i=0; (i<atc->nslab); i++) pme->scounts[i]=0;
676         for(i=0; (i<n); i++) pme->scounts[pme->idxa[i]]++;
677         MPI_Alltoall( pme->scounts, 1, MPI_INT, pme->rcounts, 1, MPI_INT, atc->mpi_comm);
678
679         /* Calculate send and receive displacements and index into send
680            buffer */
681         pme->sdispls[0]=0;
682         pme->rdispls[0]=0;
683         pme->sidx[0]=0;
684         for(i=1; i<atc->nslab; i++) {
685             pme->sdispls[i]=pme->sdispls[i-1]+pme->scounts[i-1];
686             pme->rdispls[i]=pme->rdispls[i-1]+pme->rcounts[i-1];
687             pme->sidx[i]=pme->sdispls[i];
688         }
689         /* Total # of particles to be received */
690         atc->n = pme->rdispls[atc->nslab-1] + pme->rcounts[atc->nslab-1];
691
692         pme_realloc_atomcomm_things(atc);
693
694         /* Copy particle coordinates into send buffer and exchange*/
695         for(i=0; (i<n); i++) {
696             ii=DIM*pme->sidx[pme->idxa[i]];
697             pme->sidx[pme->idxa[i]]++;
698             pme->redist_buf[ii+XX]=x_f[i][XX];
699             pme->redist_buf[ii+YY]=x_f[i][YY];
700             pme->redist_buf[ii+ZZ]=x_f[i][ZZ];
701         }
702         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls,
703                       pme->rvec_mpi, atc->x, pme->rcounts, pme->rdispls,
704                       pme->rvec_mpi, atc->mpi_comm);
705     }
706     if (forw) {
707         /* Copy charge into send buffer and exchange*/
708         for(i=0; i<atc->nslab; i++) pme->sidx[i]=pme->sdispls[i];
709         for(i=0; (i<n); i++) {
710             ii=pme->sidx[pme->idxa[i]];
711             pme->sidx[pme->idxa[i]]++;
712             pme->redist_buf[ii]=charge[i];
713         }
714         MPI_Alltoallv(pme->redist_buf, pme->scounts, pme->sdispls, mpi_type,
715                       atc->q, pme->rcounts, pme->rdispls, mpi_type,
716                       atc->mpi_comm);
717     }
718     else { /* backward, redistribution from pme to pp */
719         MPI_Alltoallv(atc->f, pme->rcounts, pme->rdispls, pme->rvec_mpi,
720                       pme->redist_buf, pme->scounts, pme->sdispls,
721                       pme->rvec_mpi, atc->mpi_comm);
722
723         /* Copy data from receive buffer */
724         for(i=0; i<atc->nslab; i++)
725             pme->sidx[i] = pme->sdispls[i];
726         for(i=0; (i<n); i++) {
727             ii = DIM*pme->sidx[pme->idxa[i]];
728             x_f[i][XX] += pme->redist_buf[ii+XX];
729             x_f[i][YY] += pme->redist_buf[ii+YY];
730             x_f[i][ZZ] += pme->redist_buf[ii+ZZ];
731             pme->sidx[pme->idxa[i]]++;
732         }
733     }
734 #endif
735 }
736
737 static void pme_dd_sendrecv(pme_atomcomm_t *atc,
738                             gmx_bool bBackward,int shift,
739                             void *buf_s,int nbyte_s,
740                             void *buf_r,int nbyte_r)
741 {
742 #ifdef GMX_MPI
743     int dest,src;
744     MPI_Status stat;
745
746     if (bBackward == FALSE) {
747         dest = atc->node_dest[shift];
748         src  = atc->node_src[shift];
749     } else {
750         dest = atc->node_src[shift];
751         src  = atc->node_dest[shift];
752     }
753
754     if (nbyte_s > 0 && nbyte_r > 0) {
755         MPI_Sendrecv(buf_s,nbyte_s,MPI_BYTE,
756                      dest,shift,
757                      buf_r,nbyte_r,MPI_BYTE,
758                      src,shift,
759                      atc->mpi_comm,&stat);
760     } else if (nbyte_s > 0) {
761         MPI_Send(buf_s,nbyte_s,MPI_BYTE,
762                  dest,shift,
763                  atc->mpi_comm);
764     } else if (nbyte_r > 0) {
765         MPI_Recv(buf_r,nbyte_r,MPI_BYTE,
766                  src,shift,
767                  atc->mpi_comm,&stat);
768     }
769 #endif
770 }
771
772 static void dd_pmeredist_x_q(gmx_pme_t pme,
773                              int n, gmx_bool bX, rvec *x, real *charge,
774                              pme_atomcomm_t *atc)
775 {
776     int *commnode,*buf_index;
777     int nnodes_comm,i,nsend,local_pos,buf_pos,node,scount,rcount;
778
779     commnode  = atc->node_dest;
780     buf_index = atc->buf_index;
781
782     nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
783
784     nsend = 0;
785     for(i=0; i<nnodes_comm; i++) {
786         buf_index[commnode[i]] = nsend;
787         nsend += atc->count[commnode[i]];
788     }
789     if (bX) {
790         if (atc->count[atc->nodeid] + nsend != n)
791             gmx_fatal(FARGS,"%d particles communicated to PME node %d are more than 2/3 times the cut-off out of the domain decomposition cell of their charge group in dimension %c.\n"
792                       "This usually means that your system is not well equilibrated.",
793                       n - (atc->count[atc->nodeid] + nsend),
794                       pme->nodeid,'x'+atc->dimind);
795
796         if (nsend > pme->buf_nalloc) {
797             pme->buf_nalloc = over_alloc_dd(nsend);
798             srenew(pme->bufv,pme->buf_nalloc);
799             srenew(pme->bufr,pme->buf_nalloc);
800         }
801
802         atc->n = atc->count[atc->nodeid];
803         for(i=0; i<nnodes_comm; i++) {
804             scount = atc->count[commnode[i]];
805             /* Communicate the count */
806             if (debug)
807                 fprintf(debug,"dimind %d PME node %d send to node %d: %d\n",
808                         atc->dimind,atc->nodeid,commnode[i],scount);
809             pme_dd_sendrecv(atc,FALSE,i,
810                             &scount,sizeof(int),
811                             &atc->rcount[i],sizeof(int));
812             atc->n += atc->rcount[i];
813         }
814
815         pme_realloc_atomcomm_things(atc);
816     }
817
818     local_pos = 0;
819     for(i=0; i<n; i++) {
820         node = atc->pd[i];
821         if (node == atc->nodeid) {
822             /* Copy direct to the receive buffer */
823             if (bX) {
824                 copy_rvec(x[i],atc->x[local_pos]);
825             }
826             atc->q[local_pos] = charge[i];
827             local_pos++;
828         } else {
829             /* Copy to the send buffer */
830             if (bX) {
831                 copy_rvec(x[i],pme->bufv[buf_index[node]]);
832             }
833             pme->bufr[buf_index[node]] = charge[i];
834             buf_index[node]++;
835         }
836     }
837
838     buf_pos = 0;
839     for(i=0; i<nnodes_comm; i++) {
840         scount = atc->count[commnode[i]];
841         rcount = atc->rcount[i];
842         if (scount > 0 || rcount > 0) {
843             if (bX) {
844                 /* Communicate the coordinates */
845                 pme_dd_sendrecv(atc,FALSE,i,
846                                 pme->bufv[buf_pos],scount*sizeof(rvec),
847                                 atc->x[local_pos],rcount*sizeof(rvec));
848             }
849             /* Communicate the charges */
850             pme_dd_sendrecv(atc,FALSE,i,
851                             pme->bufr+buf_pos,scount*sizeof(real),
852                             atc->q+local_pos,rcount*sizeof(real));
853             buf_pos   += scount;
854             local_pos += atc->rcount[i];
855         }
856     }
857 }
858
859 static void dd_pmeredist_f(gmx_pme_t pme, pme_atomcomm_t *atc,
860                            int n, rvec *f,
861                            gmx_bool bAddF)
862 {
863   int *commnode,*buf_index;
864   int nnodes_comm,local_pos,buf_pos,i,scount,rcount,node;
865
866   commnode  = atc->node_dest;
867   buf_index = atc->buf_index;
868
869   nnodes_comm = min(2*atc->maxshift,atc->nslab-1);
870
871   local_pos = atc->count[atc->nodeid];
872   buf_pos = 0;
873   for(i=0; i<nnodes_comm; i++) {
874     scount = atc->rcount[i];
875     rcount = atc->count[commnode[i]];
876     if (scount > 0 || rcount > 0) {
877       /* Communicate the forces */
878       pme_dd_sendrecv(atc,TRUE,i,
879                       atc->f[local_pos],scount*sizeof(rvec),
880                       pme->bufv[buf_pos],rcount*sizeof(rvec));
881       local_pos += scount;
882     }
883     buf_index[commnode[i]] = buf_pos;
884     buf_pos   += rcount;
885   }
886
887     local_pos = 0;
888     if (bAddF)
889     {
890         for(i=0; i<n; i++)
891         {
892             node = atc->pd[i];
893             if (node == atc->nodeid)
894             {
895                 /* Add from the local force array */
896                 rvec_inc(f[i],atc->f[local_pos]);
897                 local_pos++;
898             }
899             else
900             {
901                 /* Add from the receive buffer */
902                 rvec_inc(f[i],pme->bufv[buf_index[node]]);
903                 buf_index[node]++;
904             }
905         }
906     }
907     else
908     {
909         for(i=0; i<n; i++)
910         {
911             node = atc->pd[i];
912             if (node == atc->nodeid)
913             {
914                 /* Copy from the local force array */
915                 copy_rvec(atc->f[local_pos],f[i]);
916                 local_pos++;
917             }
918             else
919             {
920                 /* Copy from the receive buffer */
921                 copy_rvec(pme->bufv[buf_index[node]],f[i]);
922                 buf_index[node]++;
923             }
924         }
925     }
926 }
927
928 #ifdef GMX_MPI
929 static void
930 gmx_sum_qgrid_dd(gmx_pme_t pme, real *grid, int direction)
931 {
932     pme_overlap_t *overlap;
933     int send_index0,send_nindex;
934     int recv_index0,recv_nindex;
935     MPI_Status stat;
936     int i,j,k,ix,iy,iz,icnt;
937     int ipulse,send_id,recv_id,datasize;
938     real *p;
939     real *sendptr,*recvptr;
940
941     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
942     overlap = &pme->overlap[1];
943
944     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
945     {
946         /* Since we have already (un)wrapped the overlap in the z-dimension,
947          * we only have to communicate 0 to nkz (not pmegrid_nz).
948          */
949         if (direction==GMX_SUM_QGRID_FORWARD)
950         {
951             send_id = overlap->send_id[ipulse];
952             recv_id = overlap->recv_id[ipulse];
953             send_index0   = overlap->comm_data[ipulse].send_index0;
954             send_nindex   = overlap->comm_data[ipulse].send_nindex;
955             recv_index0   = overlap->comm_data[ipulse].recv_index0;
956             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
957         }
958         else
959         {
960             send_id = overlap->recv_id[ipulse];
961             recv_id = overlap->send_id[ipulse];
962             send_index0   = overlap->comm_data[ipulse].recv_index0;
963             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
964             recv_index0   = overlap->comm_data[ipulse].send_index0;
965             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
966         }
967
968         /* Copy data to contiguous send buffer */
969         if (debug)
970         {
971             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
972                     pme->nodeid,overlap->nodeid,send_id,
973                     pme->pmegrid_start_iy,
974                     send_index0-pme->pmegrid_start_iy,
975                     send_index0-pme->pmegrid_start_iy+send_nindex);
976         }
977         icnt = 0;
978         for(i=0;i<pme->pmegrid_nx;i++)
979         {
980             ix = i;
981             for(j=0;j<send_nindex;j++)
982             {
983                 iy = j + send_index0 - pme->pmegrid_start_iy;
984                 for(k=0;k<pme->nkz;k++)
985                 {
986                     iz = k;
987                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
988                 }
989             }
990         }
991
992         datasize      = pme->pmegrid_nx * pme->nkz;
993
994         MPI_Sendrecv(overlap->sendbuf,send_nindex*datasize,GMX_MPI_REAL,
995                      send_id,ipulse,
996                      overlap->recvbuf,recv_nindex*datasize,GMX_MPI_REAL,
997                      recv_id,ipulse,
998                      overlap->mpi_comm,&stat);
999
1000         /* Get data from contiguous recv buffer */
1001         if (debug)
1002         {
1003             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1004                     pme->nodeid,overlap->nodeid,recv_id,
1005                     pme->pmegrid_start_iy,
1006                     recv_index0-pme->pmegrid_start_iy,
1007                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
1008         }
1009         icnt = 0;
1010         for(i=0;i<pme->pmegrid_nx;i++)
1011         {
1012             ix = i;
1013             for(j=0;j<recv_nindex;j++)
1014             {
1015                 iy = j + recv_index0 - pme->pmegrid_start_iy;
1016                 for(k=0;k<pme->nkz;k++)
1017                 {
1018                     iz = k;
1019                     if(direction==GMX_SUM_QGRID_FORWARD)
1020                     {
1021                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
1022                     }
1023                     else
1024                     {
1025                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
1026                     }
1027                 }
1028             }
1029         }
1030     }
1031
1032     /* Major dimension is easier, no copying required,
1033      * but we might have to sum to separate array.
1034      * Since we don't copy, we have to communicate up to pmegrid_nz,
1035      * not nkz as for the minor direction.
1036      */
1037     overlap = &pme->overlap[0];
1038
1039     for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
1040     {
1041         if(direction==GMX_SUM_QGRID_FORWARD)
1042         {
1043             send_id = overlap->send_id[ipulse];
1044             recv_id = overlap->recv_id[ipulse];
1045             send_index0   = overlap->comm_data[ipulse].send_index0;
1046             send_nindex   = overlap->comm_data[ipulse].send_nindex;
1047             recv_index0   = overlap->comm_data[ipulse].recv_index0;
1048             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
1049             recvptr   = overlap->recvbuf;
1050         }
1051         else
1052         {
1053             send_id = overlap->recv_id[ipulse];
1054             recv_id = overlap->send_id[ipulse];
1055             send_index0   = overlap->comm_data[ipulse].recv_index0;
1056             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
1057             recv_index0   = overlap->comm_data[ipulse].send_index0;
1058             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
1059             recvptr   = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1060         }
1061
1062         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1063         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
1064
1065         if (debug)
1066         {
1067             fprintf(debug,"PME send node %d %d -> %d grid start %d Communicating %d to %d\n",
1068                     pme->nodeid,overlap->nodeid,send_id,
1069                     pme->pmegrid_start_ix,
1070                     send_index0-pme->pmegrid_start_ix,
1071                     send_index0-pme->pmegrid_start_ix+send_nindex);
1072             fprintf(debug,"PME recv node %d %d <- %d grid start %d Communicating %d to %d\n",
1073                     pme->nodeid,overlap->nodeid,recv_id,
1074                     pme->pmegrid_start_ix,
1075                     recv_index0-pme->pmegrid_start_ix,
1076                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
1077         }
1078
1079         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
1080                      send_id,ipulse,
1081                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
1082                      recv_id,ipulse,
1083                      overlap->mpi_comm,&stat);
1084
1085         /* ADD data from contiguous recv buffer */
1086         if(direction==GMX_SUM_QGRID_FORWARD)
1087         {
1088             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
1089             for(i=0;i<recv_nindex*datasize;i++)
1090             {
1091                 p[i] += overlap->recvbuf[i];
1092             }
1093         }
1094     }
1095 }
1096 #endif
1097
1098
1099 static int
1100 copy_pmegrid_to_fftgrid(gmx_pme_t pme, real *pmegrid, real *fftgrid)
1101 {
1102     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1103     ivec    local_pme_size;
1104     int     i,ix,iy,iz;
1105     int     pmeidx,fftidx;
1106
1107     /* Dimensions should be identical for A/B grid, so we just use A here */
1108     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1109                                    local_fft_ndata,
1110                                    local_fft_offset,
1111                                    local_fft_size);
1112
1113     local_pme_size[0] = pme->pmegrid_nx;
1114     local_pme_size[1] = pme->pmegrid_ny;
1115     local_pme_size[2] = pme->pmegrid_nz;
1116
1117     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1118      the offset is identical, and the PME grid always has more data (due to overlap)
1119      */
1120     {
1121 #ifdef DEBUG_PME
1122         FILE *fp,*fp2;
1123         char fn[STRLEN],format[STRLEN];
1124         real val;
1125         sprintf(fn,"pmegrid%d.pdb",pme->nodeid);
1126         fp = ffopen(fn,"w");
1127         sprintf(fn,"pmegrid%d.txt",pme->nodeid);
1128         fp2 = ffopen(fn,"w");
1129      sprintf(format,"%s%s\n",pdbformat,"%6.2f%6.2f");
1130 #endif
1131
1132     for(ix=0;ix<local_fft_ndata[XX];ix++)
1133     {
1134         for(iy=0;iy<local_fft_ndata[YY];iy++)
1135         {
1136             for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1137             {
1138                 pmeidx = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
1139                 fftidx = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
1140                 fftgrid[fftidx] = pmegrid[pmeidx];
1141 #ifdef DEBUG_PME
1142                 val = 100*pmegrid[pmeidx];
1143                 if (pmegrid[pmeidx] != 0)
1144                 fprintf(fp,format,"ATOM",pmeidx,"CA","GLY",' ',pmeidx,' ',
1145                         5.0*ix,5.0*iy,5.0*iz,1.0,val);
1146                 if (pmegrid[pmeidx] != 0)
1147                     fprintf(fp2,"%-12s  %5d  %5d  %5d  %12.5e\n",
1148                             "qgrid",
1149                             pme->pmegrid_start_ix + ix,
1150                             pme->pmegrid_start_iy + iy,
1151                             pme->pmegrid_start_iz + iz,
1152                             pmegrid[pmeidx]);
1153 #endif
1154             }
1155         }
1156     }
1157 #ifdef DEBUG_PME
1158     ffclose(fp);
1159     ffclose(fp2);
1160 #endif
1161     }
1162     return 0;
1163 }
1164
1165
1166 static gmx_cycles_t omp_cyc_start()
1167 {
1168     return gmx_cycles_read();
1169 }
1170
1171 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
1172 {
1173     return gmx_cycles_read() - c;
1174 }
1175
1176
1177 static int
1178 copy_fftgrid_to_pmegrid(gmx_pme_t pme, const real *fftgrid, real *pmegrid,
1179                         int nthread,int thread)
1180 {
1181     ivec    local_fft_ndata,local_fft_offset,local_fft_size;
1182     ivec    local_pme_size;
1183     int     ixy0,ixy1,ixy,ix,iy,iz;
1184     int     pmeidx,fftidx;
1185 #ifdef PME_TIME_THREADS
1186     gmx_cycles_t c1;
1187     static double cs1=0;
1188     static int cnt=0;
1189 #endif
1190
1191 #ifdef PME_TIME_THREADS
1192     c1 = omp_cyc_start();
1193 #endif
1194     /* Dimensions should be identical for A/B grid, so we just use A here */
1195     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
1196                                    local_fft_ndata,
1197                                    local_fft_offset,
1198                                    local_fft_size);
1199
1200     local_pme_size[0] = pme->pmegrid_nx;
1201     local_pme_size[1] = pme->pmegrid_ny;
1202     local_pme_size[2] = pme->pmegrid_nz;
1203
1204     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
1205      the offset is identical, and the PME grid always has more data (due to overlap)
1206      */
1207     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1208     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
1209
1210     for(ixy=ixy0;ixy<ixy1;ixy++)
1211     {
1212         ix = ixy/local_fft_ndata[YY];
1213         iy = ixy - ix*local_fft_ndata[YY];
1214
1215         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
1216         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
1217         for(iz=0;iz<local_fft_ndata[ZZ];iz++)
1218         {
1219             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
1220         }
1221     }
1222
1223 #ifdef PME_TIME_THREADS
1224     c1 = omp_cyc_end(c1);
1225     cs1 += (double)c1;
1226     cnt++;
1227     if (cnt % 20 == 0)
1228     {
1229         printf("copy %.2f\n",cs1*1e-9);
1230     }
1231 #endif
1232
1233     return 0;
1234 }
1235
1236
1237 static void
1238 wrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1239 {
1240     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix,iy,iz;
1241
1242     nx = pme->nkx;
1243     ny = pme->nky;
1244     nz = pme->nkz;
1245
1246     pnx = pme->pmegrid_nx;
1247     pny = pme->pmegrid_ny;
1248     pnz = pme->pmegrid_nz;
1249
1250     overlap = pme->pme_order - 1;
1251
1252     /* Add periodic overlap in z */
1253     for(ix=0; ix<pme->pmegrid_nx; ix++)
1254     {
1255         for(iy=0; iy<pme->pmegrid_ny; iy++)
1256         {
1257             for(iz=0; iz<overlap; iz++)
1258             {
1259                 pmegrid[(ix*pny+iy)*pnz+iz] +=
1260                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
1261             }
1262         }
1263     }
1264
1265     if (pme->nnodes_minor == 1)
1266     {
1267        for(ix=0; ix<pme->pmegrid_nx; ix++)
1268        {
1269            for(iy=0; iy<overlap; iy++)
1270            {
1271                for(iz=0; iz<nz; iz++)
1272                {
1273                    pmegrid[(ix*pny+iy)*pnz+iz] +=
1274                        pmegrid[(ix*pny+ny+iy)*pnz+iz];
1275                }
1276            }
1277        }
1278     }
1279
1280     if (pme->nnodes_major == 1)
1281     {
1282         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1283
1284         for(ix=0; ix<overlap; ix++)
1285         {
1286             for(iy=0; iy<ny_x; iy++)
1287             {
1288                 for(iz=0; iz<nz; iz++)
1289                 {
1290                     pmegrid[(ix*pny+iy)*pnz+iz] +=
1291                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
1292                 }
1293             }
1294         }
1295     }
1296 }
1297
1298
1299 static void
1300 unwrap_periodic_pmegrid(gmx_pme_t pme, real *pmegrid)
1301 {
1302     int     nx,ny,nz,pnx,pny,pnz,ny_x,overlap,ix;
1303
1304     nx = pme->nkx;
1305     ny = pme->nky;
1306     nz = pme->nkz;
1307
1308     pnx = pme->pmegrid_nx;
1309     pny = pme->pmegrid_ny;
1310     pnz = pme->pmegrid_nz;
1311
1312     overlap = pme->pme_order - 1;
1313
1314     if (pme->nnodes_major == 1)
1315     {
1316         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
1317
1318         for(ix=0; ix<overlap; ix++)
1319         {
1320             int iy,iz;
1321
1322             for(iy=0; iy<ny_x; iy++)
1323             {
1324                 for(iz=0; iz<nz; iz++)
1325                 {
1326                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
1327                         pmegrid[(ix*pny+iy)*pnz+iz];
1328                 }
1329             }
1330         }
1331     }
1332
1333     if (pme->nnodes_minor == 1)
1334     {
1335 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1336        for(ix=0; ix<pme->pmegrid_nx; ix++)
1337        {
1338            int iy,iz;
1339
1340            for(iy=0; iy<overlap; iy++)
1341            {
1342                for(iz=0; iz<nz; iz++)
1343                {
1344                    pmegrid[(ix*pny+ny+iy)*pnz+iz] =
1345                        pmegrid[(ix*pny+iy)*pnz+iz];
1346                }
1347            }
1348        }
1349     }
1350
1351     /* Copy periodic overlap in z */
1352 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
1353     for(ix=0; ix<pme->pmegrid_nx; ix++)
1354     {
1355         int iy,iz;
1356
1357         for(iy=0; iy<pme->pmegrid_ny; iy++)
1358         {
1359             for(iz=0; iz<overlap; iz++)
1360             {
1361                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
1362                     pmegrid[(ix*pny+iy)*pnz+iz];
1363             }
1364         }
1365     }
1366 }
1367
1368 static void clear_grid(int nx,int ny,int nz,real *grid,
1369                        ivec fs,int *flag,
1370                        int fx,int fy,int fz,
1371                        int order)
1372 {
1373     int nc,ncz;
1374     int fsx,fsy,fsz,gx,gy,gz,g0x,g0y,x,y,z;
1375     int flind;
1376
1377     nc  = 2 + (order - 2)/FLBS;
1378     ncz = 2 + (order - 2)/FLBSZ;
1379
1380     for(fsx=fx; fsx<fx+nc; fsx++)
1381     {
1382         for(fsy=fy; fsy<fy+nc; fsy++)
1383         {
1384             for(fsz=fz; fsz<fz+ncz; fsz++)
1385             {
1386                 flind = (fsx*fs[YY] + fsy)*fs[ZZ] + fsz;
1387                 if (flag[flind] == 0)
1388                 {
1389                     gx = fsx*FLBS;
1390                     gy = fsy*FLBS;
1391                     gz = fsz*FLBSZ;
1392                     g0x = (gx*ny + gy)*nz + gz;
1393                     for(x=0; x<FLBS; x++)
1394                     {
1395                         g0y = g0x;
1396                         for(y=0; y<FLBS; y++)
1397                         {
1398                             for(z=0; z<FLBSZ; z++)
1399                             {
1400                                 grid[g0y+z] = 0;
1401                             }
1402                             g0y += nz;
1403                         }
1404                         g0x += ny*nz;
1405                     }
1406
1407                     flag[flind] = 1;
1408                 }
1409             }
1410         }
1411     }
1412 }
1413
1414 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
1415 #define DO_BSPLINE(order)                            \
1416 for(ithx=0; (ithx<order); ithx++)                    \
1417 {                                                    \
1418     index_x = (i0+ithx)*pny*pnz;                     \
1419     valx    = qn*thx[ithx];                          \
1420                                                      \
1421     for(ithy=0; (ithy<order); ithy++)                \
1422     {                                                \
1423         valxy    = valx*thy[ithy];                   \
1424         index_xy = index_x+(j0+ithy)*pnz;            \
1425                                                      \
1426         for(ithz=0; (ithz<order); ithz++)            \
1427         {                                            \
1428             index_xyz        = index_xy+(k0+ithz);   \
1429             grid[index_xyz] += valxy*thz[ithz];      \
1430         }                                            \
1431     }                                                \
1432 }
1433
1434
1435 static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
1436                                      pme_atomcomm_t *atc, splinedata_t *spline,
1437                                      pme_spline_work_t *work)
1438 {
1439
1440     /* spread charges from home atoms to local grid */
1441     real     *grid;
1442     pme_overlap_t *ol;
1443     int      b,i,nn,n,ithx,ithy,ithz,i0,j0,k0;
1444     int *    idxptr;
1445     int      order,norder,index_x,index_xy,index_xyz;
1446     real     valx,valxy,qn;
1447     real     *thx,*thy,*thz;
1448     int      localsize, bndsize;
1449     int      pnx,pny,pnz,ndatatot;
1450     int      offx,offy,offz;
1451
1452     pnx = pmegrid->s[XX];
1453     pny = pmegrid->s[YY];
1454     pnz = pmegrid->s[ZZ];
1455
1456     offx = pmegrid->offset[XX];
1457     offy = pmegrid->offset[YY];
1458     offz = pmegrid->offset[ZZ];
1459
1460     ndatatot = pnx*pny*pnz;
1461     grid = pmegrid->grid;
1462     for(i=0;i<ndatatot;i++)
1463     {
1464         grid[i] = 0;
1465     }
1466     
1467     order = pmegrid->order;
1468
1469     for(nn=0; nn<spline->n; nn++)
1470     {
1471         n  = spline->ind[nn];
1472         qn = atc->q[n];
1473
1474         if (qn != 0)
1475         {
1476             idxptr = atc->idx[n];
1477             norder = nn*order;
1478
1479             i0   = idxptr[XX] - offx;
1480             j0   = idxptr[YY] - offy;
1481             k0   = idxptr[ZZ] - offz;
1482
1483             thx = spline->theta[XX] + norder;
1484             thy = spline->theta[YY] + norder;
1485             thz = spline->theta[ZZ] + norder;
1486
1487             switch (order) {
1488             case 4:
1489 #ifdef PME_SSE
1490 #ifdef PME_SSE_UNALIGNED
1491 #define PME_SPREAD_SSE_ORDER4
1492 #else
1493 #define PME_SPREAD_SSE_ALIGNED
1494 #define PME_ORDER 4
1495 #endif
1496 #include "pme_sse_single.h"
1497 #else
1498                 DO_BSPLINE(4);
1499 #endif
1500                 break;
1501             case 5:
1502 #ifdef PME_SSE
1503 #define PME_SPREAD_SSE_ALIGNED
1504 #define PME_ORDER 5
1505 #include "pme_sse_single.h"
1506 #else
1507                 DO_BSPLINE(5);
1508 #endif
1509                 break;
1510             default:
1511                 DO_BSPLINE(order);
1512                 break;
1513             }
1514         }
1515     }
1516 }
1517
1518 static void set_grid_alignment(int *pmegrid_nz,int pme_order)
1519 {
1520 #ifdef PME_SSE
1521     if (pme_order == 5
1522 #ifndef PME_SSE_UNALIGNED
1523         || pme_order == 4
1524 #endif
1525         )
1526     {
1527         /* Round nz up to a multiple of 4 to ensure alignment */
1528         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
1529     }
1530 #endif
1531 }
1532
1533 static void set_gridsize_alignment(int *gridsize,int pme_order)
1534 {
1535 #ifdef PME_SSE
1536 #ifndef PME_SSE_UNALIGNED
1537     if (pme_order == 4)
1538     {
1539         /* Add extra elements to ensured aligned operations do not go
1540          * beyond the allocated grid size.
1541          * Note that for pme_order=5, the pme grid z-size alignment
1542          * ensures that we will not go beyond the grid size.
1543          */
1544          *gridsize += 4;
1545     }
1546 #endif
1547 #endif
1548 }
1549
1550 static void pmegrid_init(pmegrid_t *grid,
1551                          int cx, int cy, int cz,
1552                          int x0, int y0, int z0,
1553                          int x1, int y1, int z1,
1554                          gmx_bool set_alignment,
1555                          int pme_order,
1556                          real *ptr)
1557 {
1558     int nz,gridsize;
1559
1560     grid->ci[XX] = cx;
1561     grid->ci[YY] = cy;
1562     grid->ci[ZZ] = cz;
1563     grid->offset[XX] = x0;
1564     grid->offset[YY] = y0;
1565     grid->offset[ZZ] = z0;
1566     grid->n[XX]      = x1 - x0 + pme_order - 1;
1567     grid->n[YY]      = y1 - y0 + pme_order - 1;
1568     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
1569     copy_ivec(grid->n,grid->s);
1570
1571     nz = grid->s[ZZ];
1572     set_grid_alignment(&nz,pme_order);
1573     if (set_alignment)
1574     {
1575         grid->s[ZZ] = nz;
1576     }
1577     else if (nz != grid->s[ZZ])
1578     {
1579         gmx_incons("pmegrid_init call with an unaligned z size");
1580     }
1581
1582     grid->order = pme_order;
1583     if (ptr == NULL)
1584     {
1585         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
1586         set_gridsize_alignment(&gridsize,pme_order);
1587         snew_aligned(grid->grid,gridsize,16);
1588     }
1589     else
1590     {
1591         grid->grid = ptr;
1592     }
1593 }
1594
1595 static int div_round_up(int enumerator,int denominator)
1596 {
1597     return (enumerator + denominator - 1)/denominator;
1598 }
1599
1600 static void make_subgrid_division(const ivec n,int ovl,int nthread,
1601                                   ivec nsub)
1602 {
1603     int gsize_opt,gsize;
1604     int nsx,nsy,nsz;
1605     char *env;
1606
1607     gsize_opt = -1;
1608     for(nsx=1; nsx<=nthread; nsx++)
1609     {
1610         if (nthread % nsx == 0)
1611         {
1612             for(nsy=1; nsy<=nthread; nsy++)
1613             {
1614                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
1615                 {
1616                     nsz = nthread/(nsx*nsy);
1617
1618                     /* Determine the number of grid points per thread */
1619                     gsize =
1620                         (div_round_up(n[XX],nsx) + ovl)*
1621                         (div_round_up(n[YY],nsy) + ovl)*
1622                         (div_round_up(n[ZZ],nsz) + ovl);
1623
1624                     /* Minimize the number of grids points per thread
1625                      * and, secondarily, the number of cuts in minor dimensions.
1626                      */
1627                     if (gsize_opt == -1 ||
1628                         gsize < gsize_opt ||
1629                         (gsize == gsize_opt &&
1630                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
1631                     {
1632                         nsub[XX] = nsx;
1633                         nsub[YY] = nsy;
1634                         nsub[ZZ] = nsz;
1635                         gsize_opt = gsize;
1636                     }
1637                 }
1638             }
1639         }
1640     }
1641
1642     env = getenv("GMX_PME_THREAD_DIVISION");
1643     if (env != NULL)
1644     {
1645         sscanf(env,"%d %d %d",&nsub[XX],&nsub[YY],&nsub[ZZ]);
1646     }
1647
1648     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
1649     {
1650         gmx_fatal(FARGS,"PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)",nsub[XX],nsub[YY],nsub[ZZ],nthread);
1651     }
1652 }
1653
1654 static void pmegrids_init(pmegrids_t *grids,
1655                           int nx,int ny,int nz,int nz_base,
1656                           int pme_order,
1657                           int nthread,
1658                           int overlap_x,
1659                           int overlap_y)
1660 {
1661     ivec n,n_base,g0,g1;
1662     int t,x,y,z,d,i,tfac;
1663     int max_comm_lines=-1;
1664
1665     n[XX] = nx - (pme_order - 1);
1666     n[YY] = ny - (pme_order - 1);
1667     n[ZZ] = nz - (pme_order - 1);
1668
1669     copy_ivec(n,n_base);
1670     n_base[ZZ] = nz_base;
1671
1672     pmegrid_init(&grids->grid,0,0,0,0,0,0,n[XX],n[YY],n[ZZ],FALSE,pme_order,
1673                  NULL);
1674
1675     grids->nthread = nthread;
1676
1677     make_subgrid_division(n_base,pme_order-1,grids->nthread,grids->nc);
1678
1679     if (grids->nthread > 1)
1680     {
1681         ivec nst;
1682         int gridsize;
1683
1684         for(d=0; d<DIM; d++)
1685         {
1686             nst[d] = div_round_up(n[d],grids->nc[d]) + pme_order - 1;
1687         }
1688         set_grid_alignment(&nst[ZZ],pme_order);
1689
1690         if (debug)
1691         {
1692             fprintf(debug,"pmegrid thread local division: %d x %d x %d\n",
1693                     grids->nc[XX],grids->nc[YY],grids->nc[ZZ]);
1694             fprintf(debug,"pmegrid %d %d %d max thread pmegrid %d %d %d\n",
1695                     nx,ny,nz,
1696                     nst[XX],nst[YY],nst[ZZ]);
1697         }
1698
1699         snew(grids->grid_th,grids->nthread);
1700         t = 0;
1701         gridsize = nst[XX]*nst[YY]*nst[ZZ];
1702         set_gridsize_alignment(&gridsize,pme_order);
1703         snew_aligned(grids->grid_all,
1704                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
1705                      16);
1706
1707         for(x=0; x<grids->nc[XX]; x++)
1708         {
1709             for(y=0; y<grids->nc[YY]; y++)
1710             {
1711                 for(z=0; z<grids->nc[ZZ]; z++)
1712                 {
1713                     pmegrid_init(&grids->grid_th[t],
1714                                  x,y,z,
1715                                  (n[XX]*(x  ))/grids->nc[XX],
1716                                  (n[YY]*(y  ))/grids->nc[YY],
1717                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
1718                                  (n[XX]*(x+1))/grids->nc[XX],
1719                                  (n[YY]*(y+1))/grids->nc[YY],
1720                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
1721                                  TRUE,
1722                                  pme_order,
1723                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
1724                     t++;
1725                 }
1726             }
1727         }
1728     }
1729
1730     snew(grids->g2t,DIM);
1731     tfac = 1;
1732     for(d=DIM-1; d>=0; d--)
1733     {
1734         snew(grids->g2t[d],n[d]);
1735         t = 0;
1736         for(i=0; i<n[d]; i++)
1737         {
1738             /* The second check should match the parameters
1739              * of the pmegrid_init call above.
1740              */
1741             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
1742             {
1743                 t++;
1744             }
1745             grids->g2t[d][i] = t*tfac;
1746         }
1747
1748         tfac *= grids->nc[d];
1749
1750         switch (d)
1751         {
1752         case XX: max_comm_lines = overlap_x;     break;
1753         case YY: max_comm_lines = overlap_y;     break;
1754         case ZZ: max_comm_lines = pme_order - 1; break;
1755         }
1756         grids->nthread_comm[d] = 0;
1757         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
1758                grids->nthread_comm[d] < grids->nc[d])
1759         {
1760             grids->nthread_comm[d]++;
1761         }
1762         if (debug != NULL)
1763         {
1764             fprintf(debug,"pmegrid thread grid communication range in %c: %d\n",
1765                     'x'+d,grids->nthread_comm[d]);
1766         }
1767         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
1768          * work, but this is not a problematic restriction.
1769          */
1770         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
1771         {
1772             gmx_fatal(FARGS,"Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME",grids->nthread);
1773         }
1774     }
1775 }
1776
1777
1778 static void pmegrids_destroy(pmegrids_t *grids)
1779 {
1780     int t;
1781
1782     if (grids->grid.grid != NULL)
1783     {
1784         sfree(grids->grid.grid);
1785
1786         if (grids->nthread > 0)
1787         {
1788             for(t=0; t<grids->nthread; t++)
1789             {
1790                 sfree(grids->grid_th[t].grid);
1791             }
1792             sfree(grids->grid_th);
1793         }
1794     }
1795 }
1796
1797
1798 static void realloc_work(pme_work_t *work,int nkx)
1799 {
1800     if (nkx > work->nalloc)
1801     {
1802         work->nalloc = nkx;
1803         srenew(work->mhx  ,work->nalloc);
1804         srenew(work->mhy  ,work->nalloc);
1805         srenew(work->mhz  ,work->nalloc);
1806         srenew(work->m2   ,work->nalloc);
1807         /* Allocate an aligned pointer for SSE operations, including 3 extra
1808          * elements at the end since SSE operates on 4 elements at a time.
1809          */
1810         sfree_aligned(work->denom);
1811         sfree_aligned(work->tmp1);
1812         sfree_aligned(work->eterm);
1813         snew_aligned(work->denom,work->nalloc+3,16);
1814         snew_aligned(work->tmp1 ,work->nalloc+3,16);
1815         snew_aligned(work->eterm,work->nalloc+3,16);
1816         srenew(work->m2inv,work->nalloc);
1817     }
1818 }
1819
1820
1821 static void free_work(pme_work_t *work)
1822 {
1823     sfree(work->mhx);
1824     sfree(work->mhy);
1825     sfree(work->mhz);
1826     sfree(work->m2);
1827     sfree_aligned(work->denom);
1828     sfree_aligned(work->tmp1);
1829     sfree_aligned(work->eterm);
1830     sfree(work->m2inv);
1831 }
1832
1833
1834 #ifdef PME_SSE
1835     /* Calculate exponentials through SSE in float precision */
1836 inline static void calc_exponentials(int start, int end, real f, real *d_aligned, real *r_aligned, real *e_aligned)
1837 {
1838     {
1839         const __m128 two = _mm_set_ps(2.0f,2.0f,2.0f,2.0f);
1840         __m128 f_sse;
1841         __m128 lu;
1842         __m128 tmp_d1,d_inv,tmp_r,tmp_e;
1843         int kx;
1844         f_sse = _mm_load1_ps(&f);
1845         for(kx=0; kx<end; kx+=4)
1846         {
1847             tmp_d1   = _mm_load_ps(d_aligned+kx);
1848             lu       = _mm_rcp_ps(tmp_d1);
1849             d_inv    = _mm_mul_ps(lu,_mm_sub_ps(two,_mm_mul_ps(lu,tmp_d1)));
1850             tmp_r    = _mm_load_ps(r_aligned+kx);
1851             tmp_r    = gmx_mm_exp_ps(tmp_r);
1852             tmp_e    = _mm_mul_ps(f_sse,d_inv);
1853             tmp_e    = _mm_mul_ps(tmp_e,tmp_r);
1854             _mm_store_ps(e_aligned+kx,tmp_e);
1855         }
1856     }
1857 }
1858 #else
1859 inline static void calc_exponentials(int start, int end, real f, real *d, real *r, real *e)
1860 {
1861     int kx;
1862     for(kx=start; kx<end; kx++)
1863     {
1864         d[kx] = 1.0/d[kx];
1865     }
1866     for(kx=start; kx<end; kx++)
1867     {
1868         r[kx] = exp(r[kx]);
1869     }
1870     for(kx=start; kx<end; kx++)
1871     {
1872         e[kx] = f*r[kx]*d[kx];
1873     }
1874 }
1875 #endif
1876
1877
1878 static int solve_pme_yzx(gmx_pme_t pme,t_complex *grid,
1879                          real ewaldcoeff,real vol,
1880                          gmx_bool bEnerVir,
1881                          int nthread,int thread)
1882 {
1883     /* do recip sum over local cells in grid */
1884     /* y major, z middle, x minor or continuous */
1885     t_complex *p0;
1886     int     kx,ky,kz,maxkx,maxky,maxkz;
1887     int     nx,ny,nz,iyz0,iyz1,iyz,iy,iz,kxstart,kxend;
1888     real    mx,my,mz;
1889     real    factor=M_PI*M_PI/(ewaldcoeff*ewaldcoeff);
1890     real    ets2,struct2,vfactor,ets2vf;
1891     real    d1,d2,energy=0;
1892     real    by,bz;
1893     real    virxx=0,virxy=0,virxz=0,viryy=0,viryz=0,virzz=0;
1894     real    rxx,ryx,ryy,rzx,rzy,rzz;
1895     pme_work_t *work;
1896     real    *mhx,*mhy,*mhz,*m2,*denom,*tmp1,*eterm,*m2inv;
1897     real    mhxk,mhyk,mhzk,m2k;
1898     real    corner_fac;
1899     ivec    complex_order;
1900     ivec    local_ndata,local_offset,local_size;
1901     real    elfac;
1902
1903     elfac = ONE_4PI_EPS0/pme->epsilon_r;
1904
1905     nx = pme->nkx;
1906     ny = pme->nky;
1907     nz = pme->nkz;
1908
1909     /* Dimensions should be identical for A/B grid, so we just use A here */
1910     gmx_parallel_3dfft_complex_limits(pme->pfft_setupA,
1911                                       complex_order,
1912                                       local_ndata,
1913                                       local_offset,
1914                                       local_size);
1915
1916     rxx = pme->recipbox[XX][XX];
1917     ryx = pme->recipbox[YY][XX];
1918     ryy = pme->recipbox[YY][YY];
1919     rzx = pme->recipbox[ZZ][XX];
1920     rzy = pme->recipbox[ZZ][YY];
1921     rzz = pme->recipbox[ZZ][ZZ];
1922
1923     maxkx = (nx+1)/2;
1924     maxky = (ny+1)/2;
1925     maxkz = nz/2+1;
1926
1927     work = &pme->work[thread];
1928     mhx   = work->mhx;
1929     mhy   = work->mhy;
1930     mhz   = work->mhz;
1931     m2    = work->m2;
1932     denom = work->denom;
1933     tmp1  = work->tmp1;
1934     eterm = work->eterm;
1935     m2inv = work->m2inv;
1936
1937     iyz0 = local_ndata[YY]*local_ndata[ZZ]* thread   /nthread;
1938     iyz1 = local_ndata[YY]*local_ndata[ZZ]*(thread+1)/nthread;
1939
1940     for(iyz=iyz0; iyz<iyz1; iyz++)
1941     {
1942         iy = iyz/local_ndata[ZZ];
1943         iz = iyz - iy*local_ndata[ZZ];
1944
1945         ky = iy + local_offset[YY];
1946
1947         if (ky < maxky)
1948         {
1949             my = ky;
1950         }
1951         else
1952         {
1953             my = (ky - ny);
1954         }
1955
1956         by = M_PI*vol*pme->bsp_mod[YY][ky];
1957
1958         kz = iz + local_offset[ZZ];
1959
1960         mz = kz;
1961
1962         bz = pme->bsp_mod[ZZ][kz];
1963
1964         /* 0.5 correction for corner points */
1965         corner_fac = 1;
1966         if (kz == 0 || kz == (nz+1)/2)
1967         {
1968             corner_fac = 0.5;
1969         }
1970
1971         p0 = grid + iy*local_size[ZZ]*local_size[XX] + iz*local_size[XX];
1972
1973         /* We should skip the k-space point (0,0,0) */
1974         if (local_offset[XX] > 0 || ky > 0 || kz > 0)
1975         {
1976             kxstart = local_offset[XX];
1977         }
1978         else
1979         {
1980             kxstart = local_offset[XX] + 1;
1981             p0++;
1982         }
1983         kxend = local_offset[XX] + local_ndata[XX];
1984
1985         if (bEnerVir)
1986         {
1987             /* More expensive inner loop, especially because of the storage
1988              * of the mh elements in array's.
1989              * Because x is the minor grid index, all mh elements
1990              * depend on kx for triclinic unit cells.
1991              */
1992
1993                 /* Two explicit loops to avoid a conditional inside the loop */
1994             for(kx=kxstart; kx<maxkx; kx++)
1995             {
1996                 mx = kx;
1997
1998                 mhxk      = mx * rxx;
1999                 mhyk      = mx * ryx + my * ryy;
2000                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2001                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2002                 mhx[kx]   = mhxk;
2003                 mhy[kx]   = mhyk;
2004                 mhz[kx]   = mhzk;
2005                 m2[kx]    = m2k;
2006                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2007                 tmp1[kx]  = -factor*m2k;
2008             }
2009
2010             for(kx=maxkx; kx<kxend; kx++)
2011             {
2012                 mx = (kx - nx);
2013
2014                 mhxk      = mx * rxx;
2015                 mhyk      = mx * ryx + my * ryy;
2016                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2017                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2018                 mhx[kx]   = mhxk;
2019                 mhy[kx]   = mhyk;
2020                 mhz[kx]   = mhzk;
2021                 m2[kx]    = m2k;
2022                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2023                 tmp1[kx]  = -factor*m2k;
2024             }
2025
2026             for(kx=kxstart; kx<kxend; kx++)
2027             {
2028                 m2inv[kx] = 1.0/m2[kx];
2029             }
2030
2031             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2032
2033             for(kx=kxstart; kx<kxend; kx++,p0++)
2034             {
2035                 d1      = p0->re;
2036                 d2      = p0->im;
2037
2038                 p0->re  = d1*eterm[kx];
2039                 p0->im  = d2*eterm[kx];
2040
2041                 struct2 = 2.0*(d1*d1+d2*d2);
2042
2043                 tmp1[kx] = eterm[kx]*struct2;
2044             }
2045
2046             for(kx=kxstart; kx<kxend; kx++)
2047             {
2048                 ets2     = corner_fac*tmp1[kx];
2049                 vfactor  = (factor*m2[kx] + 1.0)*2.0*m2inv[kx];
2050                 energy  += ets2;
2051
2052                 ets2vf   = ets2*vfactor;
2053                 virxx   += ets2vf*mhx[kx]*mhx[kx] - ets2;
2054                 virxy   += ets2vf*mhx[kx]*mhy[kx];
2055                 virxz   += ets2vf*mhx[kx]*mhz[kx];
2056                 viryy   += ets2vf*mhy[kx]*mhy[kx] - ets2;
2057                 viryz   += ets2vf*mhy[kx]*mhz[kx];
2058                 virzz   += ets2vf*mhz[kx]*mhz[kx] - ets2;
2059             }
2060         }
2061         else
2062         {
2063             /* We don't need to calculate the energy and the virial.
2064              * In this case the triclinic overhead is small.
2065              */
2066
2067             /* Two explicit loops to avoid a conditional inside the loop */
2068
2069             for(kx=kxstart; kx<maxkx; kx++)
2070             {
2071                 mx = kx;
2072
2073                 mhxk      = mx * rxx;
2074                 mhyk      = mx * ryx + my * ryy;
2075                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2076                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2077                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2078                 tmp1[kx]  = -factor*m2k;
2079             }
2080
2081             for(kx=maxkx; kx<kxend; kx++)
2082             {
2083                 mx = (kx - nx);
2084
2085                 mhxk      = mx * rxx;
2086                 mhyk      = mx * ryx + my * ryy;
2087                 mhzk      = mx * rzx + my * rzy + mz * rzz;
2088                 m2k       = mhxk*mhxk + mhyk*mhyk + mhzk*mhzk;
2089                 denom[kx] = m2k*bz*by*pme->bsp_mod[XX][kx];
2090                 tmp1[kx]  = -factor*m2k;
2091             }
2092
2093             calc_exponentials(kxstart,kxend,elfac,denom,tmp1,eterm);
2094
2095             for(kx=kxstart; kx<kxend; kx++,p0++)
2096             {
2097                 d1      = p0->re;
2098                 d2      = p0->im;
2099
2100                 p0->re  = d1*eterm[kx];
2101                 p0->im  = d2*eterm[kx];
2102             }
2103         }
2104     }
2105
2106     if (bEnerVir)
2107     {
2108         /* Update virial with local values.
2109          * The virial is symmetric by definition.
2110          * this virial seems ok for isotropic scaling, but I'm
2111          * experiencing problems on semiisotropic membranes.
2112          * IS THAT COMMENT STILL VALID??? (DvdS, 2001/02/07).
2113          */
2114         work->vir[XX][XX] = 0.25*virxx;
2115         work->vir[YY][YY] = 0.25*viryy;
2116         work->vir[ZZ][ZZ] = 0.25*virzz;
2117         work->vir[XX][YY] = work->vir[YY][XX] = 0.25*virxy;
2118         work->vir[XX][ZZ] = work->vir[ZZ][XX] = 0.25*virxz;
2119         work->vir[YY][ZZ] = work->vir[ZZ][YY] = 0.25*viryz;
2120
2121         /* This energy should be corrected for a charged system */
2122         work->energy = 0.5*energy;
2123     }
2124
2125     /* Return the loop count */
2126     return local_ndata[YY]*local_ndata[XX];
2127 }
2128
2129 static void get_pme_ener_vir(const gmx_pme_t pme,int nthread,
2130                              real *mesh_energy,matrix vir)
2131 {
2132     /* This function sums output over threads
2133      * and should therefore only be called after thread synchronization.
2134      */
2135     int thread;
2136
2137     *mesh_energy = pme->work[0].energy;
2138     copy_mat(pme->work[0].vir,vir);
2139
2140     for(thread=1; thread<nthread; thread++)
2141     {
2142         *mesh_energy += pme->work[thread].energy;
2143         m_add(vir,pme->work[thread].vir,vir);
2144     }
2145 }
2146
2147 #define DO_FSPLINE(order)                      \
2148 for(ithx=0; (ithx<order); ithx++)              \
2149 {                                              \
2150     index_x = (i0+ithx)*pny*pnz;               \
2151     tx      = thx[ithx];                       \
2152     dx      = dthx[ithx];                      \
2153                                                \
2154     for(ithy=0; (ithy<order); ithy++)          \
2155     {                                          \
2156         index_xy = index_x+(j0+ithy)*pnz;      \
2157         ty       = thy[ithy];                  \
2158         dy       = dthy[ithy];                 \
2159         fxy1     = fz1 = 0;                    \
2160                                                \
2161         for(ithz=0; (ithz<order); ithz++)      \
2162         {                                      \
2163             gval  = grid[index_xy+(k0+ithz)];  \
2164             fxy1 += thz[ithz]*gval;            \
2165             fz1  += dthz[ithz]*gval;           \
2166         }                                      \
2167         fx += dx*ty*fxy1;                      \
2168         fy += tx*dy*fxy1;                      \
2169         fz += tx*ty*fz1;                       \
2170     }                                          \
2171 }
2172
2173
2174 static void gather_f_bsplines(gmx_pme_t pme,real *grid,
2175                               gmx_bool bClearF,pme_atomcomm_t *atc,
2176                               splinedata_t *spline,
2177                               real scale)
2178 {
2179     /* sum forces for local particles */
2180     int     nn,n,ithx,ithy,ithz,i0,j0,k0;
2181     int     index_x,index_xy;
2182     int     nx,ny,nz,pnx,pny,pnz;
2183     int *   idxptr;
2184     real    tx,ty,dx,dy,qn;
2185     real    fx,fy,fz,gval;
2186     real    fxy1,fz1;
2187     real    *thx,*thy,*thz,*dthx,*dthy,*dthz;
2188     int     norder;
2189     real    rxx,ryx,ryy,rzx,rzy,rzz;
2190     int     order;
2191
2192     pme_spline_work_t *work;
2193
2194     work = pme->spline_work;
2195
2196     order = pme->pme_order;
2197     thx   = spline->theta[XX];
2198     thy   = spline->theta[YY];
2199     thz   = spline->theta[ZZ];
2200     dthx  = spline->dtheta[XX];
2201     dthy  = spline->dtheta[YY];
2202     dthz  = spline->dtheta[ZZ];
2203     nx    = pme->nkx;
2204     ny    = pme->nky;
2205     nz    = pme->nkz;
2206     pnx   = pme->pmegrid_nx;
2207     pny   = pme->pmegrid_ny;
2208     pnz   = pme->pmegrid_nz;
2209
2210     rxx   = pme->recipbox[XX][XX];
2211     ryx   = pme->recipbox[YY][XX];
2212     ryy   = pme->recipbox[YY][YY];
2213     rzx   = pme->recipbox[ZZ][XX];
2214     rzy   = pme->recipbox[ZZ][YY];
2215     rzz   = pme->recipbox[ZZ][ZZ];
2216
2217     for(nn=0; nn<spline->n; nn++)
2218     {
2219         n  = spline->ind[nn];
2220         qn = scale*atc->q[n];
2221
2222         if (bClearF)
2223         {
2224             atc->f[n][XX] = 0;
2225             atc->f[n][YY] = 0;
2226             atc->f[n][ZZ] = 0;
2227         }
2228         if (qn != 0)
2229         {
2230             fx     = 0;
2231             fy     = 0;
2232             fz     = 0;
2233             idxptr = atc->idx[n];
2234             norder = nn*order;
2235
2236             i0   = idxptr[XX];
2237             j0   = idxptr[YY];
2238             k0   = idxptr[ZZ];
2239
2240             /* Pointer arithmetic alert, next six statements */
2241             thx  = spline->theta[XX] + norder;
2242             thy  = spline->theta[YY] + norder;
2243             thz  = spline->theta[ZZ] + norder;
2244             dthx = spline->dtheta[XX] + norder;
2245             dthy = spline->dtheta[YY] + norder;
2246             dthz = spline->dtheta[ZZ] + norder;
2247
2248             switch (order) {
2249             case 4:
2250 #ifdef PME_SSE
2251 #ifdef PME_SSE_UNALIGNED
2252 #define PME_GATHER_F_SSE_ORDER4
2253 #else
2254 #define PME_GATHER_F_SSE_ALIGNED
2255 #define PME_ORDER 4
2256 #endif
2257 #include "pme_sse_single.h"
2258 #else
2259                 DO_FSPLINE(4);
2260 #endif
2261                 break;
2262             case 5:
2263 #ifdef PME_SSE
2264 #define PME_GATHER_F_SSE_ALIGNED
2265 #define PME_ORDER 5
2266 #include "pme_sse_single.h"
2267 #else
2268                 DO_FSPLINE(5);
2269 #endif
2270                 break;
2271             default:
2272                 DO_FSPLINE(order);
2273                 break;
2274             }
2275
2276             atc->f[n][XX] += -qn*( fx*nx*rxx );
2277             atc->f[n][YY] += -qn*( fx*nx*ryx + fy*ny*ryy );
2278             atc->f[n][ZZ] += -qn*( fx*nx*rzx + fy*ny*rzy + fz*nz*rzz );
2279         }
2280     }
2281     /* Since the energy and not forces are interpolated
2282      * the net force might not be exactly zero.
2283      * This can be solved by also interpolating F, but
2284      * that comes at a cost.
2285      * A better hack is to remove the net force every
2286      * step, but that must be done at a higher level
2287      * since this routine doesn't see all atoms if running
2288      * in parallel. Don't know how important it is?  EL 990726
2289      */
2290 }
2291
2292
2293 static real gather_energy_bsplines(gmx_pme_t pme,real *grid,
2294                                    pme_atomcomm_t *atc)
2295 {
2296     splinedata_t *spline;
2297     int     n,ithx,ithy,ithz,i0,j0,k0;
2298     int     index_x,index_xy;
2299     int *   idxptr;
2300     real    energy,pot,tx,ty,qn,gval;
2301     real    *thx,*thy,*thz;
2302     int     norder;
2303     int     order;
2304
2305     spline = &atc->spline[0];
2306
2307     order = pme->pme_order;
2308
2309     energy = 0;
2310     for(n=0; (n<atc->n); n++) {
2311         qn      = atc->q[n];
2312
2313         if (qn != 0) {
2314             idxptr = atc->idx[n];
2315             norder = n*order;
2316
2317             i0   = idxptr[XX];
2318             j0   = idxptr[YY];
2319             k0   = idxptr[ZZ];
2320
2321             /* Pointer arithmetic alert, next three statements */
2322             thx  = spline->theta[XX] + norder;
2323             thy  = spline->theta[YY] + norder;
2324             thz  = spline->theta[ZZ] + norder;
2325
2326             pot = 0;
2327             for(ithx=0; (ithx<order); ithx++)
2328             {
2329                 index_x = (i0+ithx)*pme->pmegrid_ny*pme->pmegrid_nz;
2330                 tx      = thx[ithx];
2331
2332                 for(ithy=0; (ithy<order); ithy++)
2333                 {
2334                     index_xy = index_x+(j0+ithy)*pme->pmegrid_nz;
2335                     ty       = thy[ithy];
2336
2337                     for(ithz=0; (ithz<order); ithz++)
2338                     {
2339                         gval  = grid[index_xy+(k0+ithz)];
2340                         pot  += tx*ty*thz[ithz]*gval;
2341                     }
2342
2343                 }
2344             }
2345
2346             energy += pot*qn;
2347         }
2348     }
2349
2350     return energy;
2351 }
2352
2353 /* Macro to force loop unrolling by fixing order.
2354  * This gives a significant performance gain.
2355  */
2356 #define CALC_SPLINE(order)                     \
2357 {                                              \
2358     int j,k,l;                                 \
2359     real dr,div;                               \
2360     real data[PME_ORDER_MAX];                  \
2361     real ddata[PME_ORDER_MAX];                 \
2362                                                \
2363     for(j=0; (j<DIM); j++)                     \
2364     {                                          \
2365         dr  = xptr[j];                         \
2366                                                \
2367         /* dr is relative offset from lower cell limit */ \
2368         data[order-1] = 0;                     \
2369         data[1] = dr;                          \
2370         data[0] = 1 - dr;                      \
2371                                                \
2372         for(k=3; (k<order); k++)               \
2373         {                                      \
2374             div = 1.0/(k - 1.0);               \
2375             data[k-1] = div*dr*data[k-2];      \
2376             for(l=1; (l<(k-1)); l++)           \
2377             {                                  \
2378                 data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
2379                                    data[k-l-1]);                \
2380             }                                  \
2381             data[0] = div*(1-dr)*data[0];      \
2382         }                                      \
2383         /* differentiate */                    \
2384         ddata[0] = -data[0];                   \
2385         for(k=1; (k<order); k++)               \
2386         {                                      \
2387             ddata[k] = data[k-1] - data[k];    \
2388         }                                      \
2389                                                \
2390         div = 1.0/(order - 1);                 \
2391         data[order-1] = div*dr*data[order-2];  \
2392         for(l=1; (l<(order-1)); l++)           \
2393         {                                      \
2394             data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
2395                                (order-l-dr)*data[order-l-1]); \
2396         }                                      \
2397         data[0] = div*(1 - dr)*data[0];        \
2398                                                \
2399         for(k=0; k<order; k++)                 \
2400         {                                      \
2401             theta[j][i*order+k]  = data[k];    \
2402             dtheta[j][i*order+k] = ddata[k];   \
2403         }                                      \
2404     }                                          \
2405 }
2406
2407 void make_bsplines(splinevec theta,splinevec dtheta,int order,
2408                    rvec fractx[],int nr,int ind[],real charge[],
2409                    gmx_bool bFreeEnergy)
2410 {
2411     /* construct splines for local atoms */
2412     int  i,ii;
2413     real *xptr;
2414
2415     for(i=0; i<nr; i++)
2416     {
2417         /* With free energy we do not use the charge check.
2418          * In most cases this will be more efficient than calling make_bsplines
2419          * twice, since usually more than half the particles have charges.
2420          */
2421         ii = ind[i];
2422         if (bFreeEnergy || charge[ii] != 0.0) {
2423             xptr = fractx[ii];
2424             switch(order) {
2425             case 4:  CALC_SPLINE(4);     break;
2426             case 5:  CALC_SPLINE(5);     break;
2427             default: CALC_SPLINE(order); break;
2428             }
2429         }
2430     }
2431 }
2432
2433
2434 void make_dft_mod(real *mod,real *data,int ndata)
2435 {
2436   int i,j;
2437   real sc,ss,arg;
2438
2439   for(i=0;i<ndata;i++) {
2440     sc=ss=0;
2441     for(j=0;j<ndata;j++) {
2442       arg=(2.0*M_PI*i*j)/ndata;
2443       sc+=data[j]*cos(arg);
2444       ss+=data[j]*sin(arg);
2445     }
2446     mod[i]=sc*sc+ss*ss;
2447   }
2448   for(i=0;i<ndata;i++)
2449     if(mod[i]<1e-7)
2450       mod[i]=(mod[i-1]+mod[i+1])*0.5;
2451 }
2452
2453
2454 static void make_bspline_moduli(splinevec bsp_mod,
2455                                 int nx,int ny,int nz,int order)
2456 {
2457   int nmax=max(nx,max(ny,nz));
2458   real *data,*ddata,*bsp_data;
2459   int i,k,l;
2460   real div;
2461
2462   snew(data,order);
2463   snew(ddata,order);
2464   snew(bsp_data,nmax);
2465
2466   data[order-1]=0;
2467   data[1]=0;
2468   data[0]=1;
2469
2470   for(k=3;k<order;k++) {
2471     div=1.0/(k-1.0);
2472     data[k-1]=0;
2473     for(l=1;l<(k-1);l++)
2474       data[k-l-1]=div*(l*data[k-l-2]+(k-l)*data[k-l-1]);
2475     data[0]=div*data[0];
2476   }
2477   /* differentiate */
2478   ddata[0]=-data[0];
2479   for(k=1;k<order;k++)
2480     ddata[k]=data[k-1]-data[k];
2481   div=1.0/(order-1);
2482   data[order-1]=0;
2483   for(l=1;l<(order-1);l++)
2484     data[order-l-1]=div*(l*data[order-l-2]+(order-l)*data[order-l-1]);
2485   data[0]=div*data[0];
2486
2487   for(i=0;i<nmax;i++)
2488     bsp_data[i]=0;
2489   for(i=1;i<=order;i++)
2490     bsp_data[i]=data[i-1];
2491
2492   make_dft_mod(bsp_mod[XX],bsp_data,nx);
2493   make_dft_mod(bsp_mod[YY],bsp_data,ny);
2494   make_dft_mod(bsp_mod[ZZ],bsp_data,nz);
2495
2496   sfree(data);
2497   sfree(ddata);
2498   sfree(bsp_data);
2499 }
2500
2501
2502 /* Return the P3M optimal influence function */
2503 static double do_p3m_influence(double z, int order)
2504 {
2505     double z2,z4;
2506
2507     z2 = z*z;
2508     z4 = z2*z2;
2509
2510     /* The formula and most constants can be found in:
2511      * Ballenegger et al., JCTC 8, 936 (2012)
2512      */
2513     switch(order)
2514     {
2515     case 2:
2516         return 1.0 - 2.0*z2/3.0;
2517         break;
2518     case 3:
2519         return 1.0 - z2 + 2.0*z4/15.0;
2520         break;
2521     case 4:
2522         return 1.0 - 4.0*z2/3.0 + 2.0*z4/5.0 + 4.0*z2*z4/315.0;
2523         break;
2524     case 5:
2525         return 1.0 - 5.0*z2/3.0 + 7.0*z4/9.0 - 17.0*z2*z4/189.0 + 2.0*z4*z4/2835.0;
2526         break;
2527     case 6:
2528         return 1.0 - 2.0*z2 + 19.0*z4/15.0 - 256.0*z2*z4/945.0 + 62.0*z4*z4/4725.0 + 4.0*z2*z4*z4/155925.0;
2529         break;
2530     case 7:
2531         return 1.0 - 7.0*z2/3.0 + 28.0*z4/15.0 - 16.0*z2*z4/27.0 + 26.0*z4*z4/405.0 - 2.0*z2*z4*z4/1485.0 + 4.0*z4*z4*z4/6081075.0;
2532     case 8:
2533         return 1.0 - 8.0*z2/3.0 + 116.0*z4/45.0 - 344.0*z2*z4/315.0 + 914.0*z4*z4/4725.0 - 248.0*z4*z4*z2/22275.0 + 21844.0*z4*z4*z4/212837625.0 - 8.0*z4*z4*z4*z2/638512875.0;
2534         break;
2535     }
2536
2537     return 0.0;
2538 }
2539
2540 /* Calculate the P3M B-spline moduli for one dimension */
2541 static void make_p3m_bspline_moduli_dim(real *bsp_mod,int n,int order)
2542 {
2543     double zarg,zai,sinzai,infl;
2544     int    maxk,i;
2545
2546     if (order > 8)
2547     {
2548         gmx_fatal(FARGS,"The current P3M code only supports orders up to 8");
2549     }
2550
2551     zarg = M_PI/n;
2552
2553     maxk = (n + 1)/2;
2554
2555     for(i=-maxk; i<0; i++)
2556     {
2557         zai    = zarg*i;
2558         sinzai = sin(zai);
2559         infl   = do_p3m_influence(sinzai,order);
2560         bsp_mod[n+i] = infl*infl*pow(sinzai/zai,-2.0*order);
2561     }
2562     bsp_mod[0] = 1.0;
2563     for(i=1; i<maxk; i++)
2564     {
2565         zai    = zarg*i;
2566         sinzai = sin(zai);
2567         infl   = do_p3m_influence(sinzai,order);
2568         bsp_mod[i] = infl*infl*pow(sinzai/zai,-2.0*order);
2569     }
2570 }
2571
2572 /* Calculate the P3M B-spline moduli */
2573 static void make_p3m_bspline_moduli(splinevec bsp_mod,
2574                                     int nx,int ny,int nz,int order)
2575 {
2576     make_p3m_bspline_moduli_dim(bsp_mod[XX],nx,order);
2577     make_p3m_bspline_moduli_dim(bsp_mod[YY],ny,order);
2578     make_p3m_bspline_moduli_dim(bsp_mod[ZZ],nz,order);
2579 }
2580
2581
2582 static void setup_coordinate_communication(pme_atomcomm_t *atc)
2583 {
2584   int nslab,n,i;
2585   int fw,bw;
2586
2587   nslab = atc->nslab;
2588
2589   n = 0;
2590   for(i=1; i<=nslab/2; i++) {
2591     fw = (atc->nodeid + i) % nslab;
2592     bw = (atc->nodeid - i + nslab) % nslab;
2593     if (n < nslab - 1) {
2594       atc->node_dest[n] = fw;
2595       atc->node_src[n]  = bw;
2596       n++;
2597     }
2598     if (n < nslab - 1) {
2599       atc->node_dest[n] = bw;
2600       atc->node_src[n]  = fw;
2601       n++;
2602     }
2603   }
2604 }
2605
2606 int gmx_pme_destroy(FILE *log,gmx_pme_t *pmedata)
2607 {
2608     int thread;
2609
2610     if(NULL != log)
2611     {
2612         fprintf(log,"Destroying PME data structures.\n");
2613     }
2614
2615     sfree((*pmedata)->nnx);
2616     sfree((*pmedata)->nny);
2617     sfree((*pmedata)->nnz);
2618
2619     pmegrids_destroy(&(*pmedata)->pmegridA);
2620
2621     sfree((*pmedata)->fftgridA);
2622     sfree((*pmedata)->cfftgridA);
2623     gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupA);
2624
2625     if ((*pmedata)->pmegridB.grid.grid != NULL)
2626     {
2627         pmegrids_destroy(&(*pmedata)->pmegridB);
2628         sfree((*pmedata)->fftgridB);
2629         sfree((*pmedata)->cfftgridB);
2630         gmx_parallel_3dfft_destroy((*pmedata)->pfft_setupB);
2631     }
2632     for(thread=0; thread<(*pmedata)->nthread; thread++)
2633     {
2634         free_work(&(*pmedata)->work[thread]);
2635     }
2636     sfree((*pmedata)->work);
2637
2638     sfree(*pmedata);
2639     *pmedata = NULL;
2640
2641   return 0;
2642 }
2643
2644 static int mult_up(int n,int f)
2645 {
2646     return ((n + f - 1)/f)*f;
2647 }
2648
2649
2650 static double pme_load_imbalance(gmx_pme_t pme)
2651 {
2652     int    nma,nmi;
2653     double n1,n2,n3;
2654
2655     nma = pme->nnodes_major;
2656     nmi = pme->nnodes_minor;
2657
2658     n1 = mult_up(pme->nkx,nma)*mult_up(pme->nky,nmi)*pme->nkz;
2659     n2 = mult_up(pme->nkx,nma)*mult_up(pme->nkz,nmi)*pme->nky;
2660     n3 = mult_up(pme->nky,nma)*mult_up(pme->nkz,nmi)*pme->nkx;
2661
2662     /* pme_solve is roughly double the cost of an fft */
2663
2664     return (n1 + n2 + 3*n3)/(double)(6*pme->nkx*pme->nky*pme->nkz);
2665 }
2666
2667 static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
2668                           int dimind,gmx_bool bSpread)
2669 {
2670     int nk,k,s,thread;
2671
2672     atc->dimind = dimind;
2673     atc->nslab  = 1;
2674     atc->nodeid = 0;
2675     atc->pd_nalloc = 0;
2676 #ifdef GMX_MPI
2677     if (pme->nnodes > 1)
2678     {
2679         atc->mpi_comm = pme->mpi_comm_d[dimind];
2680         MPI_Comm_size(atc->mpi_comm,&atc->nslab);
2681         MPI_Comm_rank(atc->mpi_comm,&atc->nodeid);
2682     }
2683     if (debug)
2684     {
2685         fprintf(debug,"For PME atom communication in dimind %d: nslab %d rank %d\n",atc->dimind,atc->nslab,atc->nodeid);
2686     }
2687 #endif
2688
2689     atc->bSpread   = bSpread;
2690     atc->pme_order = pme->pme_order;
2691
2692     if (atc->nslab > 1)
2693     {
2694         /* These three allocations are not required for particle decomp. */
2695         snew(atc->node_dest,atc->nslab);
2696         snew(atc->node_src,atc->nslab);
2697         setup_coordinate_communication(atc);
2698
2699         snew(atc->count_thread,pme->nthread);
2700         for(thread=0; thread<pme->nthread; thread++)
2701         {
2702             snew(atc->count_thread[thread],atc->nslab);
2703         }
2704         atc->count = atc->count_thread[0];
2705         snew(atc->rcount,atc->nslab);
2706         snew(atc->buf_index,atc->nslab);
2707     }
2708
2709     atc->nthread = pme->nthread;
2710     if (atc->nthread > 1)
2711     {
2712         snew(atc->thread_plist,atc->nthread);
2713     }
2714     snew(atc->spline,atc->nthread);
2715     for(thread=0; thread<atc->nthread; thread++)
2716     {
2717         if (atc->nthread > 1)
2718         {
2719             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
2720             atc->thread_plist[thread].n += GMX_CACHE_SEP;
2721         }
2722         snew(atc->spline[thread].thread_one,pme->nthread);
2723         atc->spline[thread].thread_one[thread] = 1;
2724     }
2725 }
2726
2727 static void
2728 init_overlap_comm(pme_overlap_t *  ol,
2729                   int              norder,
2730 #ifdef GMX_MPI
2731                   MPI_Comm         comm,
2732 #endif
2733                   int              nnodes,
2734                   int              nodeid,
2735                   int              ndata,
2736                   int              commplainsize)
2737 {
2738     int lbnd,rbnd,maxlr,b,i;
2739     int exten;
2740     int nn,nk;
2741     pme_grid_comm_t *pgc;
2742     gmx_bool bCont;
2743     int fft_start,fft_end,send_index1,recv_index1;
2744 #ifdef GMX_MPI
2745     MPI_Status stat;
2746
2747     ol->mpi_comm = comm;
2748 #endif
2749
2750     ol->nnodes = nnodes;
2751     ol->nodeid = nodeid;
2752
2753     /* Linear translation of the PME grid won't affect reciprocal space
2754      * calculations, so to optimize we only interpolate "upwards",
2755      * which also means we only have to consider overlap in one direction.
2756      * I.e., particles on this node might also be spread to grid indices
2757      * that belong to higher nodes (modulo nnodes)
2758      */
2759
2760     snew(ol->s2g0,ol->nnodes+1);
2761     snew(ol->s2g1,ol->nnodes);
2762     if (debug) { fprintf(debug,"PME slab boundaries:"); }
2763     for(i=0; i<nnodes; i++)
2764     {
2765         /* s2g0 the local interpolation grid start.
2766          * s2g1 the local interpolation grid end.
2767          * Because grid overlap communication only goes forward,
2768          * the grid the slabs for fft's should be rounded down.
2769          */
2770         ol->s2g0[i] = ( i   *ndata + 0       )/nnodes;
2771         ol->s2g1[i] = ((i+1)*ndata + nnodes-1)/nnodes + norder - 1;
2772
2773         if (debug)
2774         {
2775             fprintf(debug,"  %3d %3d",ol->s2g0[i],ol->s2g1[i]);
2776         }
2777     }
2778     ol->s2g0[nnodes] = ndata;
2779     if (debug) { fprintf(debug,"\n"); }
2780
2781     /* Determine with how many nodes we need to communicate the grid overlap */
2782     b = 0;
2783     do
2784     {
2785         b++;
2786         bCont = FALSE;
2787         for(i=0; i<nnodes; i++)
2788         {
2789             if ((i+b <  nnodes && ol->s2g1[i] > ol->s2g0[i+b]) ||
2790                 (i+b >= nnodes && ol->s2g1[i] > ol->s2g0[i+b-nnodes] + ndata))
2791             {
2792                 bCont = TRUE;
2793             }
2794         }
2795     }
2796     while (bCont && b < nnodes);
2797     ol->noverlap_nodes = b - 1;
2798
2799     snew(ol->send_id,ol->noverlap_nodes);
2800     snew(ol->recv_id,ol->noverlap_nodes);
2801     for(b=0; b<ol->noverlap_nodes; b++)
2802     {
2803         ol->send_id[b] = (ol->nodeid + (b + 1)) % ol->nnodes;
2804         ol->recv_id[b] = (ol->nodeid - (b + 1) + ol->nnodes) % ol->nnodes;
2805     }
2806     snew(ol->comm_data, ol->noverlap_nodes);
2807
2808     ol->send_size = 0;
2809     for(b=0; b<ol->noverlap_nodes; b++)
2810     {
2811         pgc = &ol->comm_data[b];
2812         /* Send */
2813         fft_start        = ol->s2g0[ol->send_id[b]];
2814         fft_end          = ol->s2g0[ol->send_id[b]+1];
2815         if (ol->send_id[b] < nodeid)
2816         {
2817             fft_start += ndata;
2818             fft_end   += ndata;
2819         }
2820         send_index1      = ol->s2g1[nodeid];
2821         send_index1      = min(send_index1,fft_end);
2822         pgc->send_index0 = fft_start;
2823         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
2824         ol->send_size    += pgc->send_nindex;
2825
2826         /* We always start receiving to the first index of our slab */
2827         fft_start        = ol->s2g0[ol->nodeid];
2828         fft_end          = ol->s2g0[ol->nodeid+1];
2829         recv_index1      = ol->s2g1[ol->recv_id[b]];
2830         if (ol->recv_id[b] > nodeid)
2831         {
2832             recv_index1 -= ndata;
2833         }
2834         recv_index1      = min(recv_index1,fft_end);
2835         pgc->recv_index0 = fft_start;
2836         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
2837     }
2838
2839 #ifdef GMX_MPI
2840     /* Communicate the buffer sizes to receive */
2841     for(b=0; b<ol->noverlap_nodes; b++)
2842     {
2843         MPI_Sendrecv(&ol->send_size             ,1,MPI_INT,ol->send_id[b],b,
2844                      &ol->comm_data[b].recv_size,1,MPI_INT,ol->recv_id[b],b,
2845                      ol->mpi_comm,&stat);
2846     }
2847 #endif
2848
2849     /* For non-divisible grid we need pme_order iso pme_order-1 */
2850     snew(ol->sendbuf,norder*commplainsize);
2851     snew(ol->recvbuf,norder*commplainsize);
2852 }
2853
2854 static void
2855 make_gridindex5_to_localindex(int n,int local_start,int local_range,
2856                               int **global_to_local,
2857                               real **fraction_shift)
2858 {
2859     int i;
2860     int * gtl;
2861     real * fsh;
2862
2863     snew(gtl,5*n);
2864     snew(fsh,5*n);
2865     for(i=0; (i<5*n); i++)
2866     {
2867         /* Determine the global to local grid index */
2868         gtl[i] = (i - local_start + n) % n;
2869         /* For coordinates that fall within the local grid the fraction
2870          * is correct, we don't need to shift it.
2871          */
2872         fsh[i] = 0;
2873         if (local_range < n)
2874         {
2875             /* Due to rounding issues i could be 1 beyond the lower or
2876              * upper boundary of the local grid. Correct the index for this.
2877              * If we shift the index, we need to shift the fraction by
2878              * the same amount in the other direction to not affect
2879              * the weights.
2880              * Note that due to this shifting the weights at the end of
2881              * the spline might change, but that will only involve values
2882              * between zero and values close to the precision of a real,
2883              * which is anyhow the accuracy of the whole mesh calculation.
2884              */
2885             /* With local_range=0 we should not change i=local_start */
2886             if (i % n != local_start)
2887             {
2888                 if (gtl[i] == n-1)
2889                 {
2890                     gtl[i] = 0;
2891                     fsh[i] = -1;
2892                 }
2893                 else if (gtl[i] == local_range)
2894                 {
2895                     gtl[i] = local_range - 1;
2896                     fsh[i] = 1;
2897                 }
2898             }
2899         }
2900     }
2901
2902     *global_to_local = gtl;
2903     *fraction_shift  = fsh;
2904 }
2905
2906 static pme_spline_work_t *make_pme_spline_work(int order)
2907 {
2908     pme_spline_work_t *work;
2909
2910 #ifdef PME_SSE
2911     float  tmp[8];
2912     __m128 zero_SSE;
2913     int    of,i;
2914
2915     snew_aligned(work,1,16);
2916
2917     zero_SSE = _mm_setzero_ps();
2918
2919     /* Generate bit masks to mask out the unused grid entries,
2920      * as we only operate on order of the 8 grid entries that are
2921      * load into 2 SSE float registers.
2922      */
2923     for(of=0; of<8-(order-1); of++)
2924     {
2925         for(i=0; i<8; i++)
2926         {
2927             tmp[i] = (i >= of && i < of+order ? 1 : 0);
2928         }
2929         work->mask_SSE0[of] = _mm_loadu_ps(tmp);
2930         work->mask_SSE1[of] = _mm_loadu_ps(tmp+4);
2931         work->mask_SSE0[of] = _mm_cmpgt_ps(work->mask_SSE0[of],zero_SSE);
2932         work->mask_SSE1[of] = _mm_cmpgt_ps(work->mask_SSE1[of],zero_SSE);
2933     }
2934 #else
2935     work = NULL;
2936 #endif
2937
2938     return work;
2939 }
2940
2941 static void
2942 gmx_pme_check_grid_restrictions(FILE *fplog,char dim,int nnodes,int *nk)
2943 {
2944     int nk_new;
2945
2946     if (*nk % nnodes != 0)
2947     {
2948         nk_new = nnodes*(*nk/nnodes + 1);
2949
2950         if (2*nk_new >= 3*(*nk))
2951         {
2952             gmx_fatal(FARGS,"The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). The grid size would have to be increased by more than 50%% to make the grid divisible. Change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).",
2953                       dim,*nk,dim,nnodes,dim);
2954         }
2955
2956         if (fplog != NULL)
2957         {
2958             fprintf(fplog,"\nNOTE: The PME grid size in dim %c (%d) is not divisble by the number of nodes doing PME in dim %c (%d). Increasing the PME grid size in dim %c to %d. This will increase the accuracy and will not decrease the performance significantly on this number of nodes. For optimal performance change the total number of nodes or the number of domain decomposition cells in x or the PME grid %c dimension (and the cut-off).\n\n",
2959                     dim,*nk,dim,nnodes,dim,nk_new,dim);
2960         }
2961
2962         *nk = nk_new;
2963     }
2964 }
2965
2966 int gmx_pme_init(gmx_pme_t *         pmedata,
2967                  t_commrec *         cr,
2968                  int                 nnodes_major,
2969                  int                 nnodes_minor,
2970                  t_inputrec *        ir,
2971                  int                 homenr,
2972                  gmx_bool            bFreeEnergy,
2973                  gmx_bool            bReproducible,
2974                  int                 nthread)
2975 {
2976     gmx_pme_t pme=NULL;
2977
2978     pme_atomcomm_t *atc;
2979     ivec ndata;
2980
2981     if (debug)
2982         fprintf(debug,"Creating PME data structures.\n");
2983     snew(pme,1);
2984
2985     pme->redist_init         = FALSE;
2986     pme->sum_qgrid_tmp       = NULL;
2987     pme->sum_qgrid_dd_tmp    = NULL;
2988     pme->buf_nalloc          = 0;
2989     pme->redist_buf_nalloc   = 0;
2990
2991     pme->nnodes              = 1;
2992     pme->bPPnode             = TRUE;
2993
2994     pme->nnodes_major        = nnodes_major;
2995     pme->nnodes_minor        = nnodes_minor;
2996
2997 #ifdef GMX_MPI
2998     if (nnodes_major*nnodes_minor > 1)
2999     {
3000         pme->mpi_comm = cr->mpi_comm_mygroup;
3001
3002         MPI_Comm_rank(pme->mpi_comm,&pme->nodeid);
3003         MPI_Comm_size(pme->mpi_comm,&pme->nnodes);
3004         if (pme->nnodes != nnodes_major*nnodes_minor)
3005         {
3006             gmx_incons("PME node count mismatch");
3007         }
3008     }
3009     else
3010     {
3011         pme->mpi_comm = MPI_COMM_NULL;
3012     }
3013 #endif
3014
3015     if (pme->nnodes == 1)
3016     {
3017 #ifdef GMX_MPI
3018         pme->mpi_comm_d[0] = MPI_COMM_NULL;
3019         pme->mpi_comm_d[1] = MPI_COMM_NULL;
3020 #endif
3021         pme->ndecompdim = 0;
3022         pme->nodeid_major = 0;
3023         pme->nodeid_minor = 0;
3024 #ifdef GMX_MPI
3025         pme->mpi_comm_d[0] = pme->mpi_comm_d[1] = MPI_COMM_NULL;
3026 #endif
3027     }
3028     else
3029     {
3030         if (nnodes_minor == 1)
3031         {
3032 #ifdef GMX_MPI
3033             pme->mpi_comm_d[0] = pme->mpi_comm;
3034             pme->mpi_comm_d[1] = MPI_COMM_NULL;
3035 #endif
3036             pme->ndecompdim = 1;
3037             pme->nodeid_major = pme->nodeid;
3038             pme->nodeid_minor = 0;
3039
3040         }
3041         else if (nnodes_major == 1)
3042         {
3043 #ifdef GMX_MPI
3044             pme->mpi_comm_d[0] = MPI_COMM_NULL;
3045             pme->mpi_comm_d[1] = pme->mpi_comm;
3046 #endif
3047             pme->ndecompdim = 1;
3048             pme->nodeid_major = 0;
3049             pme->nodeid_minor = pme->nodeid;
3050         }
3051         else
3052         {
3053             if (pme->nnodes % nnodes_major != 0)
3054             {
3055                 gmx_incons("For 2D PME decomposition, #PME nodes must be divisible by the number of nodes in the major dimension");
3056             }
3057             pme->ndecompdim = 2;
3058
3059 #ifdef GMX_MPI
3060             MPI_Comm_split(pme->mpi_comm,pme->nodeid % nnodes_minor,
3061                            pme->nodeid,&pme->mpi_comm_d[0]);  /* My communicator along major dimension */
3062             MPI_Comm_split(pme->mpi_comm,pme->nodeid/nnodes_minor,
3063                            pme->nodeid,&pme->mpi_comm_d[1]);  /* My communicator along minor dimension */
3064
3065             MPI_Comm_rank(pme->mpi_comm_d[0],&pme->nodeid_major);
3066             MPI_Comm_size(pme->mpi_comm_d[0],&pme->nnodes_major);
3067             MPI_Comm_rank(pme->mpi_comm_d[1],&pme->nodeid_minor);
3068             MPI_Comm_size(pme->mpi_comm_d[1],&pme->nnodes_minor);
3069 #endif
3070         }
3071         pme->bPPnode = (cr->duty & DUTY_PP);
3072     }
3073
3074     pme->nthread = nthread;
3075
3076     if (ir->ePBC == epbcSCREW)
3077     {
3078         gmx_fatal(FARGS,"pme does not (yet) work with pbc = screw");
3079     }
3080
3081     pme->bFEP        = ((ir->efep != efepNO) && bFreeEnergy);
3082     pme->nkx         = ir->nkx;
3083     pme->nky         = ir->nky;
3084     pme->nkz         = ir->nkz;
3085     pme->bP3M        = (ir->coulombtype == eelP3M_AD || getenv("GMX_PME_P3M") != NULL);
3086     pme->pme_order   = ir->pme_order;
3087     pme->epsilon_r   = ir->epsilon_r;
3088
3089     if (pme->pme_order > PME_ORDER_MAX)
3090     {
3091         gmx_fatal(FARGS,"pme_order (%d) is larger than the maximum allowed value (%d). Modify and recompile the code if you really need such a high order.",
3092                   pme->pme_order,PME_ORDER_MAX);
3093     }
3094
3095     /* Currently pme.c supports only the fft5d FFT code.
3096      * Therefore the grid always needs to be divisible by nnodes.
3097      * When the old 1D code is also supported again, change this check.
3098      *
3099      * This check should be done before calling gmx_pme_init
3100      * and fplog should be passed iso stderr.
3101      *
3102     if (pme->ndecompdim >= 2)
3103     */
3104     if (pme->ndecompdim >= 1)
3105     {
3106         /*
3107         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3108                                         'x',nnodes_major,&pme->nkx);
3109         gmx_pme_check_grid_restrictions(pme->nodeid==0 ? stderr : NULL,
3110                                         'y',nnodes_minor,&pme->nky);
3111         */
3112     }
3113
3114     if (pme->nkx <= pme->pme_order*(pme->nnodes_major > 1 ? 2 : 1) ||
3115         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
3116         pme->nkz <= pme->pme_order)
3117     {
3118         gmx_fatal(FARGS,"The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",pme->pme_order);
3119     }
3120
3121     if (pme->nnodes > 1) {
3122         double imbal;
3123
3124 #ifdef GMX_MPI
3125         MPI_Type_contiguous(DIM, mpi_type, &(pme->rvec_mpi));
3126         MPI_Type_commit(&(pme->rvec_mpi));
3127 #endif
3128
3129         /* Note that the charge spreading and force gathering, which usually
3130          * takes about the same amount of time as FFT+solve_pme,
3131          * is always fully load balanced
3132          * (unless the charge distribution is inhomogeneous).
3133          */
3134
3135         imbal = pme_load_imbalance(pme);
3136         if (imbal >= 1.2 && pme->nodeid_major == 0 && pme->nodeid_minor == 0)
3137         {
3138             fprintf(stderr,
3139                     "\n"
3140                     "NOTE: The load imbalance in PME FFT and solve is %d%%.\n"
3141                     "      For optimal PME load balancing\n"
3142                     "      PME grid_x (%d) and grid_y (%d) should be divisible by #PME_nodes_x (%d)\n"
3143                     "      and PME grid_y (%d) and grid_z (%d) should be divisible by #PME_nodes_y (%d)\n"
3144                     "\n",
3145                     (int)((imbal-1)*100 + 0.5),
3146                     pme->nkx,pme->nky,pme->nnodes_major,
3147                     pme->nky,pme->nkz,pme->nnodes_minor);
3148         }
3149     }
3150
3151     /* For non-divisible grid we need pme_order iso pme_order-1 */
3152     /* In sum_qgrid_dd x overlap is copied in place: take padding into account.
3153      * y is always copied through a buffer: we don't need padding in z,
3154      * but we do need the overlap in x because of the communication order.
3155      */
3156     init_overlap_comm(&pme->overlap[0],pme->pme_order,
3157 #ifdef GMX_MPI
3158                       pme->mpi_comm_d[0],
3159 #endif
3160                       pme->nnodes_major,pme->nodeid_major,
3161                       pme->nkx,
3162                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
3163
3164     /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
3165      * We do this with an offset buffer of equal size, so we need to allocate
3166      * extra for the offset. That's what the (+1)*pme->nkz is for.
3167      */
3168     init_overlap_comm(&pme->overlap[1],pme->pme_order,
3169 #ifdef GMX_MPI
3170                       pme->mpi_comm_d[1],
3171 #endif
3172                       pme->nnodes_minor,pme->nodeid_minor,
3173                       pme->nky,
3174                       (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
3175
3176     /* Check for a limitation of the (current) sum_fftgrid_dd code.
3177      * We only allow multiple communication pulses in dim 1, not in dim 0.
3178      */
3179     if (pme->nthread > 1 && (pme->overlap[0].noverlap_nodes > 1 ||
3180                              pme->nkx < pme->nnodes_major*pme->pme_order))
3181     {
3182         gmx_fatal(FARGS,"The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x and should be >= pme_order (%d). To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
3183                   pme->nkx/(double)pme->nnodes_major,pme->pme_order);
3184     }
3185
3186     snew(pme->bsp_mod[XX],pme->nkx);
3187     snew(pme->bsp_mod[YY],pme->nky);
3188     snew(pme->bsp_mod[ZZ],pme->nkz);
3189
3190     /* The required size of the interpolation grid, including overlap.
3191      * The allocated size (pmegrid_n?) might be slightly larger.
3192      */
3193     pme->pmegrid_nx = pme->overlap[0].s2g1[pme->nodeid_major] -
3194                       pme->overlap[0].s2g0[pme->nodeid_major];
3195     pme->pmegrid_ny = pme->overlap[1].s2g1[pme->nodeid_minor] -
3196                       pme->overlap[1].s2g0[pme->nodeid_minor];
3197     pme->pmegrid_nz_base = pme->nkz;
3198     pme->pmegrid_nz = pme->pmegrid_nz_base + pme->pme_order - 1;
3199     set_grid_alignment(&pme->pmegrid_nz,pme->pme_order);
3200
3201     pme->pmegrid_start_ix = pme->overlap[0].s2g0[pme->nodeid_major];
3202     pme->pmegrid_start_iy = pme->overlap[1].s2g0[pme->nodeid_minor];
3203     pme->pmegrid_start_iz = 0;
3204
3205     make_gridindex5_to_localindex(pme->nkx,
3206                                   pme->pmegrid_start_ix,
3207                                   pme->pmegrid_nx - (pme->pme_order-1),
3208                                   &pme->nnx,&pme->fshx);
3209     make_gridindex5_to_localindex(pme->nky,
3210                                   pme->pmegrid_start_iy,
3211                                   pme->pmegrid_ny - (pme->pme_order-1),
3212                                   &pme->nny,&pme->fshy);
3213     make_gridindex5_to_localindex(pme->nkz,
3214                                   pme->pmegrid_start_iz,
3215                                   pme->pmegrid_nz_base,
3216                                   &pme->nnz,&pme->fshz);
3217
3218     pmegrids_init(&pme->pmegridA,
3219                   pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3220                   pme->pmegrid_nz_base,
3221                   pme->pme_order,
3222                   pme->nthread,
3223                   pme->overlap[0].s2g1[pme->nodeid_major]-pme->overlap[0].s2g0[pme->nodeid_major+1],
3224                   pme->overlap[1].s2g1[pme->nodeid_minor]-pme->overlap[1].s2g0[pme->nodeid_minor+1]);
3225
3226     pme->spline_work = make_pme_spline_work(pme->pme_order);
3227
3228     ndata[0] = pme->nkx;
3229     ndata[1] = pme->nky;
3230     ndata[2] = pme->nkz;
3231
3232     /* This routine will allocate the grid data to fit the FFTs */
3233     gmx_parallel_3dfft_init(&pme->pfft_setupA,ndata,
3234                             &pme->fftgridA,&pme->cfftgridA,
3235                             pme->mpi_comm_d,
3236                             pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3237                             bReproducible,pme->nthread);
3238
3239     if (bFreeEnergy)
3240     {
3241         pmegrids_init(&pme->pmegridB,
3242                       pme->pmegrid_nx,pme->pmegrid_ny,pme->pmegrid_nz,
3243                       pme->pmegrid_nz_base,
3244                       pme->pme_order,
3245                       pme->nthread,
3246                       pme->nkx % pme->nnodes_major != 0,
3247                       pme->nky % pme->nnodes_minor != 0);
3248
3249         gmx_parallel_3dfft_init(&pme->pfft_setupB,ndata,
3250                                 &pme->fftgridB,&pme->cfftgridB,
3251                                 pme->mpi_comm_d,
3252                                 pme->overlap[0].s2g0,pme->overlap[1].s2g0,
3253                                 bReproducible,pme->nthread);
3254     }
3255     else
3256     {
3257         pme->pmegridB.grid.grid = NULL;
3258         pme->fftgridB           = NULL;
3259         pme->cfftgridB          = NULL;
3260     }
3261
3262     if (!pme->bP3M)
3263     {
3264         /* Use plain SPME B-spline interpolation */
3265         make_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3266     }
3267     else
3268     {
3269         /* Use the P3M grid-optimized influence function */
3270         make_p3m_bspline_moduli(pme->bsp_mod,pme->nkx,pme->nky,pme->nkz,pme->pme_order);
3271     }
3272
3273     /* Use atc[0] for spreading */
3274     init_atomcomm(pme,&pme->atc[0],cr,nnodes_major > 1 ? 0 : 1,TRUE);
3275     if (pme->ndecompdim >= 2)
3276     {
3277         init_atomcomm(pme,&pme->atc[1],cr,1,FALSE);
3278     }
3279
3280     if (pme->nnodes == 1) {
3281         pme->atc[0].n = homenr;
3282         pme_realloc_atomcomm_things(&pme->atc[0]);
3283     }
3284
3285     {
3286         int thread;
3287
3288         /* Use fft5d, order after FFT is y major, z, x minor */
3289
3290         snew(pme->work,pme->nthread);
3291         for(thread=0; thread<pme->nthread; thread++)
3292         {
3293             realloc_work(&pme->work[thread],pme->nkx);
3294         }
3295     }
3296
3297     *pmedata = pme;
3298     
3299     return 0;
3300 }
3301
3302 static void reuse_pmegrids(const pmegrids_t *old,pmegrids_t *new)
3303 {
3304     int d,t;
3305
3306     for(d=0; d<DIM; d++)
3307     {
3308         if (new->grid.n[d] > old->grid.n[d])
3309         {
3310             return;
3311         }
3312     }
3313
3314     sfree_aligned(new->grid.grid);
3315     new->grid.grid = old->grid.grid;
3316
3317     if (new->nthread > 1 && new->nthread == old->nthread)
3318     {
3319         sfree_aligned(new->grid_all);
3320         for(t=0; t<new->nthread; t++)
3321         {
3322             new->grid_th[t].grid = old->grid_th[t].grid;
3323         }
3324     }
3325 }
3326
3327 int gmx_pme_reinit(gmx_pme_t *         pmedata,
3328                    t_commrec *         cr,
3329                    gmx_pme_t           pme_src,
3330                    const t_inputrec *  ir,
3331                    ivec                grid_size)
3332 {
3333     t_inputrec irc;
3334     int homenr;
3335     int ret;
3336
3337     irc = *ir;
3338     irc.nkx = grid_size[XX];
3339     irc.nky = grid_size[YY];
3340     irc.nkz = grid_size[ZZ];
3341
3342     if (pme_src->nnodes == 1)
3343     {
3344         homenr = pme_src->atc[0].n;
3345     }
3346     else
3347     {
3348         homenr = -1;
3349     }
3350
3351     ret = gmx_pme_init(pmedata,cr,pme_src->nnodes_major,pme_src->nnodes_minor,
3352                        &irc,homenr,pme_src->bFEP,FALSE,pme_src->nthread);
3353
3354     if (ret == 0)
3355     {
3356         /* We can easily reuse the allocated pme grids in pme_src */
3357         reuse_pmegrids(&pme_src->pmegridA,&(*pmedata)->pmegridA);
3358         /* We would like to reuse the fft grids, but that's harder */
3359     }
3360
3361     return ret;
3362 }
3363
3364
3365 static void copy_local_grid(gmx_pme_t pme,
3366                             pmegrids_t *pmegrids,int thread,real *fftgrid)
3367 {
3368     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3369     int  fft_my,fft_mz;
3370     int  nsx,nsy,nsz;
3371     ivec nf;
3372     int  offx,offy,offz,x,y,z,i0,i0t;
3373     int  d;
3374     pmegrid_t *pmegrid;
3375     real *grid_th;
3376
3377     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3378                                    local_fft_ndata,
3379                                    local_fft_offset,
3380                                    local_fft_size);
3381     fft_my = local_fft_size[YY];
3382     fft_mz = local_fft_size[ZZ];
3383
3384     pmegrid = &pmegrids->grid_th[thread];
3385
3386     nsx = pmegrid->s[XX];
3387     nsy = pmegrid->s[YY];
3388     nsz = pmegrid->s[ZZ];
3389
3390     for(d=0; d<DIM; d++)
3391     {
3392         nf[d] = min(pmegrid->n[d] - (pmegrid->order - 1),
3393                     local_fft_ndata[d] - pmegrid->offset[d]);
3394     }
3395
3396     offx = pmegrid->offset[XX];
3397     offy = pmegrid->offset[YY];
3398     offz = pmegrid->offset[ZZ];
3399
3400     /* Directly copy the non-overlapping parts of the local grids.
3401      * This also initializes the full grid.
3402      */
3403     grid_th = pmegrid->grid;
3404     for(x=0; x<nf[XX]; x++)
3405     {
3406         for(y=0; y<nf[YY]; y++)
3407         {
3408             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
3409             i0t = (x*nsy + y)*nsz;
3410             for(z=0; z<nf[ZZ]; z++)
3411             {
3412                 fftgrid[i0+z] = grid_th[i0t+z];
3413             }
3414         }
3415     }
3416 }
3417
3418 static void
3419 reduce_threadgrid_overlap(gmx_pme_t pme,
3420                           const pmegrids_t *pmegrids,int thread,
3421                           real *fftgrid,real *commbuf_x,real *commbuf_y)
3422 {
3423     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3424     int  fft_nx,fft_ny,fft_nz;
3425     int  fft_my,fft_mz;
3426     int  buf_my=-1;
3427     int  nsx,nsy,nsz;
3428     ivec ne;
3429     int  offx,offy,offz,x,y,z,i0,i0t;
3430     int  sx,sy,sz,fx,fy,fz,tx1,ty1,tz1,ox,oy,oz;
3431     gmx_bool bClearBufX,bClearBufY,bClearBufXY,bClearBuf;
3432     gmx_bool bCommX,bCommY;
3433     int  d;
3434     int  thread_f;
3435     const pmegrid_t *pmegrid,*pmegrid_g,*pmegrid_f;
3436     const real *grid_th;
3437     real *commbuf=NULL;
3438
3439     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3440                                    local_fft_ndata,
3441                                    local_fft_offset,
3442                                    local_fft_size);
3443     fft_nx = local_fft_ndata[XX];
3444     fft_ny = local_fft_ndata[YY];
3445     fft_nz = local_fft_ndata[ZZ];
3446
3447     fft_my = local_fft_size[YY];
3448     fft_mz = local_fft_size[ZZ];
3449
3450     /* This routine is called when all thread have finished spreading.
3451      * Here each thread sums grid contributions calculated by other threads
3452      * to the thread local grid volume.
3453      * To minimize the number of grid copying operations,
3454      * this routines sums immediately from the pmegrid to the fftgrid.
3455      */
3456
3457     /* Determine which part of the full node grid we should operate on,
3458      * this is our thread local part of the full grid.
3459      */
3460     pmegrid = &pmegrids->grid_th[thread];
3461
3462     for(d=0; d<DIM; d++)
3463     {
3464         ne[d] = min(pmegrid->offset[d]+pmegrid->n[d]-(pmegrid->order-1),
3465                     local_fft_ndata[d]);
3466     }
3467
3468     offx = pmegrid->offset[XX];
3469     offy = pmegrid->offset[YY];
3470     offz = pmegrid->offset[ZZ];
3471
3472
3473     bClearBufX  = TRUE;
3474     bClearBufY  = TRUE;
3475     bClearBufXY = TRUE;
3476
3477     /* Now loop over all the thread data blocks that contribute
3478      * to the grid region we (our thread) are operating on.
3479      */
3480     /* Note that ffy_nx/y is equal to the number of grid points
3481      * between the first point of our node grid and the one of the next node.
3482      */
3483     for(sx=0; sx>=-pmegrids->nthread_comm[XX]; sx--)
3484     {
3485         fx = pmegrid->ci[XX] + sx;
3486         ox = 0;
3487         bCommX = FALSE;
3488         if (fx < 0) {
3489             fx += pmegrids->nc[XX];
3490             ox -= fft_nx;
3491             bCommX = (pme->nnodes_major > 1);
3492         }
3493         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
3494         ox += pmegrid_g->offset[XX];
3495         if (!bCommX)
3496         {
3497             tx1 = min(ox + pmegrid_g->n[XX],ne[XX]);
3498         }
3499         else
3500         {
3501             tx1 = min(ox + pmegrid_g->n[XX],pme->pme_order);
3502         }
3503
3504         for(sy=0; sy>=-pmegrids->nthread_comm[YY]; sy--)
3505         {
3506             fy = pmegrid->ci[YY] + sy;
3507             oy = 0;
3508             bCommY = FALSE;
3509             if (fy < 0) {
3510                 fy += pmegrids->nc[YY];
3511                 oy -= fft_ny;
3512                 bCommY = (pme->nnodes_minor > 1);
3513             }
3514             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
3515             oy += pmegrid_g->offset[YY];
3516             if (!bCommY)
3517             {
3518                 ty1 = min(oy + pmegrid_g->n[YY],ne[YY]);
3519             }
3520             else
3521             {
3522                 ty1 = min(oy + pmegrid_g->n[YY],pme->pme_order);
3523             }
3524
3525             for(sz=0; sz>=-pmegrids->nthread_comm[ZZ]; sz--)
3526             {
3527                 fz = pmegrid->ci[ZZ] + sz;
3528                 oz = 0;
3529                 if (fz < 0)
3530                 {
3531                     fz += pmegrids->nc[ZZ];
3532                     oz -= fft_nz;
3533                 }
3534                 pmegrid_g = &pmegrids->grid_th[fz];
3535                 oz += pmegrid_g->offset[ZZ];
3536                 tz1 = min(oz + pmegrid_g->n[ZZ],ne[ZZ]);
3537
3538                 if (sx == 0 && sy == 0 && sz == 0)
3539                 {
3540                     /* We have already added our local contribution
3541                      * before calling this routine, so skip it here.
3542                      */
3543                     continue;
3544                 }
3545
3546                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
3547
3548                 pmegrid_f = &pmegrids->grid_th[thread_f];
3549
3550                 grid_th = pmegrid_f->grid;
3551
3552                 nsx = pmegrid_f->s[XX];
3553                 nsy = pmegrid_f->s[YY];
3554                 nsz = pmegrid_f->s[ZZ];
3555
3556 #ifdef DEBUG_PME_REDUCE
3557                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
3558                        pme->nodeid,thread,thread_f,
3559                        pme->pmegrid_start_ix,
3560                        pme->pmegrid_start_iy,
3561                        pme->pmegrid_start_iz,
3562                        sx,sy,sz,
3563                        offx-ox,tx1-ox,offx,tx1,
3564                        offy-oy,ty1-oy,offy,ty1,
3565                        offz-oz,tz1-oz,offz,tz1);
3566 #endif
3567
3568                 if (!(bCommX || bCommY))
3569                 {
3570                     /* Copy from the thread local grid to the node grid */
3571                     for(x=offx; x<tx1; x++)
3572                     {
3573                         for(y=offy; y<ty1; y++)
3574                         {
3575                             i0  = (x*fft_my + y)*fft_mz;
3576                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3577                             for(z=offz; z<tz1; z++)
3578                             {
3579                                 fftgrid[i0+z] += grid_th[i0t+z];
3580                             }
3581                         }
3582                     }
3583                 }
3584                 else
3585                 {
3586                     /* The order of this conditional decides
3587                      * where the corner volume gets stored with x+y decomp.
3588                      */
3589                     if (bCommY)
3590                     {
3591                         commbuf = commbuf_y;
3592                         buf_my  = ty1 - offy;
3593                         if (bCommX)
3594                         {
3595                             /* We index commbuf modulo the local grid size */
3596                             commbuf += buf_my*fft_nx*fft_nz;
3597
3598                             bClearBuf  = bClearBufXY;
3599                             bClearBufXY = FALSE;
3600                         }
3601                         else
3602                         {
3603                             bClearBuf  = bClearBufY;
3604                             bClearBufY = FALSE;
3605                         }
3606                     }
3607                     else
3608                     {
3609                         commbuf = commbuf_x;
3610                         buf_my  = fft_ny;
3611                         bClearBuf  = bClearBufX;
3612                         bClearBufX = FALSE;
3613                     }
3614
3615                     /* Copy to the communication buffer */
3616                     for(x=offx; x<tx1; x++)
3617                     {
3618                         for(y=offy; y<ty1; y++)
3619                         {
3620                             i0  = (x*buf_my + y)*fft_nz;
3621                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
3622
3623                             if (bClearBuf)
3624                             {
3625                                 /* First access of commbuf, initialize it */
3626                                 for(z=offz; z<tz1; z++)
3627                                 {
3628                                     commbuf[i0+z]  = grid_th[i0t+z];
3629                                 }
3630                             }
3631                             else
3632                             {
3633                                 for(z=offz; z<tz1; z++)
3634                                 {
3635                                     commbuf[i0+z] += grid_th[i0t+z];
3636                                 }
3637                             }
3638                         }
3639                     }
3640                 }
3641             }
3642         }
3643     }
3644 }
3645
3646
3647 static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
3648 {
3649     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3650     pme_overlap_t *overlap;
3651     int  send_index0,send_nindex;
3652     int  recv_nindex;
3653 #ifdef GMX_MPI
3654     MPI_Status stat;
3655 #endif
3656     int  send_size_y,recv_size_y;
3657     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
3658     real *sendptr,*recvptr;
3659     int  x,y,z,indg,indb;
3660
3661     /* Note that this routine is only used for forward communication.
3662      * Since the force gathering, unlike the charge spreading,
3663      * can be trivially parallelized over the particles,
3664      * the backwards process is much simpler and can use the "old"
3665      * communication setup.
3666      */
3667
3668     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3669                                    local_fft_ndata,
3670                                    local_fft_offset,
3671                                    local_fft_size);
3672
3673     if (pme->nnodes_minor > 1)
3674     {
3675         /* Major dimension */
3676         overlap = &pme->overlap[1];
3677
3678         if (pme->nnodes_major > 1)
3679         {
3680              size_yx = pme->overlap[0].comm_data[0].send_nindex;
3681         }
3682         else
3683         {
3684             size_yx = 0;
3685         }
3686         datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
3687
3688         send_size_y = overlap->send_size;
3689
3690         for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
3691         {
3692             send_id = overlap->send_id[ipulse];
3693             recv_id = overlap->recv_id[ipulse];
3694             send_index0   =
3695                 overlap->comm_data[ipulse].send_index0 -
3696                 overlap->comm_data[0].send_index0;
3697             send_nindex   = overlap->comm_data[ipulse].send_nindex;
3698             /* We don't use recv_index0, as we always receive starting at 0 */
3699             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3700             recv_size_y   = overlap->comm_data[ipulse].recv_size;
3701
3702             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
3703             recvptr = overlap->recvbuf;
3704
3705 #ifdef GMX_MPI
3706             MPI_Sendrecv(sendptr,send_size_y*datasize,GMX_MPI_REAL,
3707                          send_id,ipulse,
3708                          recvptr,recv_size_y*datasize,GMX_MPI_REAL,
3709                          recv_id,ipulse,
3710                          overlap->mpi_comm,&stat);
3711 #endif
3712
3713             for(x=0; x<local_fft_ndata[XX]; x++)
3714             {
3715                 for(y=0; y<recv_nindex; y++)
3716                 {
3717                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
3718                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
3719                     for(z=0; z<local_fft_ndata[ZZ]; z++)
3720                     {
3721                         fftgrid[indg+z] += recvptr[indb+z];
3722                     }
3723                 }
3724             }
3725
3726             if (pme->nnodes_major > 1)
3727             {
3728                 /* Copy from the received buffer to the send buffer for dim 0 */
3729                 sendptr = pme->overlap[0].sendbuf;
3730                 for(x=0; x<size_yx; x++)
3731                 {
3732                     for(y=0; y<recv_nindex; y++)
3733                     {
3734                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3735                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
3736                         for(z=0; z<local_fft_ndata[ZZ]; z++)
3737                         {
3738                             sendptr[indg+z] += recvptr[indb+z];
3739                         }
3740                     }
3741                 }
3742             }
3743         }
3744     }
3745
3746     /* We only support a single pulse here.
3747      * This is not a severe limitation, as this code is only used
3748      * with OpenMP and with OpenMP the (PME) domains can be larger.
3749      */
3750     if (pme->nnodes_major > 1)
3751     {
3752         /* Major dimension */
3753         overlap = &pme->overlap[0];
3754
3755         datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
3756         gridsize = local_fft_size[YY] *local_fft_size[ZZ];
3757
3758         ipulse = 0;
3759
3760         send_id = overlap->send_id[ipulse];
3761         recv_id = overlap->recv_id[ipulse];
3762         send_nindex   = overlap->comm_data[ipulse].send_nindex;
3763         /* We don't use recv_index0, as we always receive starting at 0 */
3764         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
3765
3766         sendptr = overlap->sendbuf;
3767         recvptr = overlap->recvbuf;
3768
3769         if (debug != NULL)
3770         {
3771             fprintf(debug,"PME fftgrid comm %2d x %2d x %2d\n",
3772                    send_nindex,local_fft_ndata[YY],local_fft_ndata[ZZ]);
3773         }
3774
3775 #ifdef GMX_MPI
3776         MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
3777                      send_id,ipulse,
3778                      recvptr,recv_nindex*datasize,GMX_MPI_REAL,
3779                      recv_id,ipulse,
3780                      overlap->mpi_comm,&stat);
3781 #endif
3782
3783         for(x=0; x<recv_nindex; x++)
3784         {
3785             for(y=0; y<local_fft_ndata[YY]; y++)
3786             {
3787                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
3788                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
3789                 for(z=0; z<local_fft_ndata[ZZ]; z++)
3790                 {
3791                     fftgrid[indg+z] += recvptr[indb+z];
3792                 }
3793             }
3794         }
3795     }
3796 }
3797
3798
3799 static void spread_on_grid(gmx_pme_t pme,
3800                            pme_atomcomm_t *atc,pmegrids_t *grids,
3801                            gmx_bool bCalcSplines,gmx_bool bSpread,
3802                            real *fftgrid)
3803 {
3804     int nthread,thread;
3805 #ifdef PME_TIME_THREADS
3806     gmx_cycles_t c1,c2,c3,ct1a,ct1b,ct1c;
3807     static double cs1=0,cs2=0,cs3=0;
3808     static double cs1a[6]={0,0,0,0,0,0};
3809     static int cnt=0;
3810 #endif
3811
3812     nthread = pme->nthread;
3813     assert(nthread>0);
3814
3815 #ifdef PME_TIME_THREADS
3816     c1 = omp_cyc_start();
3817 #endif
3818     if (bCalcSplines)
3819     {
3820 #pragma omp parallel for num_threads(nthread) schedule(static)
3821         for(thread=0; thread<nthread; thread++)
3822         {
3823             int start,end;
3824
3825             start = atc->n* thread   /nthread;
3826             end   = atc->n*(thread+1)/nthread;
3827
3828             /* Compute fftgrid index for all atoms,
3829              * with help of some extra variables.
3830              */
3831             calc_interpolation_idx(pme,atc,start,end,thread);
3832         }
3833     }
3834 #ifdef PME_TIME_THREADS
3835     c1 = omp_cyc_end(c1);
3836     cs1 += (double)c1;
3837 #endif
3838
3839 #ifdef PME_TIME_THREADS
3840     c2 = omp_cyc_start();
3841 #endif
3842 #pragma omp parallel for num_threads(nthread) schedule(static)
3843     for(thread=0; thread<nthread; thread++)
3844     {
3845         splinedata_t *spline;
3846         pmegrid_t *grid;
3847
3848         /* make local bsplines  */
3849         if (grids == NULL || grids->nthread == 1)
3850         {
3851             spline = &atc->spline[0];
3852
3853             spline->n = atc->n;
3854
3855             grid = &grids->grid;
3856         }
3857         else
3858         {
3859             spline = &atc->spline[thread];
3860
3861             make_thread_local_ind(atc,thread,spline);
3862
3863             grid = &grids->grid_th[thread];
3864         }
3865
3866         if (bCalcSplines)
3867         {
3868             make_bsplines(spline->theta,spline->dtheta,pme->pme_order,
3869                           atc->fractx,spline->n,spline->ind,atc->q,pme->bFEP);
3870         }
3871
3872         if (bSpread)
3873         {
3874             /* put local atoms on grid. */
3875 #ifdef PME_TIME_SPREAD
3876             ct1a = omp_cyc_start();
3877 #endif
3878             spread_q_bsplines_thread(grid,atc,spline,pme->spline_work);
3879
3880             if (grids->nthread > 1)
3881             {
3882                 copy_local_grid(pme,grids,thread,fftgrid);
3883             }
3884 #ifdef PME_TIME_SPREAD
3885             ct1a = omp_cyc_end(ct1a);
3886             cs1a[thread] += (double)ct1a;
3887 #endif
3888         }
3889     }
3890 #ifdef PME_TIME_THREADS
3891     c2 = omp_cyc_end(c2);
3892     cs2 += (double)c2;
3893 #endif
3894
3895     if (bSpread && grids->nthread > 1)
3896     {
3897 #ifdef PME_TIME_THREADS
3898         c3 = omp_cyc_start();
3899 #endif
3900 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
3901         for(thread=0; thread<grids->nthread; thread++)
3902         {
3903             reduce_threadgrid_overlap(pme,grids,thread,
3904                                       fftgrid,
3905                                       pme->overlap[0].sendbuf,
3906                                       pme->overlap[1].sendbuf);
3907         }
3908 #ifdef PME_TIME_THREADS
3909         c3 = omp_cyc_end(c3);
3910         cs3 += (double)c3;
3911 #endif
3912
3913         if (pme->nnodes > 1)
3914         {
3915             /* Communicate the overlapping part of the fftgrid */
3916             sum_fftgrid_dd(pme,fftgrid);
3917         }
3918     }
3919
3920 #ifdef PME_TIME_THREADS
3921     cnt++;
3922     if (cnt % 20 == 0)
3923     {
3924         printf("idx %.2f spread %.2f red %.2f",
3925                cs1*1e-9,cs2*1e-9,cs3*1e-9);
3926 #ifdef PME_TIME_SPREAD
3927         for(thread=0; thread<nthread; thread++)
3928             printf(" %.2f",cs1a[thread]*1e-9);
3929 #endif
3930         printf("\n");
3931     }
3932 #endif
3933 }
3934
3935
3936 static void dump_grid(FILE *fp,
3937                       int sx,int sy,int sz,int nx,int ny,int nz,
3938                       int my,int mz,const real *g)
3939 {
3940     int x,y,z;
3941
3942     for(x=0; x<nx; x++)
3943     {
3944         for(y=0; y<ny; y++)
3945         {
3946             for(z=0; z<nz; z++)
3947             {
3948                 fprintf(fp,"%2d %2d %2d %6.3f\n",
3949                         sx+x,sy+y,sz+z,g[(x*my + y)*mz + z]);
3950             }
3951         }
3952     }
3953 }
3954
3955 static void dump_local_fftgrid(gmx_pme_t pme,const real *fftgrid)
3956 {
3957     ivec local_fft_ndata,local_fft_offset,local_fft_size;
3958
3959     gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
3960                                    local_fft_ndata,
3961                                    local_fft_offset,
3962                                    local_fft_size);
3963
3964     dump_grid(stderr,
3965               pme->pmegrid_start_ix,
3966               pme->pmegrid_start_iy,
3967               pme->pmegrid_start_iz,
3968               pme->pmegrid_nx-pme->pme_order+1,
3969               pme->pmegrid_ny-pme->pme_order+1,
3970               pme->pmegrid_nz-pme->pme_order+1,
3971               local_fft_size[YY],
3972               local_fft_size[ZZ],
3973               fftgrid);
3974 }
3975
3976
3977 void gmx_pme_calc_energy(gmx_pme_t pme,int n,rvec *x,real *q,real *V)
3978 {
3979     pme_atomcomm_t *atc;
3980     pmegrids_t *grid;
3981
3982     if (pme->nnodes > 1)
3983     {
3984         gmx_incons("gmx_pme_calc_energy called in parallel");
3985     }
3986     if (pme->bFEP > 1)
3987     {
3988         gmx_incons("gmx_pme_calc_energy with free energy");
3989     }
3990
3991     atc = &pme->atc_energy;
3992     atc->nthread   = 1;
3993     if (atc->spline == NULL)
3994     {
3995         snew(atc->spline,atc->nthread);
3996     }
3997     atc->nslab     = 1;
3998     atc->bSpread   = TRUE;
3999     atc->pme_order = pme->pme_order;
4000     atc->n         = n;
4001     pme_realloc_atomcomm_things(atc);
4002     atc->x         = x;
4003     atc->q         = q;
4004
4005     /* We only use the A-charges grid */
4006     grid = &pme->pmegridA;
4007
4008     spread_on_grid(pme,atc,NULL,TRUE,FALSE,pme->fftgridA);
4009
4010     *V = gather_energy_bsplines(pme,grid->grid.grid,atc);
4011 }
4012
4013
4014 static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
4015         t_nrnb *nrnb,t_inputrec *ir, gmx_large_int_t step_rel)
4016 {
4017     /* Reset all the counters related to performance over the run */
4018     wallcycle_stop(wcycle,ewcRUN);
4019     wallcycle_reset_all(wcycle);
4020     init_nrnb(nrnb);
4021     ir->init_step += step_rel;
4022     ir->nsteps    -= step_rel;
4023     wallcycle_start(wcycle,ewcRUN);
4024 }
4025
4026
4027 static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
4028                                ivec grid_size,
4029                                t_commrec *cr, t_inputrec *ir,
4030                                gmx_pme_t *pme_ret)
4031 {
4032     int ind;
4033     gmx_pme_t pme = NULL;
4034
4035     ind = 0;
4036     while (ind < *npmedata)
4037     {
4038         pme = (*pmedata)[ind];
4039         if (pme->nkx == grid_size[XX] &&
4040             pme->nky == grid_size[YY] &&
4041             pme->nkz == grid_size[ZZ])
4042         {
4043             *pme_ret = pme;
4044
4045             return;
4046         }
4047
4048         ind++;
4049     }
4050
4051     (*npmedata)++;
4052     srenew(*pmedata,*npmedata);
4053
4054     /* Generate a new PME data structure, copying part of the old pointers */
4055     gmx_pme_reinit(&((*pmedata)[ind]),cr,pme,ir,grid_size);
4056
4057     *pme_ret = (*pmedata)[ind];
4058 }
4059
4060
4061 int gmx_pmeonly(gmx_pme_t pme,
4062                 t_commrec *cr,    t_nrnb *nrnb,
4063                 gmx_wallcycle_t wcycle,
4064                 real ewaldcoeff,  gmx_bool bGatherOnly,
4065                 t_inputrec *ir)
4066 {
4067     int npmedata;
4068     gmx_pme_t *pmedata;
4069     gmx_pme_pp_t pme_pp;
4070     int  natoms;
4071     matrix box;
4072     rvec *x_pp=NULL,*f_pp=NULL;
4073     real *chargeA=NULL,*chargeB=NULL;
4074     real lambda=0;
4075     int  maxshift_x=0,maxshift_y=0;
4076     real energy,dvdlambda;
4077     matrix vir;
4078     float cycles;
4079     int  count;
4080     gmx_bool bEnerVir;
4081     gmx_large_int_t step,step_rel;
4082     ivec grid_switch;
4083
4084     /* This data will only use with PME tuning, i.e. switching PME grids */
4085     npmedata = 1;
4086     snew(pmedata,npmedata);
4087     pmedata[0] = pme;
4088
4089     pme_pp = gmx_pme_pp_init(cr);
4090
4091     init_nrnb(nrnb);
4092
4093     count = 0;
4094     do /****** this is a quasi-loop over time steps! */
4095     {
4096         /* The reason for having a loop here is PME grid tuning/switching */
4097         do
4098         {
4099             /* Domain decomposition */
4100             natoms = gmx_pme_recv_q_x(pme_pp,
4101                                       &chargeA,&chargeB,box,&x_pp,&f_pp,
4102                                       &maxshift_x,&maxshift_y,
4103                                       &pme->bFEP,&lambda,
4104                                       &bEnerVir,
4105                                       &step,
4106                                       grid_switch,&ewaldcoeff);
4107
4108             if (natoms == -2)
4109             {
4110                 /* Switch the PME grid to grid_switch */
4111                 gmx_pmeonly_switch(&npmedata,&pmedata,grid_switch,cr,ir,&pme);
4112             }
4113         }
4114         while (natoms == -2);
4115
4116         if (natoms == -1)
4117         {
4118             /* We should stop: break out of the loop */
4119             break;
4120         }
4121
4122         step_rel = step - ir->init_step;
4123
4124         if (count == 0)
4125             wallcycle_start(wcycle,ewcRUN);
4126
4127         wallcycle_start(wcycle,ewcPMEMESH);
4128
4129         dvdlambda = 0;
4130         clear_mat(vir);
4131         gmx_pme_do(pme,0,natoms,x_pp,f_pp,chargeA,chargeB,box,
4132                    cr,maxshift_x,maxshift_y,nrnb,wcycle,vir,ewaldcoeff,
4133                    &energy,lambda,&dvdlambda,
4134                    GMX_PME_DO_ALL_F | (bEnerVir ? GMX_PME_CALC_ENER_VIR : 0));
4135
4136         cycles = wallcycle_stop(wcycle,ewcPMEMESH);
4137
4138         gmx_pme_send_force_vir_ener(pme_pp,
4139                                     f_pp,vir,energy,dvdlambda,
4140                                     cycles);
4141
4142         count++;
4143
4144         if (step_rel == wcycle_get_reset_counters(wcycle))
4145         {
4146             /* Reset all the counters related to performance over the run */
4147             reset_pmeonly_counters(cr,wcycle,nrnb,ir,step_rel);
4148             wcycle_set_reset_counters(wcycle, 0);
4149         }
4150
4151     } /***** end of quasi-loop, we stop with the break above */
4152     while (TRUE);
4153
4154     return 0;
4155 }
4156
4157 int gmx_pme_do(gmx_pme_t pme,
4158                int start,       int homenr,
4159                rvec x[],        rvec f[],
4160                real *chargeA,   real *chargeB,
4161                matrix box, t_commrec *cr,
4162                int  maxshift_x, int maxshift_y,
4163                t_nrnb *nrnb,    gmx_wallcycle_t wcycle,
4164                matrix vir,      real ewaldcoeff,
4165                real *energy,    real lambda,
4166                real *dvdlambda, int flags)
4167 {
4168     int     q,d,i,j,ntot,npme;
4169     int     nx,ny,nz;
4170     int     n_d,local_ny;
4171     pme_atomcomm_t *atc=NULL;
4172     pmegrids_t *pmegrid=NULL;
4173     real    *grid=NULL;
4174     real    *ptr;
4175     rvec    *x_d,*f_d;
4176     real    *charge=NULL,*q_d;
4177     real    energy_AB[2];
4178     matrix  vir_AB[2];
4179     gmx_bool bClearF;
4180     gmx_parallel_3dfft_t pfft_setup;
4181     real *  fftgrid;
4182     t_complex * cfftgrid;
4183     int     thread;
4184     const gmx_bool bCalcEnerVir = flags & GMX_PME_CALC_ENER_VIR;
4185     const gmx_bool bCalcF = flags & GMX_PME_CALC_F;
4186
4187     assert(pme->nnodes > 0);
4188     assert(pme->nnodes == 1 || pme->ndecompdim > 0);
4189
4190     if (pme->nnodes > 1) {
4191         atc = &pme->atc[0];
4192         atc->npd = homenr;
4193         if (atc->npd > atc->pd_nalloc) {
4194             atc->pd_nalloc = over_alloc_dd(atc->npd);
4195             srenew(atc->pd,atc->pd_nalloc);
4196         }
4197         atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4198     }
4199     else
4200     {
4201         /* This could be necessary for TPI */
4202         pme->atc[0].n = homenr;
4203     }
4204
4205     for(q=0; q<(pme->bFEP ? 2 : 1); q++) {
4206         if (q == 0) {
4207             pmegrid = &pme->pmegridA;
4208             fftgrid = pme->fftgridA;
4209             cfftgrid = pme->cfftgridA;
4210             pfft_setup = pme->pfft_setupA;
4211             charge = chargeA+start;
4212         } else {
4213             pmegrid = &pme->pmegridB;
4214             fftgrid = pme->fftgridB;
4215             cfftgrid = pme->cfftgridB;
4216             pfft_setup = pme->pfft_setupB;
4217             charge = chargeB+start;
4218         }
4219         grid = pmegrid->grid.grid;
4220         /* Unpack structure */
4221         if (debug) {
4222             fprintf(debug,"PME: nnodes = %d, nodeid = %d\n",
4223                     cr->nnodes,cr->nodeid);
4224             fprintf(debug,"Grid = %p\n",(void*)grid);
4225             if (grid == NULL)
4226                 gmx_fatal(FARGS,"No grid!");
4227         }
4228         where();
4229
4230         m_inv_ur0(box,pme->recipbox);
4231
4232         if (pme->nnodes == 1) {
4233             atc = &pme->atc[0];
4234             if (DOMAINDECOMP(cr)) {
4235                 atc->n = homenr;
4236                 pme_realloc_atomcomm_things(atc);
4237             }
4238             atc->x = x;
4239             atc->q = charge;
4240             atc->f = f;
4241         } else {
4242             wallcycle_start(wcycle,ewcPME_REDISTXF);
4243             for(d=pme->ndecompdim-1; d>=0; d--)
4244             {
4245                 if (d == pme->ndecompdim-1)
4246                 {
4247                     n_d = homenr;
4248                     x_d = x + start;
4249                     q_d = charge;
4250                 }
4251                 else
4252                 {
4253                     n_d = pme->atc[d+1].n;
4254                     x_d = atc->x;
4255                     q_d = atc->q;
4256                 }
4257                 atc = &pme->atc[d];
4258                 atc->npd = n_d;
4259                 if (atc->npd > atc->pd_nalloc) {
4260                     atc->pd_nalloc = over_alloc_dd(atc->npd);
4261                     srenew(atc->pd,atc->pd_nalloc);
4262                 }
4263                 atc->maxshift = (atc->dimind==0 ? maxshift_x : maxshift_y);
4264                 pme_calc_pidx_wrapper(n_d,pme->recipbox,x_d,atc);
4265                 where();
4266
4267                 GMX_BARRIER(cr->mpi_comm_mygroup);
4268                 /* Redistribute x (only once) and qA or qB */
4269                 if (DOMAINDECOMP(cr)) {
4270                     dd_pmeredist_x_q(pme, n_d, q==0, x_d, q_d, atc);
4271                 } else {
4272                     pmeredist_pd(pme, TRUE, n_d, q==0, x_d, q_d, atc);
4273                 }
4274             }
4275             where();
4276
4277             wallcycle_stop(wcycle,ewcPME_REDISTXF);
4278         }
4279
4280         if (debug)
4281             fprintf(debug,"Node= %6d, pme local particles=%6d\n",
4282                     cr->nodeid,atc->n);
4283
4284         if (flags & GMX_PME_SPREAD_Q)
4285         {
4286             wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4287
4288             /* Spread the charges on a grid */
4289             GMX_MPE_LOG(ev_spread_on_grid_start);
4290
4291             /* Spread the charges on a grid */
4292             spread_on_grid(pme,&pme->atc[0],pmegrid,q==0,TRUE,fftgrid);
4293             GMX_MPE_LOG(ev_spread_on_grid_finish);
4294
4295             if (q == 0)
4296             {
4297                 inc_nrnb(nrnb,eNR_WEIGHTS,DIM*atc->n);
4298             }
4299             inc_nrnb(nrnb,eNR_SPREADQBSP,
4300                      pme->pme_order*pme->pme_order*pme->pme_order*atc->n);
4301
4302             if (pme->nthread == 1)
4303             {
4304                 wrap_periodic_pmegrid(pme,grid);
4305
4306                 /* sum contributions to local grid from other nodes */
4307 #ifdef GMX_MPI
4308                 if (pme->nnodes > 1)
4309                 {
4310                     GMX_BARRIER(cr->mpi_comm_mygroup);
4311                     gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_FORWARD);
4312                     where();
4313                 }
4314 #endif
4315
4316                 copy_pmegrid_to_fftgrid(pme,grid,fftgrid);
4317             }
4318
4319             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4320
4321             /*
4322             dump_local_fftgrid(pme,fftgrid);
4323             exit(0);
4324             */
4325         }
4326
4327         /* Here we start a large thread parallel region */
4328 #pragma omp parallel num_threads(pme->nthread) private(thread)
4329         {
4330             thread=gmx_omp_get_thread_num();
4331             if (flags & GMX_PME_SOLVE)
4332             {
4333                 int loop_count;
4334
4335                 /* do 3d-fft */
4336                 if (thread == 0)
4337                 {
4338                     GMX_BARRIER(cr->mpi_comm_mygroup);
4339                     GMX_MPE_LOG(ev_gmxfft3d_start);
4340                     wallcycle_start(wcycle,ewcPME_FFT);
4341                 }
4342                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_REAL_TO_COMPLEX,
4343                                            fftgrid,cfftgrid,thread,wcycle);
4344                 if (thread == 0)
4345                 {
4346                     wallcycle_stop(wcycle,ewcPME_FFT);
4347                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4348                 }
4349                 where();
4350
4351                 /* solve in k-space for our local cells */
4352                 if (thread == 0)
4353                 {
4354                     GMX_BARRIER(cr->mpi_comm_mygroup);
4355                     GMX_MPE_LOG(ev_solve_pme_start);
4356                     wallcycle_start(wcycle,ewcPME_SOLVE);
4357                 }
4358                 loop_count =
4359                     solve_pme_yzx(pme,cfftgrid,ewaldcoeff,
4360                                   box[XX][XX]*box[YY][YY]*box[ZZ][ZZ],
4361                                   bCalcEnerVir,
4362                                   pme->nthread,thread);
4363                 if (thread == 0)
4364                 {
4365                     wallcycle_stop(wcycle,ewcPME_SOLVE);
4366                     where();
4367                     GMX_MPE_LOG(ev_solve_pme_finish);
4368                     inc_nrnb(nrnb,eNR_SOLVEPME,loop_count);
4369                 }
4370             }
4371
4372             if (bCalcF)
4373             {
4374                 /* do 3d-invfft */
4375                 if (thread == 0)
4376                 {
4377                     GMX_BARRIER(cr->mpi_comm_mygroup);
4378                     GMX_MPE_LOG(ev_gmxfft3d_start);
4379                     where();
4380                     wallcycle_start(wcycle,ewcPME_FFT);
4381                 }
4382                 gmx_parallel_3dfft_execute(pfft_setup,GMX_FFT_COMPLEX_TO_REAL,
4383                                            cfftgrid,fftgrid,thread,wcycle);
4384                 if (thread == 0)
4385                 {
4386                     wallcycle_stop(wcycle,ewcPME_FFT);
4387                     
4388                     where();
4389                     GMX_MPE_LOG(ev_gmxfft3d_finish);
4390
4391                     if (pme->nodeid == 0)
4392                     {
4393                         ntot = pme->nkx*pme->nky*pme->nkz;
4394                         npme  = ntot*log((real)ntot)/log(2.0);
4395                         inc_nrnb(nrnb,eNR_FFT,2*npme);
4396                     }
4397
4398                     wallcycle_start(wcycle,ewcPME_SPREADGATHER);
4399                 }
4400
4401                 copy_fftgrid_to_pmegrid(pme,fftgrid,grid,pme->nthread,thread);
4402             }
4403         }
4404         /* End of thread parallel section.
4405          * With MPI we have to synchronize here before gmx_sum_qgrid_dd.
4406          */
4407
4408         if (bCalcF)
4409         {
4410             /* distribute local grid to all nodes */
4411 #ifdef GMX_MPI
4412             if (pme->nnodes > 1) {
4413                 GMX_BARRIER(cr->mpi_comm_mygroup);
4414                 gmx_sum_qgrid_dd(pme,grid,GMX_SUM_QGRID_BACKWARD);
4415             }
4416 #endif
4417             where();
4418
4419             unwrap_periodic_pmegrid(pme,grid);
4420
4421             /* interpolate forces for our local atoms */
4422             GMX_BARRIER(cr->mpi_comm_mygroup);
4423             GMX_MPE_LOG(ev_gather_f_bsplines_start);
4424
4425             where();
4426
4427             /* If we are running without parallelization,
4428              * atc->f is the actual force array, not a buffer,
4429              * therefore we should not clear it.
4430              */
4431             bClearF = (q == 0 && PAR(cr));
4432 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
4433             for(thread=0; thread<pme->nthread; thread++)
4434             {
4435                 gather_f_bsplines(pme,grid,bClearF,atc,
4436                                   &atc->spline[thread],
4437                                   pme->bFEP ? (q==0 ? 1.0-lambda : lambda) : 1.0);
4438             }
4439
4440             where();
4441
4442             GMX_MPE_LOG(ev_gather_f_bsplines_finish);
4443
4444             inc_nrnb(nrnb,eNR_GATHERFBSP,
4445                      pme->pme_order*pme->pme_order*pme->pme_order*pme->atc[0].n);
4446             wallcycle_stop(wcycle,ewcPME_SPREADGATHER);
4447         }
4448
4449         if (bCalcEnerVir)
4450         {
4451             /* This should only be called on the master thread
4452              * and after the threads have synchronized.
4453              */
4454             get_pme_ener_vir(pme,pme->nthread,&energy_AB[q],vir_AB[q]);
4455         }
4456     } /* of q-loop */
4457
4458     if (bCalcF && pme->nnodes > 1) {
4459         wallcycle_start(wcycle,ewcPME_REDISTXF);
4460         for(d=0; d<pme->ndecompdim; d++)
4461         {
4462             atc = &pme->atc[d];
4463             if (d == pme->ndecompdim - 1)
4464             {
4465                 n_d = homenr;
4466                 f_d = f + start;
4467             }
4468             else
4469             {
4470                 n_d = pme->atc[d+1].n;
4471                 f_d = pme->atc[d+1].f;
4472             }
4473             GMX_BARRIER(cr->mpi_comm_mygroup);
4474             if (DOMAINDECOMP(cr)) {
4475                 dd_pmeredist_f(pme,atc,n_d,f_d,
4476                                d==pme->ndecompdim-1 && pme->bPPnode);
4477             } else {
4478                 pmeredist_pd(pme, FALSE, n_d, TRUE, f_d, NULL, atc);
4479             }
4480         }
4481
4482         wallcycle_stop(wcycle,ewcPME_REDISTXF);
4483     }
4484     where();
4485
4486     if (bCalcEnerVir)
4487     {
4488         if (!pme->bFEP) {
4489             *energy = energy_AB[0];
4490             m_add(vir,vir_AB[0],vir);
4491         } else {
4492             *energy = (1.0-lambda)*energy_AB[0] + lambda*energy_AB[1];
4493             *dvdlambda += energy_AB[1] - energy_AB[0];
4494             for(i=0; i<DIM; i++)
4495             {
4496                 for(j=0; j<DIM; j++)
4497                 {
4498                     vir[i][j] += (1.0-lambda)*vir_AB[0][i][j] + 
4499                         lambda*vir_AB[1][i][j];
4500                 }
4501             }
4502         }
4503     }
4504     else
4505     {
4506         *energy = 0;
4507     }
4508
4509     if (debug)
4510     {
4511         fprintf(debug,"PME mesh energy: %g\n",*energy);
4512     }
4513
4514     return 0;
4515 }