Fix cmake policy warning
[alexxy/gromacs.git] / src / gromacs / ewald / pme_grid.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* TODO find out what this file should be called */
38 #include "gmxpre.h"
39
40 #include "pme_grid.h"
41
42 #include "config.h"
43
44 #include <cstdlib>
45
46 #include "gromacs/ewald/pme.h"
47 #include "gromacs/fft/parallel_3dfft.h"
48 #include "gromacs/math/vec.h"
49 #include "gromacs/timing/cyclecounter.h"
50 #include "gromacs/utility/fatalerror.h"
51 #include "gromacs/utility/smalloc.h"
52
53 #include "pme_internal.h"
54
55 #ifdef DEBUG_PME
56 #include "gromacs/fileio/pdbio.h"
57 #include "gromacs/utility/cstringutil.h"
58 #include "gromacs/utility/futil.h"
59 #endif
60
61 #include "pme_simd.h"
62
63 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
64  * to preserve alignment.
65  */
66 #define GMX_CACHE_SEP 64
67
68 #if GMX_MPI
69 void gmx_sum_qgrid_dd(struct gmx_pme_t *pme, real *grid, int direction)
70 {
71     pme_overlap_t *overlap;
72     int            send_index0, send_nindex;
73     int            recv_index0, recv_nindex;
74     MPI_Status     stat;
75     int            i, j, k, ix, iy, iz, icnt;
76     int            send_id, recv_id, datasize;
77     real          *p;
78     real          *sendptr, *recvptr;
79
80     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
81     overlap = &pme->overlap[1];
82
83     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
84     {
85         /* Since we have already (un)wrapped the overlap in the z-dimension,
86          * we only have to communicate 0 to nkz (not pmegrid_nz).
87          */
88         if (direction == GMX_SUM_GRID_FORWARD)
89         {
90             send_id       = overlap->comm_data[ipulse].send_id;
91             recv_id       = overlap->comm_data[ipulse].recv_id;
92             send_index0   = overlap->comm_data[ipulse].send_index0;
93             send_nindex   = overlap->comm_data[ipulse].send_nindex;
94             recv_index0   = overlap->comm_data[ipulse].recv_index0;
95             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
96         }
97         else
98         {
99             send_id       = overlap->comm_data[ipulse].recv_id;
100             recv_id       = overlap->comm_data[ipulse].send_id;
101             send_index0   = overlap->comm_data[ipulse].recv_index0;
102             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
103             recv_index0   = overlap->comm_data[ipulse].send_index0;
104             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
105         }
106
107         /* Copy data to contiguous send buffer */
108         if (debug)
109         {
110             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
111                     pme->nodeid, overlap->nodeid, send_id,
112                     pme->pmegrid_start_iy,
113                     send_index0-pme->pmegrid_start_iy,
114                     send_index0-pme->pmegrid_start_iy+send_nindex);
115         }
116         icnt = 0;
117         for (i = 0; i < pme->pmegrid_nx; i++)
118         {
119             ix = i;
120             for (j = 0; j < send_nindex; j++)
121             {
122                 iy = j + send_index0 - pme->pmegrid_start_iy;
123                 for (k = 0; k < pme->nkz; k++)
124                 {
125                     iz = k;
126                     overlap->sendbuf[icnt++] = grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz];
127                 }
128             }
129         }
130
131         datasize      = pme->pmegrid_nx * pme->nkz;
132
133         MPI_Sendrecv(overlap->sendbuf.data(), send_nindex*datasize, GMX_MPI_REAL,
134                      send_id, ipulse,
135                      overlap->recvbuf.data(), recv_nindex*datasize, GMX_MPI_REAL,
136                      recv_id, ipulse,
137                      overlap->mpi_comm, &stat);
138
139         /* Get data from contiguous recv buffer */
140         if (debug)
141         {
142             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
143                     pme->nodeid, overlap->nodeid, recv_id,
144                     pme->pmegrid_start_iy,
145                     recv_index0-pme->pmegrid_start_iy,
146                     recv_index0-pme->pmegrid_start_iy+recv_nindex);
147         }
148         icnt = 0;
149         for (i = 0; i < pme->pmegrid_nx; i++)
150         {
151             ix = i;
152             for (j = 0; j < recv_nindex; j++)
153             {
154                 iy = j + recv_index0 - pme->pmegrid_start_iy;
155                 for (k = 0; k < pme->nkz; k++)
156                 {
157                     iz = k;
158                     if (direction == GMX_SUM_GRID_FORWARD)
159                     {
160                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz] += overlap->recvbuf[icnt++];
161                     }
162                     else
163                     {
164                         grid[ix*(pme->pmegrid_ny*pme->pmegrid_nz)+iy*(pme->pmegrid_nz)+iz]  = overlap->recvbuf[icnt++];
165                     }
166                 }
167             }
168         }
169     }
170
171     /* Major dimension is easier, no copying required,
172      * but we might have to sum to separate array.
173      * Since we don't copy, we have to communicate up to pmegrid_nz,
174      * not nkz as for the minor direction.
175      */
176     overlap = &pme->overlap[0];
177
178     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
179     {
180         if (direction == GMX_SUM_GRID_FORWARD)
181         {
182             send_id       = overlap->comm_data[ipulse].send_id;
183             recv_id       = overlap->comm_data[ipulse].recv_id;
184             send_index0   = overlap->comm_data[ipulse].send_index0;
185             send_nindex   = overlap->comm_data[ipulse].send_nindex;
186             recv_index0   = overlap->comm_data[ipulse].recv_index0;
187             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
188             recvptr       = overlap->recvbuf.data();
189         }
190         else
191         {
192             send_id       = overlap->comm_data[ipulse].recv_id;
193             recv_id       = overlap->comm_data[ipulse].send_id;
194             send_index0   = overlap->comm_data[ipulse].recv_index0;
195             send_nindex   = overlap->comm_data[ipulse].recv_nindex;
196             recv_index0   = overlap->comm_data[ipulse].send_index0;
197             recv_nindex   = overlap->comm_data[ipulse].send_nindex;
198             recvptr       = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
199         }
200
201         sendptr       = grid + (send_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
202         datasize      = pme->pmegrid_ny * pme->pmegrid_nz;
203
204         if (debug)
205         {
206             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
207                     pme->nodeid, overlap->nodeid, send_id,
208                     pme->pmegrid_start_ix,
209                     send_index0-pme->pmegrid_start_ix,
210                     send_index0-pme->pmegrid_start_ix+send_nindex);
211             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
212                     pme->nodeid, overlap->nodeid, recv_id,
213                     pme->pmegrid_start_ix,
214                     recv_index0-pme->pmegrid_start_ix,
215                     recv_index0-pme->pmegrid_start_ix+recv_nindex);
216         }
217
218         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
219                      send_id, ipulse,
220                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
221                      recv_id, ipulse,
222                      overlap->mpi_comm, &stat);
223
224         /* ADD data from contiguous recv buffer */
225         if (direction == GMX_SUM_GRID_FORWARD)
226         {
227             p = grid + (recv_index0-pme->pmegrid_start_ix)*(pme->pmegrid_ny*pme->pmegrid_nz);
228             for (i = 0; i < recv_nindex*datasize; i++)
229             {
230                 p[i] += overlap->recvbuf[i];
231             }
232         }
233     }
234 }
235 #endif
236
237
238 int copy_pmegrid_to_fftgrid(const gmx_pme_t *pme, const real *pmegrid, real *fftgrid, int grid_index)
239 {
240     ivec    local_fft_ndata, local_fft_offset, local_fft_size;
241     ivec    local_pme_size;
242     int     ix, iy, iz;
243     int     pmeidx, fftidx;
244
245     /* Dimensions should be identical for A/B grid, so we just use A here */
246     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
247                                    local_fft_ndata,
248                                    local_fft_offset,
249                                    local_fft_size);
250
251     local_pme_size[0] = pme->pmegrid_nx;
252     local_pme_size[1] = pme->pmegrid_ny;
253     local_pme_size[2] = pme->pmegrid_nz;
254
255     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
256        the offset is identical, and the PME grid always has more data (due to overlap)
257      */
258     {
259 #ifdef DEBUG_PME
260         FILE *fp, *fp2;
261         char  fn[STRLEN];
262         real  val;
263         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
264         fp = gmx_ffopen(fn, "w");
265         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
266         fp2 = gmx_ffopen(fn, "w");
267 #endif
268
269         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
270         {
271             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
272             {
273                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
274                 {
275                     pmeidx          = ix*(local_pme_size[YY]*local_pme_size[ZZ])+iy*(local_pme_size[ZZ])+iz;
276                     fftidx          = ix*(local_fft_size[YY]*local_fft_size[ZZ])+iy*(local_fft_size[ZZ])+iz;
277                     fftgrid[fftidx] = pmegrid[pmeidx];
278 #ifdef DEBUG_PME
279                     val = 100*pmegrid[pmeidx];
280                     if (pmegrid[pmeidx] != 0)
281                     {
282                         gmx_fprintf_pdb_atomline(fp, epdbATOM, pmeidx, "CA", ' ', "GLY", ' ', pmeidx, ' ',
283                                                  5.0*ix, 5.0*iy, 5.0*iz, 1.0, val, "");
284                     }
285                     if (pmegrid[pmeidx] != 0)
286                     {
287                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n",
288                                 "qgrid",
289                                 pme->pmegrid_start_ix + ix,
290                                 pme->pmegrid_start_iy + iy,
291                                 pme->pmegrid_start_iz + iz,
292                                 pmegrid[pmeidx]);
293                     }
294 #endif
295                 }
296             }
297         }
298 #ifdef DEBUG_PME
299         gmx_ffclose(fp);
300         gmx_ffclose(fp2);
301 #endif
302     }
303     return 0;
304 }
305
306
307 #ifdef PME_TIME_THREADS
308 static gmx_cycles_t omp_cyc_start()
309 {
310     return gmx_cycles_read();
311 }
312
313 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
314 {
315     return gmx_cycles_read() - c;
316 }
317 #endif
318
319
320 int copy_fftgrid_to_pmegrid(struct gmx_pme_t *pme, const real *fftgrid, real *pmegrid, int grid_index,
321                             int nthread, int thread)
322 {
323     ivec          local_fft_ndata, local_fft_offset, local_fft_size;
324     ivec          local_pme_size;
325     int           ixy0, ixy1, ixy, ix, iy, iz;
326     int           pmeidx, fftidx;
327 #ifdef PME_TIME_THREADS
328     gmx_cycles_t  c1;
329     static double cs1 = 0;
330     static int    cnt = 0;
331 #endif
332
333 #ifdef PME_TIME_THREADS
334     c1 = omp_cyc_start();
335 #endif
336     /* Dimensions should be identical for A/B grid, so we just use A here */
337     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
338                                    local_fft_ndata,
339                                    local_fft_offset,
340                                    local_fft_size);
341
342     local_pme_size[0] = pme->pmegrid_nx;
343     local_pme_size[1] = pme->pmegrid_ny;
344     local_pme_size[2] = pme->pmegrid_nz;
345
346     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
347        the offset is identical, and the PME grid always has more data (due to overlap)
348      */
349     ixy0 = ((thread  )*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
350     ixy1 = ((thread+1)*local_fft_ndata[XX]*local_fft_ndata[YY])/nthread;
351
352     for (ixy = ixy0; ixy < ixy1; ixy++)
353     {
354         ix = ixy/local_fft_ndata[YY];
355         iy = ixy - ix*local_fft_ndata[YY];
356
357         pmeidx = (ix*local_pme_size[YY] + iy)*local_pme_size[ZZ];
358         fftidx = (ix*local_fft_size[YY] + iy)*local_fft_size[ZZ];
359         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
360         {
361             pmegrid[pmeidx+iz] = fftgrid[fftidx+iz];
362         }
363     }
364
365 #ifdef PME_TIME_THREADS
366     c1   = omp_cyc_end(c1);
367     cs1 += (double)c1;
368     cnt++;
369     if (cnt % 20 == 0)
370     {
371         printf("copy %.2f\n", cs1*1e-9);
372     }
373 #endif
374
375     return 0;
376 }
377
378
379 void wrap_periodic_pmegrid(const gmx_pme_t *pme, real *pmegrid)
380 {
381     int     nx, ny, nz, pny, pnz, ny_x, overlap, ix, iy, iz;
382
383     nx = pme->nkx;
384     ny = pme->nky;
385     nz = pme->nkz;
386
387     pny = pme->pmegrid_ny;
388     pnz = pme->pmegrid_nz;
389
390     overlap = pme->pme_order - 1;
391
392     /* Add periodic overlap in z */
393     for (ix = 0; ix < pme->pmegrid_nx; ix++)
394     {
395         for (iy = 0; iy < pme->pmegrid_ny; iy++)
396         {
397             for (iz = 0; iz < overlap; iz++)
398             {
399                 pmegrid[(ix*pny+iy)*pnz+iz] +=
400                     pmegrid[(ix*pny+iy)*pnz+nz+iz];
401             }
402         }
403     }
404
405     if (pme->nnodes_minor == 1)
406     {
407         for (ix = 0; ix < pme->pmegrid_nx; ix++)
408         {
409             for (iy = 0; iy < overlap; iy++)
410             {
411                 for (iz = 0; iz < nz; iz++)
412                 {
413                     pmegrid[(ix*pny+iy)*pnz+iz] +=
414                         pmegrid[(ix*pny+ny+iy)*pnz+iz];
415                 }
416             }
417         }
418     }
419
420     if (pme->nnodes_major == 1)
421     {
422         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
423
424         for (ix = 0; ix < overlap; ix++)
425         {
426             for (iy = 0; iy < ny_x; iy++)
427             {
428                 for (iz = 0; iz < nz; iz++)
429                 {
430                     pmegrid[(ix*pny+iy)*pnz+iz] +=
431                         pmegrid[((nx+ix)*pny+iy)*pnz+iz];
432                 }
433             }
434         }
435     }
436 }
437
438
439 void unwrap_periodic_pmegrid(struct gmx_pme_t *pme, real *pmegrid)
440 {
441     int     nx, ny, nz, pny, pnz, ny_x, overlap, ix;
442
443     nx = pme->nkx;
444     ny = pme->nky;
445     nz = pme->nkz;
446
447     pny = pme->pmegrid_ny;
448     pnz = pme->pmegrid_nz;
449
450     overlap = pme->pme_order - 1;
451
452     if (pme->nnodes_major == 1)
453     {
454         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
455
456         for (ix = 0; ix < overlap; ix++)
457         {
458             int iy, iz;
459
460             for (iy = 0; iy < ny_x; iy++)
461             {
462                 for (iz = 0; iz < nz; iz++)
463                 {
464                     pmegrid[((nx+ix)*pny+iy)*pnz+iz] =
465                         pmegrid[(ix*pny+iy)*pnz+iz];
466                 }
467             }
468         }
469     }
470
471     if (pme->nnodes_minor == 1)
472     {
473 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
474         for (ix = 0; ix < pme->pmegrid_nx; ix++)
475         {
476             // Trivial OpenMP region that does not throw, no need for try/catch
477             int iy, iz;
478
479             for (iy = 0; iy < overlap; iy++)
480             {
481                 for (iz = 0; iz < nz; iz++)
482                 {
483                     pmegrid[(ix*pny+ny+iy)*pnz+iz] =
484                         pmegrid[(ix*pny+iy)*pnz+iz];
485                 }
486             }
487         }
488     }
489
490     /* Copy periodic overlap in z */
491 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
492     for (ix = 0; ix < pme->pmegrid_nx; ix++)
493     {
494         // Trivial OpenMP region that does not throw, no need for try/catch
495         int iy, iz;
496
497         for (iy = 0; iy < pme->pmegrid_ny; iy++)
498         {
499             for (iz = 0; iz < overlap; iz++)
500             {
501                 pmegrid[(ix*pny+iy)*pnz+nz+iz] =
502                     pmegrid[(ix*pny+iy)*pnz+iz];
503             }
504         }
505     }
506 }
507
508 void set_grid_alignment(int gmx_unused *pmegrid_nz, int gmx_unused pme_order)
509 {
510 #ifdef PME_SIMD4_SPREAD_GATHER
511     if (pme_order == 5
512 #if !PME_4NSIMD_GATHER
513         || pme_order == 4
514 #endif
515         )
516     {
517         /* Round nz up to a multiple of 4 to ensure alignment */
518         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
519     }
520 #endif
521 }
522
523 static void set_gridsize_alignment(int gmx_unused *gridsize, int gmx_unused pme_order)
524 {
525 #ifdef PME_SIMD4_SPREAD_GATHER
526 #if !PME_4NSIMD_GATHER
527     if (pme_order == 4)
528     {
529         /* Add extra elements to ensured aligned operations do not go
530          * beyond the allocated grid size.
531          * Note that for pme_order=5, the pme grid z-size alignment
532          * ensures that we will not go beyond the grid size.
533          */
534         *gridsize += 4;
535     }
536 #endif
537 #endif
538 }
539
540 void pmegrid_init(pmegrid_t *grid,
541                   int cx, int cy, int cz,
542                   int x0, int y0, int z0,
543                   int x1, int y1, int z1,
544                   gmx_bool set_alignment,
545                   int pme_order,
546                   real *ptr)
547 {
548     int nz, gridsize;
549
550     grid->ci[XX]     = cx;
551     grid->ci[YY]     = cy;
552     grid->ci[ZZ]     = cz;
553     grid->offset[XX] = x0;
554     grid->offset[YY] = y0;
555     grid->offset[ZZ] = z0;
556     grid->n[XX]      = x1 - x0 + pme_order - 1;
557     grid->n[YY]      = y1 - y0 + pme_order - 1;
558     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
559     copy_ivec(grid->n, grid->s);
560
561     nz = grid->s[ZZ];
562     set_grid_alignment(&nz, pme_order);
563     if (set_alignment)
564     {
565         grid->s[ZZ] = nz;
566     }
567     else if (nz != grid->s[ZZ])
568     {
569         gmx_incons("pmegrid_init call with an unaligned z size");
570     }
571
572     grid->order = pme_order;
573     if (ptr == nullptr)
574     {
575         gridsize = grid->s[XX]*grid->s[YY]*grid->s[ZZ];
576         set_gridsize_alignment(&gridsize, pme_order);
577         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
578     }
579     else
580     {
581         grid->grid = ptr;
582     }
583 }
584
585 static int div_round_up(int enumerator, int denominator)
586 {
587     return (enumerator + denominator - 1)/denominator;
588 }
589
590 static void make_subgrid_division(const ivec n, int ovl, int nthread,
591                                   ivec nsub)
592 {
593     int   gsize_opt, gsize;
594     int   nsx, nsy, nsz;
595     char *env;
596
597     gsize_opt = -1;
598     for (nsx = 1; nsx <= nthread; nsx++)
599     {
600         if (nthread % nsx == 0)
601         {
602             for (nsy = 1; nsy <= nthread; nsy++)
603             {
604                 if (nsx*nsy <= nthread && nthread % (nsx*nsy) == 0)
605                 {
606                     nsz = nthread/(nsx*nsy);
607
608                     /* Determine the number of grid points per thread */
609                     gsize =
610                         (div_round_up(n[XX], nsx) + ovl)*
611                         (div_round_up(n[YY], nsy) + ovl)*
612                         (div_round_up(n[ZZ], nsz) + ovl);
613
614                     /* Minimize the number of grids points per thread
615                      * and, secondarily, the number of cuts in minor dimensions.
616                      */
617                     if (gsize_opt == -1 ||
618                         gsize < gsize_opt ||
619                         (gsize == gsize_opt &&
620                          (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
621                     {
622                         nsub[XX]  = nsx;
623                         nsub[YY]  = nsy;
624                         nsub[ZZ]  = nsz;
625                         gsize_opt = gsize;
626                     }
627                 }
628             }
629         }
630     }
631
632     env = getenv("GMX_PME_THREAD_DIVISION");
633     if (env != nullptr)
634     {
635         sscanf(env, "%20d %20d %20d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
636     }
637
638     if (nsub[XX]*nsub[YY]*nsub[ZZ] != nthread)
639     {
640         gmx_fatal(FARGS, "PME grid thread division (%d x %d x %d) does not match the total number of threads (%d)", nsub[XX], nsub[YY], nsub[ZZ], nthread);
641     }
642 }
643
644 void pmegrids_init(pmegrids_t *grids,
645                    int nx, int ny, int nz, int nz_base,
646                    int pme_order,
647                    gmx_bool bUseThreads,
648                    int nthread,
649                    int overlap_x,
650                    int overlap_y)
651 {
652     ivec n, n_base;
653     int  t, x, y, z, d, i, tfac;
654     int  max_comm_lines = -1;
655
656     n[XX] = nx - (pme_order - 1);
657     n[YY] = ny - (pme_order - 1);
658     n[ZZ] = nz - (pme_order - 1);
659
660     copy_ivec(n, n_base);
661     n_base[ZZ] = nz_base;
662
663     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order,
664                  nullptr);
665
666     grids->nthread = nthread;
667
668     make_subgrid_division(n_base, pme_order-1, grids->nthread, grids->nc);
669
670     if (bUseThreads)
671     {
672         ivec nst;
673         int  gridsize;
674
675         for (d = 0; d < DIM; d++)
676         {
677             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
678         }
679         set_grid_alignment(&nst[ZZ], pme_order);
680
681         if (debug)
682         {
683             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n",
684                     grids->nc[XX], grids->nc[YY], grids->nc[ZZ]);
685             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n",
686                     nx, ny, nz,
687                     nst[XX], nst[YY], nst[ZZ]);
688         }
689
690         snew(grids->grid_th, grids->nthread);
691         t        = 0;
692         gridsize = nst[XX]*nst[YY]*nst[ZZ];
693         set_gridsize_alignment(&gridsize, pme_order);
694         snew_aligned(grids->grid_all,
695                      grids->nthread*gridsize+(grids->nthread+1)*GMX_CACHE_SEP,
696                      SIMD4_ALIGNMENT);
697
698         for (x = 0; x < grids->nc[XX]; x++)
699         {
700             for (y = 0; y < grids->nc[YY]; y++)
701             {
702                 for (z = 0; z < grids->nc[ZZ]; z++)
703                 {
704                     pmegrid_init(&grids->grid_th[t],
705                                  x, y, z,
706                                  (n[XX]*(x  ))/grids->nc[XX],
707                                  (n[YY]*(y  ))/grids->nc[YY],
708                                  (n[ZZ]*(z  ))/grids->nc[ZZ],
709                                  (n[XX]*(x+1))/grids->nc[XX],
710                                  (n[YY]*(y+1))/grids->nc[YY],
711                                  (n[ZZ]*(z+1))/grids->nc[ZZ],
712                                  TRUE,
713                                  pme_order,
714                                  grids->grid_all+GMX_CACHE_SEP+t*(gridsize+GMX_CACHE_SEP));
715                     t++;
716                 }
717             }
718         }
719     }
720     else
721     {
722         grids->grid_th = nullptr;
723     }
724
725     tfac = 1;
726     for (d = DIM-1; d >= 0; d--)
727     {
728         snew(grids->g2t[d], n[d]);
729         t = 0;
730         for (i = 0; i < n[d]; i++)
731         {
732             /* The second check should match the parameters
733              * of the pmegrid_init call above.
734              */
735             while (t + 1 < grids->nc[d] && i >= (n[d]*(t+1))/grids->nc[d])
736             {
737                 t++;
738             }
739             grids->g2t[d][i] = t*tfac;
740         }
741
742         tfac *= grids->nc[d];
743
744         switch (d)
745         {
746             case XX: max_comm_lines = overlap_x;     break;
747             case YY: max_comm_lines = overlap_y;     break;
748             case ZZ: max_comm_lines = pme_order - 1; break;
749         }
750         grids->nthread_comm[d] = 0;
751         while ((n[d]*grids->nthread_comm[d])/grids->nc[d] < max_comm_lines &&
752                grids->nthread_comm[d] < grids->nc[d])
753         {
754             grids->nthread_comm[d]++;
755         }
756         if (debug != nullptr)
757         {
758             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n",
759                     'x'+d, grids->nthread_comm[d]);
760         }
761         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
762          * work, but this is not a problematic restriction.
763          */
764         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
765         {
766             gmx_fatal(FARGS, "Too many threads for PME (%d) compared to the number of grid lines, reduce the number of threads doing PME", grids->nthread);
767         }
768     }
769 }
770
771 void pmegrids_destroy(pmegrids_t *grids)
772 {
773     if (grids->grid.grid != nullptr)
774     {
775         sfree_aligned(grids->grid.grid);
776
777         if (grids->nthread > 0)
778         {
779             sfree_aligned(grids->grid_all);
780             sfree(grids->grid_th);
781         }
782         for (int d = 0; d < DIM; d++)
783         {
784             sfree(grids->g2t[d]);
785         }
786     }
787 }
788
789 void
790 make_gridindex_to_localindex(int n, int local_start, int local_range,
791                              int **global_to_local,
792                              real **fraction_shift)
793 {
794     /* Here we construct array for looking up the grid line index and
795      * fraction for particles. This is done because it is slighlty
796      * faster than the modulo operation and to because we need to take
797      * care of rounding issues, see below.
798      * We use an array size of c_pmeNeighborUnitcellCount times the grid size
799      * to allow for particles to be out of the triclinic unit-cell.
800      */
801     const int arraySize = c_pmeNeighborUnitcellCount * n;
802     int     * gtl;
803     real    * fsh;
804
805     snew(gtl, arraySize);
806     snew(fsh, arraySize);
807
808     for (int i = 0; i < arraySize; i++)
809     {
810         /* Transform global grid index to the local grid index.
811          * Our local grid always runs from 0 to local_range-1.
812          */
813         gtl[i] = (i - local_start + n) % n;
814         /* For coordinates that fall within the local grid the fraction
815          * is correct, we don't need to shift it.
816          */
817         fsh[i] = 0;
818         /* Check if we are using domain decomposition for PME */
819         if (local_range < n)
820         {
821             /* Due to rounding issues i could be 1 beyond the lower or
822              * upper boundary of the local grid. Correct the index for this.
823              * If we shift the index, we need to shift the fraction by
824              * the same amount in the other direction to not affect
825              * the weights.
826              * Note that due to this shifting the weights at the end of
827              * the spline might change, but that will only involve values
828              * between zero and values close to the precision of a real,
829              * which is anyhow the accuracy of the whole mesh calculation.
830              */
831             if (gtl[i] == n - 1)
832             {
833                 /* When this i is used, we should round the local index up */
834                 gtl[i] = 0;
835                 fsh[i] = -1;
836             }
837             else if (gtl[i] == local_range && local_range > 0)
838             {
839                 /* When this i is used, we should round the local index down */
840                 gtl[i] = local_range - 1;
841                 fsh[i] = 1;
842             }
843         }
844     }
845
846     *global_to_local = gtl;
847     *fraction_shift  = fsh;
848 }
849
850 void reuse_pmegrids(const pmegrids_t *oldgrid, pmegrids_t *newgrid)
851 {
852     int d, t;
853
854     for (d = 0; d < DIM; d++)
855     {
856         if (newgrid->grid.n[d] > oldgrid->grid.n[d])
857         {
858             return;
859         }
860     }
861
862     sfree_aligned(newgrid->grid.grid);
863     newgrid->grid.grid = oldgrid->grid.grid;
864
865     if (newgrid->grid_th != nullptr && newgrid->nthread == oldgrid->nthread)
866     {
867         sfree_aligned(newgrid->grid_all);
868         newgrid->grid_all = oldgrid->grid_all;
869         for (t = 0; t < newgrid->nthread; t++)
870         {
871             newgrid->grid_th[t].grid = oldgrid->grid_th[t].grid;
872         }
873     }
874 }