added Verlet scheme and NxN non-bonded functionality
[alexxy/gromacs.git] / src / mdlib / pme.c
index 8342621cad818ecbe0e5fd396580cdf47e2656fa..a9fa5c1d07b2ae4cf7551916c5f61d997ace7a3f 100644 (file)
@@ -92,8 +92,7 @@
 /* Single precision, with SSE2 or higher available */
 #if defined(GMX_X86_SSE2) && !defined(GMX_DOUBLE)
 
-#include "gmx_x86_sse2.h"
-#include "gmx_math_x86_sse2_single.h"
+#include "gmx_x86_simd_single.h"
 
 #define PME_SSE
 /* Some old AMD processors could have problems with unaligned loads+stores */
@@ -133,6 +132,7 @@ typedef struct {
     int send_nindex;
     int recv_index0;
     int recv_nindex;
+    int recv_size;   /* Receive buffer width, used with OpenMP */
 } pme_grid_comm_t;
 
 typedef struct {
@@ -144,6 +144,7 @@ typedef struct {
     int  *s2g1;
     int  noverlap_nodes;
     int  *send_id,*recv_id;
+    int  send_size;             /* Send buffer width, used with OpenMP */
     pme_grid_comm_t *comm_data;
     real *sendbuf;
     real *recvbuf;
@@ -156,10 +157,13 @@ typedef struct {
 } thread_plist_t;
 
 typedef struct {
+    int  *thread_one;
     int  n;
     int  *ind;
     splinevec theta;
+    real *ptr_theta_z;
     splinevec dtheta;
+    real *ptr_dtheta_z;
 } splinedata_t;
 
 typedef struct {
@@ -204,11 +208,12 @@ typedef struct {
 #define FLBSZ 4
 
 typedef struct {
-    ivec ci;     /* The spatial location of this grid       */
-    ivec n;      /* The size of *grid, including order-1    */
-    ivec offset; /* The grid offset from the full node grid */
-    int  order;  /* PME spreading order                     */
-    real *grid;  /* The grid local thread, size n           */
+    ivec ci;     /* The spatial location of this grid         */
+    ivec n;      /* The used size of *grid, including order-1 */
+    ivec offset; /* The grid offset from the full node grid   */
+    int  order;  /* PME spreading order                       */
+    ivec s;      /* The allocated size of *grid, s >= n       */
+    real *grid;  /* The grid local thread, size n             */
 } pmegrid_t;
 
 typedef struct {
@@ -216,6 +221,7 @@ typedef struct {
     int  nthread;       /* The number of threads operating on this grid     */
     ivec nc;            /* The local spatial decomposition over the threads */
     pmegrid_t *grid_th; /* Array of grids for each thread                   */
+    real *grid_all;     /* Allocated array for the grids in *grid_th        */
     int  **g2t;         /* The grid to thread index                         */
     ivec nthread_comm;  /* The number of threads to communicate with        */
 } pmegrids_t;
@@ -563,6 +569,24 @@ static void pme_calc_pidx_wrapper(int natoms, matrix recipbox, rvec x[],
     }
 }
 
+static void realloc_splinevec(splinevec th,real **ptr_z,int nalloc)
+{
+    const int padding=4;
+    int i;
+
+    srenew(th[XX],nalloc);
+    srenew(th[YY],nalloc);
+    /* In z we add padding, this is only required for the aligned SSE code */
+    srenew(*ptr_z,nalloc+2*padding);
+    th[ZZ] = *ptr_z + padding;
+
+    for(i=0; i<padding; i++)
+    {
+        (*ptr_z)[               i] = 0;
+        (*ptr_z)[padding+nalloc+i] = 0;
+    }
+}
+
 static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
 {
     int i,d;
@@ -574,11 +598,10 @@ static void pme_realloc_splinedata(splinedata_t *spline, pme_atomcomm_t *atc)
         spline->ind[i] = i;
     }
 
-    for(d=0;d<DIM;d++)
-    {
-        srenew(spline->theta[d] ,atc->pme_order*atc->nalloc);
-        srenew(spline->dtheta[d],atc->pme_order*atc->nalloc);
-    }
+    realloc_splinevec(spline->theta,&spline->ptr_theta_z,
+                      atc->pme_order*atc->nalloc);
+    realloc_splinevec(spline->dtheta,&spline->ptr_dtheta_z,
+                      atc->pme_order*atc->nalloc);
 }
 
 static void pme_realloc_atomcomm_things(pme_atomcomm_t *atc)
@@ -1425,9 +1448,9 @@ static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
     int      pnx,pny,pnz,ndatatot;
     int      offx,offy,offz;
 
-    pnx = pmegrid->n[XX];
-    pny = pmegrid->n[YY];
-    pnz = pmegrid->n[ZZ];
+    pnx = pmegrid->s[XX];
+    pny = pmegrid->s[YY];
+    pnz = pmegrid->s[ZZ];
 
     offx = pmegrid->offset[XX];
     offy = pmegrid->offset[YY];
@@ -1439,7 +1462,7 @@ static void spread_q_bsplines_thread(pmegrid_t *pmegrid,
     {
         grid[i] = 0;
     }
-
+    
     order = pmegrid->order;
 
     for(nn=0; nn<spline->n; nn++)
@@ -1542,14 +1565,15 @@ static void pmegrid_init(pmegrid_t *grid,
     grid->n[XX]      = x1 - x0 + pme_order - 1;
     grid->n[YY]      = y1 - y0 + pme_order - 1;
     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
+    copy_ivec(grid->n,grid->s);
 
-    nz = grid->n[ZZ];
+    nz = grid->s[ZZ];
     set_grid_alignment(&nz,pme_order);
     if (set_alignment)
     {
-        grid->n[ZZ] = nz;
+        grid->s[ZZ] = nz;
     }
-    else if (nz != grid->n[ZZ])
+    else if (nz != grid->s[ZZ])
     {
         gmx_incons("pmegrid_init call with an unaligned z size");
     }
@@ -1557,7 +1581,7 @@ static void pmegrid_init(pmegrid_t *grid,
     grid->order = pme_order;
     if (ptr == NULL)
     {
-        gridsize = grid->n[XX]*grid->n[YY]*grid->n[ZZ];
+        gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
         set_gridsize_alignment(&gridsize,pme_order);
         snew_aligned(grid->grid,gridsize,16);
     }
@@ -1635,7 +1659,7 @@ static void pmegrids_init(pmegrids_t *grids,
 {
     ivec n,n_base,g0,g1;
     int t,x,y,z,d,i,tfac;
-    int max_comm_lines;
+    int max_comm_lines=-1;
 
     n[XX] = nx - (pme_order - 1);
     n[YY] = ny - (pme_order - 1);
@@ -1655,7 +1679,6 @@ static void pmegrids_init(pmegrids_t *grids,
     {
         ivec nst;
         int gridsize;
-        real *grid_all;
 
         for(d=0; d<DIM; d++)
         {
@@ -1676,7 +1699,7 @@ static void pmegrids_init(pmegrids_t *grids,
         t = 0;
         gridsize = nst[XX]*nst[YY]*nst[ZZ];
         set_gridsize_alignment(&gridsize,pme_order);
-        snew_aligned(grid_all,
+        snew_aligned(grids->grid_all,
                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
                      16);
 
@@ -1696,7 +1719,7 @@ static void pmegrids_init(pmegrids_t *grids,
                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
                                  TRUE,
                                  pme_order,
-                                 grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
+                                 grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
                     t++;
                 }
             }
@@ -1730,7 +1753,8 @@ static void pmegrids_init(pmegrids_t *grids,
         case ZZ: max_comm_lines = pme_order - 1; break;
         }
         grids->nthread_comm[d] = 0;
-        while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines)
+        while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
+               grids->nthread_comm[d] < grids->nc[d])
         {
             grids->nthread_comm[d]++;
         }
@@ -2694,6 +2718,8 @@ static void init_atomcomm(gmx_pme_t pme,pme_atomcomm_t *atc, t_commrec *cr,
             snew(atc->thread_plist[thread].n,atc->nthread+2*GMX_CACHE_SEP);
             atc->thread_plist[thread].n += GMX_CACHE_SEP;
         }
+        snew(atc->spline[thread].thread_one,pme->nthread);
+        atc->spline[thread].thread_one[thread] = 1;
     }
 }
 
@@ -2714,15 +2740,16 @@ init_overlap_comm(pme_overlap_t *  ol,
     pme_grid_comm_t *pgc;
     gmx_bool bCont;
     int fft_start,fft_end,send_index1,recv_index1;
-
 #ifdef GMX_MPI
+    MPI_Status stat;
+
     ol->mpi_comm = comm;
 #endif
 
     ol->nnodes = nnodes;
     ol->nodeid = nodeid;
 
-    /* Linear translation of the PME grid wo'nt affect reciprocal space
+    /* Linear translation of the PME grid won't affect reciprocal space
      * calculations, so to optimize we only interpolate "upwards",
      * which also means we only have to consider overlap in one direction.
      * I.e., particles on this node might also be spread to grid indices
@@ -2777,6 +2804,7 @@ init_overlap_comm(pme_overlap_t *  ol,
     }
     snew(ol->comm_data, ol->noverlap_nodes);
 
+    ol->send_size = 0;
     for(b=0; b<ol->noverlap_nodes; b++)
     {
         pgc = &ol->comm_data[b];
@@ -2792,6 +2820,7 @@ init_overlap_comm(pme_overlap_t *  ol,
         send_index1      = min(send_index1,fft_end);
         pgc->send_index0 = fft_start;
         pgc->send_nindex = max(0,send_index1 - pgc->send_index0);
+        ol->send_size    += pgc->send_nindex;
 
         /* We always start receiving to the first index of our slab */
         fft_start        = ol->s2g0[ol->nodeid];
@@ -2806,6 +2835,16 @@ init_overlap_comm(pme_overlap_t *  ol,
         pgc->recv_nindex = max(0,recv_index1 - pgc->recv_index0);
     }
 
+#ifdef GMX_MPI
+    /* Communicate the buffer sizes to receive */
+    for(b=0; b<ol->noverlap_nodes; b++)
+    {
+        MPI_Sendrecv(&ol->send_size             ,1,MPI_INT,ol->send_id[b],b,
+                     &ol->comm_data[b].recv_size,1,MPI_INT,ol->recv_id[b],b,
+                     ol->mpi_comm,&stat);
+    }
+#endif
+
     /* For non-divisible grid we need pme_order iso pme_order-1 */
     snew(ol->sendbuf,norder*commplainsize);
     snew(ol->recvbuf,norder*commplainsize);
@@ -3075,7 +3114,7 @@ int gmx_pme_init(gmx_pme_t *         pmedata,
         pme->nky <= pme->pme_order*(pme->nnodes_minor > 1 ? 2 : 1) ||
         pme->nkz <= pme->pme_order)
     {
-        gmx_fatal(FARGS,"The pme grid dimensions need to be larger than pme_order (%d) and in parallel larger than 2*pme_ordern for x and/or y",pme->pme_order);
+        gmx_fatal(FARGS,"The PME grid sizes need to be larger than pme_order (%d) and for dimensions with domain decomposition larger than 2*pme_order",pme->pme_order);
     }
 
     if (pme->nnodes > 1) {
@@ -3121,20 +3160,26 @@ int gmx_pme_init(gmx_pme_t *         pmedata,
                       pme->nkx,
                       (div_round_up(pme->nky,pme->nnodes_minor)+pme->pme_order)*(pme->nkz+pme->pme_order-1));
 
+    /* Along overlap dim 1 we can send in multiple pulses in sum_fftgrid_dd.
+     * We do this with an offset buffer of equal size, so we need to allocate
+     * extra for the offset. That's what the (+1)*pme->nkz is for.
+     */
     init_overlap_comm(&pme->overlap[1],pme->pme_order,
 #ifdef GMX_MPI
                       pme->mpi_comm_d[1],
 #endif
                       pme->nnodes_minor,pme->nodeid_minor,
                       pme->nky,
-                      (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order)*pme->nkz);
+                      (div_round_up(pme->nkx,pme->nnodes_major)+pme->pme_order+1)*pme->nkz);
 
-    /* Check for a limitation of the (current) sum_fftgrid_dd code */
-    if (pme->nthread > 1 &&
-        (pme->overlap[0].noverlap_nodes > 1 ||
-         pme->overlap[1].noverlap_nodes > 1))
+    /* Check for a limitation of the (current) sum_fftgrid_dd code.
+     * We only allow multiple communication pulses in dim 1, not in dim 0.
+     */
+    if (pme->nthread > 1 && (pme->overlap[0].noverlap_nodes > 1 ||
+                             pme->nkx < pme->nnodes_major*pme->pme_order))
     {
-        gmx_fatal(FARGS,"With threads the number of grid lines per node along x and or y should be pme_order (%d) or more or exactly pme_order-1",pme->pme_order);
+        gmx_fatal(FARGS,"The number of PME grid lines per node along x is %g. But when using OpenMP threads, the number of grid lines per node along x and should be >= pme_order (%d). To resolve this issue, use less nodes along x (and possibly more along y and/or z) by specifying -dd manually.",
+                  pme->nkx/(double)pme->nnodes_major,pme->pme_order);
     }
 
     snew(pme->bsp_mod[XX],pme->nkx);
@@ -3249,10 +3294,72 @@ int gmx_pme_init(gmx_pme_t *         pmedata,
     }
 
     *pmedata = pme;
-
+    
     return 0;
 }
 
+static void reuse_pmegrids(const pmegrids_t *old,pmegrids_t *new)
+{
+    int d,t;
+
+    for(d=0; d<DIM; d++)
+    {
+        if (new->grid.n[d] > old->grid.n[d])
+        {
+            return;
+        }
+    }
+
+    sfree_aligned(new->grid.grid);
+    new->grid.grid = old->grid.grid;
+
+    if (new->nthread > 1 && new->nthread == old->nthread)
+    {
+        sfree_aligned(new->grid_all);
+        for(t=0; t<new->nthread; t++)
+        {
+            new->grid_th[t].grid = old->grid_th[t].grid;
+        }
+    }
+}
+
+int gmx_pme_reinit(gmx_pme_t *         pmedata,
+                   t_commrec *         cr,
+                   gmx_pme_t           pme_src,
+                   const t_inputrec *  ir,
+                   ivec                grid_size)
+{
+    t_inputrec irc;
+    int homenr;
+    int ret;
+
+    irc = *ir;
+    irc.nkx = grid_size[XX];
+    irc.nky = grid_size[YY];
+    irc.nkz = grid_size[ZZ];
+
+    if (pme_src->nnodes == 1)
+    {
+        homenr = pme_src->atc[0].n;
+    }
+    else
+    {
+        homenr = -1;
+    }
+
+    ret = gmx_pme_init(pmedata,cr,pme_src->nnodes_major,pme_src->nnodes_minor,
+                       &irc,homenr,pme_src->bFEP,FALSE,pme_src->nthread);
+
+    if (ret == 0)
+    {
+        /* We can easily reuse the allocated pme grids in pme_src */
+        reuse_pmegrids(&pme_src->pmegridA,&(*pmedata)->pmegridA);
+        /* We would like to reuse the fft grids, but that's harder */
+    }
+
+    return ret;
+}
+
 
 static void copy_local_grid(gmx_pme_t pme,
                             pmegrids_t *pmegrids,int thread,real *fftgrid)
@@ -3275,9 +3382,9 @@ static void copy_local_grid(gmx_pme_t pme,
 
     pmegrid = &pmegrids->grid_th[thread];
 
-    nsx = pmegrid->n[XX];
-    nsy = pmegrid->n[YY];
-    nsz = pmegrid->n[ZZ];
+    nsx = pmegrid->s[XX];
+    nsy = pmegrid->s[YY];
+    nsz = pmegrid->s[ZZ];
 
     for(d=0; d<DIM; d++)
     {
@@ -3307,41 +3414,6 @@ static void copy_local_grid(gmx_pme_t pme,
     }
 }
 
-static void print_sendbuf(gmx_pme_t pme,real *sendbuf)
-{
-    ivec local_fft_ndata,local_fft_offset,local_fft_size;
-    pme_overlap_t *overlap;
-    int datasize,nind;
-    int i,x,y,z,n;
-
-    gmx_parallel_3dfft_real_limits(pme->pfft_setupA,
-                                   local_fft_ndata,
-                                   local_fft_offset,
-                                   local_fft_size);
-    /* Major dimension */
-    overlap = &pme->overlap[0];
-
-    nind   = overlap->comm_data[0].send_nindex;
-
-    for(y=0; y<local_fft_ndata[YY]; y++) {
-         printf(" %2d",y);
-    }
-    printf("\n");
-
-    i = 0;
-    for(x=0; x<nind; x++) {
-        for(y=0; y<local_fft_ndata[YY]; y++) {
-            n = 0;
-            for(z=0; z<local_fft_ndata[ZZ]; z++) {
-                if (sendbuf[i] != 0) n++;
-                i++;
-            }
-            printf(" %2d",n);
-        }
-        printf("\n");
-    }
-}
-
 static void
 reduce_threadgrid_overlap(gmx_pme_t pme,
                           const pmegrids_t *pmegrids,int thread,
@@ -3476,9 +3548,9 @@ reduce_threadgrid_overlap(gmx_pme_t pme,
 
                 grid_th = pmegrid_f->grid;
 
-                nsx = pmegrid_f->n[XX];
-                nsy = pmegrid_f->n[YY];
-                nsz = pmegrid_f->n[ZZ];
+                nsx = pmegrid_f->s[XX];
+                nsy = pmegrid_f->s[YY];
+                nsz = pmegrid_f->s[ZZ];
 
 #ifdef DEBUG_PME_REDUCE
                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
@@ -3575,11 +3647,12 @@ static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
 {
     ivec local_fft_ndata,local_fft_offset,local_fft_size;
     pme_overlap_t *overlap;
-    int  send_nindex;
-    int  recv_index0,recv_nindex;
+    int  send_index0,send_nindex;
+    int  recv_nindex;
 #ifdef GMX_MPI
     MPI_Status stat;
 #endif
+    int  send_size_y,recv_size_y;
     int  ipulse,send_id,recv_id,datasize,gridsize,size_yx;
     real *sendptr,*recvptr;
     int  x,y,z,indg,indb;
@@ -3596,9 +3669,6 @@ static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
                                    local_fft_offset,
                                    local_fft_size);
 
-    /* Currently supports only a single communication pulse */
-
-/* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
     if (pme->nnodes_minor > 1)
     {
         /* Major dimension */
@@ -3612,66 +3682,70 @@ static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
         {
             size_yx = 0;
         }
-        datasize = (local_fft_ndata[XX]+size_yx)*local_fft_ndata[ZZ];
+        datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
 
-        ipulse = 0;
+        send_size_y = overlap->send_size;
 
-        send_id = overlap->send_id[ipulse];
-        recv_id = overlap->recv_id[ipulse];
-        send_nindex   = overlap->comm_data[ipulse].send_nindex;
-        /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
-        recv_index0 = 0;
-        recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
-
-        sendptr = overlap->sendbuf;
-        recvptr = overlap->recvbuf;
+        for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++)
+        {
+            send_id = overlap->send_id[ipulse];
+            recv_id = overlap->recv_id[ipulse];
+            send_index0   =
+                overlap->comm_data[ipulse].send_index0 -
+                overlap->comm_data[0].send_index0;
+            send_nindex   = overlap->comm_data[ipulse].send_nindex;
+            /* We don't use recv_index0, as we always receive starting at 0 */
+            recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
+            recv_size_y   = overlap->comm_data[ipulse].recv_size;
 
-        /*
-        printf("node %d comm %2d x %2d x %2d\n",pme->nodeid,
-               local_fft_ndata[XX]+size_yx,send_nindex,local_fft_ndata[ZZ]);
-        printf("node %d send %f, %f\n",pme->nodeid,
-               sendptr[0],sendptr[send_nindex*datasize-1]);
-        */
+            sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
+            recvptr = overlap->recvbuf;
 
 #ifdef GMX_MPI
-        MPI_Sendrecv(sendptr,send_nindex*datasize,GMX_MPI_REAL,
-                     send_id,ipulse,
-                     recvptr,recv_nindex*datasize,GMX_MPI_REAL,
-                     recv_id,ipulse,
-                     overlap->mpi_comm,&stat);
+            MPI_Sendrecv(sendptr,send_size_y*datasize,GMX_MPI_REAL,
+                         send_id,ipulse,
+                         recvptr,recv_size_y*datasize,GMX_MPI_REAL,
+                         recv_id,ipulse,
+                         overlap->mpi_comm,&stat);
 #endif
 
-        for(x=0; x<local_fft_ndata[XX]; x++)
-        {
-            for(y=0; y<recv_nindex; y++)
+            for(x=0; x<local_fft_ndata[XX]; x++)
             {
-                indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
-                indb = (x*recv_nindex        + y)*local_fft_ndata[ZZ];
-                for(z=0; z<local_fft_ndata[ZZ]; z++)
+                for(y=0; y<recv_nindex; y++)
                 {
-                    fftgrid[indg+z] += recvptr[indb+z];
+                    indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
+                    indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
+                    for(z=0; z<local_fft_ndata[ZZ]; z++)
+                    {
+                        fftgrid[indg+z] += recvptr[indb+z];
+                    }
                 }
             }
-        }
-        if (pme->nnodes_major > 1)
-        {
-            sendptr = pme->overlap[0].sendbuf;
-            for(x=0; x<size_yx; x++)
+
+            if (pme->nnodes_major > 1)
             {
-                for(y=0; y<recv_nindex; y++)
+                /* Copy from the received buffer to the send buffer for dim 0 */
+                sendptr = pme->overlap[0].sendbuf;
+                for(x=0; x<size_yx; x++)
                 {
-                    indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
-                    indb = ((local_fft_ndata[XX] + x)*recv_nindex +y)*local_fft_ndata[ZZ];
-                    for(z=0; z<local_fft_ndata[ZZ]; z++)
+                    for(y=0; y<recv_nindex; y++)
                     {
-                        sendptr[indg+z] += recvptr[indb+z];
+                        indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
+                        indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
+                        for(z=0; z<local_fft_ndata[ZZ]; z++)
+                        {
+                            sendptr[indg+z] += recvptr[indb+z];
+                        }
                     }
                 }
             }
         }
     }
 
-    /* for(ipulse=0;ipulse<overlap->noverlap_nodes;ipulse++) */
+    /* We only support a single pulse here.
+     * This is not a severe limitation, as this code is only used
+     * with OpenMP and with OpenMP the (PME) domains can be larger.
+     */
     if (pme->nnodes_major > 1)
     {
         /* Major dimension */
@@ -3685,8 +3759,7 @@ static void sum_fftgrid_dd(gmx_pme_t pme,real *fftgrid)
         send_id = overlap->send_id[ipulse];
         recv_id = overlap->recv_id[ipulse];
         send_nindex   = overlap->comm_data[ipulse].send_nindex;
-        /* recv_index0   = overlap->comm_data[ipulse].recv_index0; */
-        recv_index0 = 0;
+        /* We don't use recv_index0, as we always receive starting at 0 */
         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
 
         sendptr = overlap->sendbuf;
@@ -3830,9 +3903,6 @@ static void spread_on_grid(gmx_pme_t pme,
                                       fftgrid,
                                       pme->overlap[0].sendbuf,
                                       pme->overlap[1].sendbuf);
-#ifdef PRINT_PME_SENDBUF
-            print_sendbuf(pme,pme->overlap[0].sendbuf);
-#endif
         }
 #ifdef PME_TIME_THREADS
         c3 = omp_cyc_end(c3);
@@ -3953,12 +4023,48 @@ static void reset_pmeonly_counters(t_commrec *cr,gmx_wallcycle_t wcycle,
 }
 
 
+static void gmx_pmeonly_switch(int *npmedata, gmx_pme_t **pmedata,
+                               ivec grid_size,
+                               t_commrec *cr, t_inputrec *ir,
+                               gmx_pme_t *pme_ret)
+{
+    int ind;
+    gmx_pme_t pme = NULL;
+
+    ind = 0;
+    while (ind < *npmedata)
+    {
+        pme = (*pmedata)[ind];
+        if (pme->nkx == grid_size[XX] &&
+            pme->nky == grid_size[YY] &&
+            pme->nkz == grid_size[ZZ])
+        {
+            *pme_ret = pme;
+
+            return;
+        }
+
+        ind++;
+    }
+
+    (*npmedata)++;
+    srenew(*pmedata,*npmedata);
+
+    /* Generate a new PME data structure, copying part of the old pointers */
+    gmx_pme_reinit(&((*pmedata)[ind]),cr,pme,ir,grid_size);
+
+    *pme_ret = (*pmedata)[ind];
+}
+
+
 int gmx_pmeonly(gmx_pme_t pme,
                 t_commrec *cr,    t_nrnb *nrnb,
                 gmx_wallcycle_t wcycle,
                 real ewaldcoeff,  gmx_bool bGatherOnly,
                 t_inputrec *ir)
 {
+    int npmedata;
+    gmx_pme_t *pmedata;
     gmx_pme_pp_t pme_pp;
     int  natoms;
     matrix box;
@@ -3972,7 +4078,12 @@ int gmx_pmeonly(gmx_pme_t pme,
     int  count;
     gmx_bool bEnerVir;
     gmx_large_int_t step,step_rel;
+    ivec grid_switch;
 
+    /* This data will only use with PME tuning, i.e. switching PME grids */
+    npmedata = 1;
+    snew(pmedata,npmedata);
+    pmedata[0] = pme;
 
     pme_pp = gmx_pme_pp_init(cr);
 
@@ -3981,15 +4092,28 @@ int gmx_pmeonly(gmx_pme_t pme,
     count = 0;
     do /****** this is a quasi-loop over time steps! */
     {
-        /* Domain decomposition */
-        natoms = gmx_pme_recv_q_x(pme_pp,
-                                  &chargeA,&chargeB,box,&x_pp,&f_pp,
-                                  &maxshift_x,&maxshift_y,
-                                  &pme->bFEP,&lambda,
-                                  &bEnerVir,
-                                  &step);
-
-        if (natoms == -1) {
+        /* The reason for having a loop here is PME grid tuning/switching */
+        do
+        {
+            /* Domain decomposition */
+            natoms = gmx_pme_recv_q_x(pme_pp,
+                                      &chargeA,&chargeB,box,&x_pp,&f_pp,
+                                      &maxshift_x,&maxshift_y,
+                                      &pme->bFEP,&lambda,
+                                      &bEnerVir,
+                                      &step,
+                                      grid_switch,&ewaldcoeff);
+
+            if (natoms == -2)
+            {
+                /* Switch the PME grid to grid_switch */
+                gmx_pmeonly_switch(&npmedata,&pmedata,grid_switch,cr,ir,&pme);
+            }
+        }
+        while (natoms == -2);
+
+        if (natoms == -1)
+        {
             /* We should stop: break out of the loop */
             break;
         }
@@ -4259,7 +4383,7 @@ int gmx_pme_do(gmx_pme_t pme,
                 if (thread == 0)
                 {
                     wallcycle_stop(wcycle,ewcPME_FFT);
-
+                    
                     where();
                     GMX_MPE_LOG(ev_gmxfft3d_finish);