3d277d314666eb1d123aa07be7455d5a8b69fca0
[alexxy/gromacs.git] / src / gromacs / ewald / pme_grid.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* TODO find out what this file should be called */
39 #include "gmxpre.h"
40
41 #include "pme_grid.h"
42
43 #include "config.h"
44
45 #include <cstdlib>
46
47 #include "gromacs/ewald/pme.h"
48 #include "gromacs/fft/parallel_3dfft.h"
49 #include "gromacs/math/vec.h"
50 #include "gromacs/timing/cyclecounter.h"
51 #include "gromacs/utility/fatalerror.h"
52 #include "gromacs/utility/smalloc.h"
53
54 #include "pme_internal.h"
55
56 #ifdef DEBUG_PME
57 #    include "gromacs/fileio/pdbio.h"
58 #    include "gromacs/utility/cstringutil.h"
59 #    include "gromacs/utility/futil.h"
60 #endif
61
62 #include "pme_simd.h"
63
64 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
65  * to preserve alignment.
66  */
67 #define GMX_CACHE_SEP 64
68
69 void gmx_sum_qgrid_dd(gmx_pme_t* pme, real* grid, const int direction)
70 {
71 #if GMX_MPI
72     pme_overlap_t* overlap;
73     int            send_index0, send_nindex;
74     int            recv_index0, recv_nindex;
75     MPI_Status     stat;
76     int            i, j, k, ix, iy, iz, icnt;
77     int            send_id, recv_id, datasize;
78     real*          p;
79     real *         sendptr, *recvptr;
80
81     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
82     overlap = &pme->overlap[1];
83
84     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
85     {
86         /* Since we have already (un)wrapped the overlap in the z-dimension,
87          * we only have to communicate 0 to nkz (not pmegrid_nz).
88          */
89         if (direction == GMX_SUM_GRID_FORWARD)
90         {
91             send_id     = overlap->comm_data[ipulse].send_id;
92             recv_id     = overlap->comm_data[ipulse].recv_id;
93             send_index0 = overlap->comm_data[ipulse].send_index0;
94             send_nindex = overlap->comm_data[ipulse].send_nindex;
95             recv_index0 = overlap->comm_data[ipulse].recv_index0;
96             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
97         }
98         else
99         {
100             send_id     = overlap->comm_data[ipulse].recv_id;
101             recv_id     = overlap->comm_data[ipulse].send_id;
102             send_index0 = overlap->comm_data[ipulse].recv_index0;
103             send_nindex = overlap->comm_data[ipulse].recv_nindex;
104             recv_index0 = overlap->comm_data[ipulse].send_index0;
105             recv_nindex = overlap->comm_data[ipulse].send_nindex;
106         }
107
108         /* Copy data to contiguous send buffer */
109         if (debug)
110         {
111             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
112                     pme->nodeid, overlap->nodeid, send_id, pme->pmegrid_start_iy,
113                     send_index0 - pme->pmegrid_start_iy,
114                     send_index0 - pme->pmegrid_start_iy + send_nindex);
115         }
116         icnt = 0;
117         for (i = 0; i < pme->pmegrid_nx; i++)
118         {
119             ix = i;
120             for (j = 0; j < send_nindex; j++)
121             {
122                 iy = j + send_index0 - pme->pmegrid_start_iy;
123                 for (k = 0; k < pme->nkz; k++)
124                 {
125                     iz = k;
126                     overlap->sendbuf[icnt++] =
127                             grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz];
128                 }
129             }
130         }
131
132         datasize = pme->pmegrid_nx * pme->nkz;
133
134         MPI_Sendrecv(overlap->sendbuf.data(), send_nindex * datasize, GMX_MPI_REAL, send_id, ipulse,
135                      overlap->recvbuf.data(), recv_nindex * datasize, GMX_MPI_REAL, recv_id, ipulse,
136                      overlap->mpi_comm, &stat);
137
138         /* Get data from contiguous recv buffer */
139         if (debug)
140         {
141             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
142                     pme->nodeid, overlap->nodeid, recv_id, pme->pmegrid_start_iy,
143                     recv_index0 - pme->pmegrid_start_iy,
144                     recv_index0 - pme->pmegrid_start_iy + recv_nindex);
145         }
146         icnt = 0;
147         for (i = 0; i < pme->pmegrid_nx; i++)
148         {
149             ix = i;
150             for (j = 0; j < recv_nindex; j++)
151             {
152                 iy = j + recv_index0 - pme->pmegrid_start_iy;
153                 for (k = 0; k < pme->nkz; k++)
154                 {
155                     iz = k;
156                     if (direction == GMX_SUM_GRID_FORWARD)
157                     {
158                         grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz] +=
159                                 overlap->recvbuf[icnt++];
160                     }
161                     else
162                     {
163                         grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz] =
164                                 overlap->recvbuf[icnt++];
165                     }
166                 }
167             }
168         }
169     }
170
171     /* Major dimension is easier, no copying required,
172      * but we might have to sum to separate array.
173      * Since we don't copy, we have to communicate up to pmegrid_nz,
174      * not nkz as for the minor direction.
175      */
176     overlap = &pme->overlap[0];
177
178     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
179     {
180         if (direction == GMX_SUM_GRID_FORWARD)
181         {
182             send_id     = overlap->comm_data[ipulse].send_id;
183             recv_id     = overlap->comm_data[ipulse].recv_id;
184             send_index0 = overlap->comm_data[ipulse].send_index0;
185             send_nindex = overlap->comm_data[ipulse].send_nindex;
186             recv_index0 = overlap->comm_data[ipulse].recv_index0;
187             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
188             recvptr     = overlap->recvbuf.data();
189         }
190         else
191         {
192             send_id     = overlap->comm_data[ipulse].recv_id;
193             recv_id     = overlap->comm_data[ipulse].send_id;
194             send_index0 = overlap->comm_data[ipulse].recv_index0;
195             send_nindex = overlap->comm_data[ipulse].recv_nindex;
196             recv_index0 = overlap->comm_data[ipulse].send_index0;
197             recv_nindex = overlap->comm_data[ipulse].send_nindex;
198             recvptr = grid + (recv_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
199         }
200
201         sendptr = grid + (send_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
202         datasize = pme->pmegrid_ny * pme->pmegrid_nz;
203
204         if (debug)
205         {
206             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
207                     pme->nodeid, overlap->nodeid, send_id, pme->pmegrid_start_ix,
208                     send_index0 - pme->pmegrid_start_ix,
209                     send_index0 - pme->pmegrid_start_ix + send_nindex);
210             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
211                     pme->nodeid, overlap->nodeid, recv_id, pme->pmegrid_start_ix,
212                     recv_index0 - pme->pmegrid_start_ix,
213                     recv_index0 - pme->pmegrid_start_ix + recv_nindex);
214         }
215
216         MPI_Sendrecv(sendptr, send_nindex * datasize, GMX_MPI_REAL, send_id, ipulse, recvptr,
217                      recv_nindex * datasize, GMX_MPI_REAL, recv_id, ipulse, overlap->mpi_comm, &stat);
218
219         /* ADD data from contiguous recv buffer */
220         if (direction == GMX_SUM_GRID_FORWARD)
221         {
222             p = grid + (recv_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
223             for (i = 0; i < recv_nindex * datasize; i++)
224             {
225                 p[i] += overlap->recvbuf[i];
226             }
227         }
228     }
229 #else  // GMX_MPI
230     GMX_UNUSED_VALUE(pme);
231     GMX_UNUSED_VALUE(grid);
232     GMX_UNUSED_VALUE(direction);
233
234     GMX_RELEASE_ASSERT(false, "gmx_sum_qgrid_dd() should not be called without MPI");
235 #endif // GMX_MPI
236 }
237
238
239 int copy_pmegrid_to_fftgrid(const gmx_pme_t* pme, const real* pmegrid, real* fftgrid, int grid_index)
240 {
241     ivec local_fft_ndata, local_fft_offset, local_fft_size;
242     ivec local_pme_size;
243     int  ix, iy, iz;
244     int  pmeidx, fftidx;
245
246     /* Dimensions should be identical for A/B grid, so we just use A here */
247     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
248                                    local_fft_size);
249
250     local_pme_size[0] = pme->pmegrid_nx;
251     local_pme_size[1] = pme->pmegrid_ny;
252     local_pme_size[2] = pme->pmegrid_nz;
253
254     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
255        the offset is identical, and the PME grid always has more data (due to overlap)
256      */
257     {
258 #ifdef DEBUG_PME
259         FILE *fp, *fp2;
260         char  fn[STRLEN];
261         real  val;
262         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
263         fp = gmx_ffopen(fn, "w");
264         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
265         fp2 = gmx_ffopen(fn, "w");
266 #endif
267
268         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
269         {
270             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
271             {
272                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
273                 {
274                     pmeidx = ix * (local_pme_size[YY] * local_pme_size[ZZ])
275                              + iy * (local_pme_size[ZZ]) + iz;
276                     fftidx = ix * (local_fft_size[YY] * local_fft_size[ZZ])
277                              + iy * (local_fft_size[ZZ]) + iz;
278                     fftgrid[fftidx] = pmegrid[pmeidx];
279 #ifdef DEBUG_PME
280                     val = 100 * pmegrid[pmeidx];
281                     if (pmegrid[pmeidx] != 0)
282                     {
283                         gmx_fprintf_pdb_atomline(fp, epdbATOM, pmeidx, "CA", ' ', "GLY", ' ', pmeidx,
284                                                  ' ', 5.0 * ix, 5.0 * iy, 5.0 * iz, 1.0, val, "");
285                     }
286                     if (pmegrid[pmeidx] != 0)
287                     {
288                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n", "qgrid",
289                                 pme->pmegrid_start_ix + ix, pme->pmegrid_start_iy + iy,
290                                 pme->pmegrid_start_iz + iz, pmegrid[pmeidx]);
291                     }
292 #endif
293                 }
294             }
295         }
296 #ifdef DEBUG_PME
297         gmx_ffclose(fp);
298         gmx_ffclose(fp2);
299 #endif
300     }
301     return 0;
302 }
303
304
305 #ifdef PME_TIME_THREADS
306 static gmx_cycles_t omp_cyc_start()
307 {
308     return gmx_cycles_read();
309 }
310
311 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
312 {
313     return gmx_cycles_read() - c;
314 }
315 #endif
316
317
318 int copy_fftgrid_to_pmegrid(struct gmx_pme_t* pme, const real* fftgrid, real* pmegrid, int grid_index, int nthread, int thread)
319 {
320     ivec local_fft_ndata, local_fft_offset, local_fft_size;
321     ivec local_pme_size;
322     int  ixy0, ixy1, ixy, ix, iy, iz;
323     int  pmeidx, fftidx;
324 #ifdef PME_TIME_THREADS
325     gmx_cycles_t  c1;
326     static double cs1 = 0;
327     static int    cnt = 0;
328 #endif
329
330 #ifdef PME_TIME_THREADS
331     c1 = omp_cyc_start();
332 #endif
333     /* Dimensions should be identical for A/B grid, so we just use A here */
334     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
335                                    local_fft_size);
336
337     local_pme_size[0] = pme->pmegrid_nx;
338     local_pme_size[1] = pme->pmegrid_ny;
339     local_pme_size[2] = pme->pmegrid_nz;
340
341     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
342        the offset is identical, and the PME grid always has more data (due to overlap)
343      */
344     ixy0 = ((thread)*local_fft_ndata[XX] * local_fft_ndata[YY]) / nthread;
345     ixy1 = ((thread + 1) * local_fft_ndata[XX] * local_fft_ndata[YY]) / nthread;
346
347     for (ixy = ixy0; ixy < ixy1; ixy++)
348     {
349         ix = ixy / local_fft_ndata[YY];
350         iy = ixy - ix * local_fft_ndata[YY];
351
352         pmeidx = (ix * local_pme_size[YY] + iy) * local_pme_size[ZZ];
353         fftidx = (ix * local_fft_size[YY] + iy) * local_fft_size[ZZ];
354         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
355         {
356             pmegrid[pmeidx + iz] = fftgrid[fftidx + iz];
357         }
358     }
359
360 #ifdef PME_TIME_THREADS
361     c1 = omp_cyc_end(c1);
362     cs1 += (double)c1;
363     cnt++;
364     if (cnt % 20 == 0)
365     {
366         printf("copy %.2f\n", cs1 * 1e-9);
367     }
368 #endif
369
370     return 0;
371 }
372
373
374 void wrap_periodic_pmegrid(const gmx_pme_t* pme, real* pmegrid)
375 {
376     int nx, ny, nz, pny, pnz, ny_x, overlap, ix, iy, iz;
377
378     nx = pme->nkx;
379     ny = pme->nky;
380     nz = pme->nkz;
381
382     pny = pme->pmegrid_ny;
383     pnz = pme->pmegrid_nz;
384
385     overlap = pme->pme_order - 1;
386
387     /* Add periodic overlap in z */
388     for (ix = 0; ix < pme->pmegrid_nx; ix++)
389     {
390         for (iy = 0; iy < pme->pmegrid_ny; iy++)
391         {
392             for (iz = 0; iz < overlap; iz++)
393             {
394                 pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[(ix * pny + iy) * pnz + nz + iz];
395             }
396         }
397     }
398
399     if (pme->nnodes_minor == 1)
400     {
401         for (ix = 0; ix < pme->pmegrid_nx; ix++)
402         {
403             for (iy = 0; iy < overlap; iy++)
404             {
405                 for (iz = 0; iz < nz; iz++)
406                 {
407                     pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[(ix * pny + ny + iy) * pnz + iz];
408                 }
409             }
410         }
411     }
412
413     if (pme->nnodes_major == 1)
414     {
415         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
416
417         for (ix = 0; ix < overlap; ix++)
418         {
419             for (iy = 0; iy < ny_x; iy++)
420             {
421                 for (iz = 0; iz < nz; iz++)
422                 {
423                     pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[((nx + ix) * pny + iy) * pnz + iz];
424                 }
425             }
426         }
427     }
428 }
429
430
431 void unwrap_periodic_pmegrid(struct gmx_pme_t* pme, real* pmegrid)
432 {
433     int nx, ny, nz, pny, pnz, ny_x, overlap, ix;
434
435     nx = pme->nkx;
436     ny = pme->nky;
437     nz = pme->nkz;
438
439     pny = pme->pmegrid_ny;
440     pnz = pme->pmegrid_nz;
441
442     overlap = pme->pme_order - 1;
443
444     if (pme->nnodes_major == 1)
445     {
446         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
447
448         for (ix = 0; ix < overlap; ix++)
449         {
450             int iy, iz;
451
452             for (iy = 0; iy < ny_x; iy++)
453             {
454                 for (iz = 0; iz < nz; iz++)
455                 {
456                     pmegrid[((nx + ix) * pny + iy) * pnz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
457                 }
458             }
459         }
460     }
461
462     if (pme->nnodes_minor == 1)
463     {
464 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
465         for (ix = 0; ix < pme->pmegrid_nx; ix++)
466         {
467             // Trivial OpenMP region that does not throw, no need for try/catch
468             int iy, iz;
469
470             for (iy = 0; iy < overlap; iy++)
471             {
472                 for (iz = 0; iz < nz; iz++)
473                 {
474                     pmegrid[(ix * pny + ny + iy) * pnz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
475                 }
476             }
477         }
478     }
479
480     /* Copy periodic overlap in z */
481 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
482     for (ix = 0; ix < pme->pmegrid_nx; ix++)
483     {
484         // Trivial OpenMP region that does not throw, no need for try/catch
485         int iy, iz;
486
487         for (iy = 0; iy < pme->pmegrid_ny; iy++)
488         {
489             for (iz = 0; iz < overlap; iz++)
490             {
491                 pmegrid[(ix * pny + iy) * pnz + nz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
492             }
493         }
494     }
495 }
496
497 void set_grid_alignment(int gmx_unused* pmegrid_nz, int gmx_unused pme_order)
498 {
499 #ifdef PME_SIMD4_SPREAD_GATHER
500     if (pme_order == 5
501 #    if !PME_4NSIMD_GATHER
502         || pme_order == 4
503 #    endif
504     )
505     {
506         /* Round nz up to a multiple of 4 to ensure alignment */
507         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
508     }
509 #endif
510 }
511
512 static void set_gridsize_alignment(int gmx_unused* gridsize, int gmx_unused pme_order)
513 {
514 #ifdef PME_SIMD4_SPREAD_GATHER
515 #    if !PME_4NSIMD_GATHER
516     if (pme_order == 4)
517     {
518         /* Add extra elements to ensured aligned operations do not go
519          * beyond the allocated grid size.
520          * Note that for pme_order=5, the pme grid z-size alignment
521          * ensures that we will not go beyond the grid size.
522          */
523         *gridsize += 4;
524     }
525 #    endif
526 #endif
527 }
528
529 void pmegrid_init(pmegrid_t* grid,
530                   int        cx,
531                   int        cy,
532                   int        cz,
533                   int        x0,
534                   int        y0,
535                   int        z0,
536                   int        x1,
537                   int        y1,
538                   int        z1,
539                   gmx_bool   set_alignment,
540                   int        pme_order,
541                   real*      ptr)
542 {
543     int nz, gridsize;
544
545     grid->ci[XX]     = cx;
546     grid->ci[YY]     = cy;
547     grid->ci[ZZ]     = cz;
548     grid->offset[XX] = x0;
549     grid->offset[YY] = y0;
550     grid->offset[ZZ] = z0;
551     grid->n[XX]      = x1 - x0 + pme_order - 1;
552     grid->n[YY]      = y1 - y0 + pme_order - 1;
553     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
554     copy_ivec(grid->n, grid->s);
555
556     nz = grid->s[ZZ];
557     set_grid_alignment(&nz, pme_order);
558     if (set_alignment)
559     {
560         grid->s[ZZ] = nz;
561     }
562     else if (nz != grid->s[ZZ])
563     {
564         gmx_incons("pmegrid_init call with an unaligned z size");
565     }
566
567     grid->order = pme_order;
568     if (ptr == nullptr)
569     {
570         gridsize = grid->s[XX] * grid->s[YY] * grid->s[ZZ];
571         set_gridsize_alignment(&gridsize, pme_order);
572         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
573     }
574     else
575     {
576         grid->grid = ptr;
577     }
578 }
579
580 static int div_round_up(int enumerator, int denominator)
581 {
582     return (enumerator + denominator - 1) / denominator;
583 }
584
585 static void make_subgrid_division(const ivec n, int ovl, int nthread, ivec nsub)
586 {
587     int   gsize_opt, gsize;
588     int   nsx, nsy, nsz;
589     char* env;
590
591     gsize_opt = -1;
592     for (nsx = 1; nsx <= nthread; nsx++)
593     {
594         if (nthread % nsx == 0)
595         {
596             for (nsy = 1; nsy <= nthread; nsy++)
597             {
598                 if (nsx * nsy <= nthread && nthread % (nsx * nsy) == 0)
599                 {
600                     nsz = nthread / (nsx * nsy);
601
602                     /* Determine the number of grid points per thread */
603                     gsize = (div_round_up(n[XX], nsx) + ovl) * (div_round_up(n[YY], nsy) + ovl)
604                             * (div_round_up(n[ZZ], nsz) + ovl);
605
606                     /* Minimize the number of grids points per thread
607                      * and, secondarily, the number of cuts in minor dimensions.
608                      */
609                     if (gsize_opt == -1 || gsize < gsize_opt
610                         || (gsize == gsize_opt && (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
611                     {
612                         nsub[XX]  = nsx;
613                         nsub[YY]  = nsy;
614                         nsub[ZZ]  = nsz;
615                         gsize_opt = gsize;
616                     }
617                 }
618             }
619         }
620     }
621
622     env = getenv("GMX_PME_THREAD_DIVISION");
623     if (env != nullptr)
624     {
625         sscanf(env, "%20d %20d %20d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
626     }
627
628     if (nsub[XX] * nsub[YY] * nsub[ZZ] != nthread)
629     {
630         gmx_fatal(FARGS,
631                   "PME grid thread division (%d x %d x %d) does not match the total number of "
632                   "threads (%d)",
633                   nsub[XX], nsub[YY], nsub[ZZ], nthread);
634     }
635 }
636
637 void pmegrids_init(pmegrids_t* grids,
638                    int         nx,
639                    int         ny,
640                    int         nz,
641                    int         nz_base,
642                    int         pme_order,
643                    gmx_bool    bUseThreads,
644                    int         nthread,
645                    int         overlap_x,
646                    int         overlap_y)
647 {
648     ivec n, n_base;
649     int  t, x, y, z, d, i, tfac;
650     int  max_comm_lines = -1;
651
652     n[XX] = nx - (pme_order - 1);
653     n[YY] = ny - (pme_order - 1);
654     n[ZZ] = nz - (pme_order - 1);
655
656     copy_ivec(n, n_base);
657     n_base[ZZ] = nz_base;
658
659     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order, nullptr);
660
661     grids->nthread = nthread;
662
663     make_subgrid_division(n_base, pme_order - 1, grids->nthread, grids->nc);
664
665     if (bUseThreads)
666     {
667         ivec nst;
668         int  gridsize;
669
670         for (d = 0; d < DIM; d++)
671         {
672             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
673         }
674         set_grid_alignment(&nst[ZZ], pme_order);
675
676         if (debug)
677         {
678             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n", grids->nc[XX],
679                     grids->nc[YY], grids->nc[ZZ]);
680             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n", nx, ny, nz, nst[XX],
681                     nst[YY], nst[ZZ]);
682         }
683
684         snew(grids->grid_th, grids->nthread);
685         t        = 0;
686         gridsize = nst[XX] * nst[YY] * nst[ZZ];
687         set_gridsize_alignment(&gridsize, pme_order);
688         snew_aligned(grids->grid_all, grids->nthread * gridsize + (grids->nthread + 1) * GMX_CACHE_SEP,
689                      SIMD4_ALIGNMENT);
690
691         for (x = 0; x < grids->nc[XX]; x++)
692         {
693             for (y = 0; y < grids->nc[YY]; y++)
694             {
695                 for (z = 0; z < grids->nc[ZZ]; z++)
696                 {
697                     pmegrid_init(&grids->grid_th[t], x, y, z, (n[XX] * (x)) / grids->nc[XX],
698                                  (n[YY] * (y)) / grids->nc[YY], (n[ZZ] * (z)) / grids->nc[ZZ],
699                                  (n[XX] * (x + 1)) / grids->nc[XX], (n[YY] * (y + 1)) / grids->nc[YY],
700                                  (n[ZZ] * (z + 1)) / grids->nc[ZZ], TRUE, pme_order,
701                                  grids->grid_all + GMX_CACHE_SEP + t * (gridsize + GMX_CACHE_SEP));
702                     t++;
703                 }
704             }
705         }
706     }
707     else
708     {
709         grids->grid_th = nullptr;
710     }
711
712     tfac = 1;
713     for (d = DIM - 1; d >= 0; d--)
714     {
715         snew(grids->g2t[d], n[d]);
716         t = 0;
717         for (i = 0; i < n[d]; i++)
718         {
719             /* The second check should match the parameters
720              * of the pmegrid_init call above.
721              */
722             while (t + 1 < grids->nc[d] && i >= (n[d] * (t + 1)) / grids->nc[d])
723             {
724                 t++;
725             }
726             grids->g2t[d][i] = t * tfac;
727         }
728
729         tfac *= grids->nc[d];
730
731         switch (d)
732         {
733             case XX: max_comm_lines = overlap_x; break;
734             case YY: max_comm_lines = overlap_y; break;
735             case ZZ: max_comm_lines = pme_order - 1; break;
736         }
737         grids->nthread_comm[d] = 0;
738         while ((n[d] * grids->nthread_comm[d]) / grids->nc[d] < max_comm_lines
739                && grids->nthread_comm[d] < grids->nc[d])
740         {
741             grids->nthread_comm[d]++;
742         }
743         if (debug != nullptr)
744         {
745             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n", 'x' + d,
746                     grids->nthread_comm[d]);
747         }
748         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
749          * work, but this is not a problematic restriction.
750          */
751         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
752         {
753             gmx_fatal(FARGS,
754                       "Too many threads for PME (%d) compared to the number of grid lines, reduce "
755                       "the number of threads doing PME",
756                       grids->nthread);
757         }
758     }
759 }
760
761 void pmegrids_destroy(pmegrids_t* grids)
762 {
763     if (grids->grid.grid != nullptr)
764     {
765         sfree_aligned(grids->grid.grid);
766
767         if (grids->nthread > 0)
768         {
769             sfree_aligned(grids->grid_all);
770             sfree(grids->grid_th);
771         }
772         for (int d = 0; d < DIM; d++)
773         {
774             sfree(grids->g2t[d]);
775         }
776     }
777 }
778
779 void make_gridindex_to_localindex(int n, int local_start, int local_range, int** global_to_local, real** fraction_shift)
780 {
781     /* Here we construct array for looking up the grid line index and
782      * fraction for particles. This is done because it is slighlty
783      * faster than the modulo operation and to because we need to take
784      * care of rounding issues, see below.
785      * We use an array size of c_pmeNeighborUnitcellCount times the grid size
786      * to allow for particles to be out of the triclinic unit-cell.
787      */
788     const int arraySize = c_pmeNeighborUnitcellCount * n;
789     int*      gtl;
790     real*     fsh;
791
792     snew(gtl, arraySize);
793     snew(fsh, arraySize);
794
795     for (int i = 0; i < arraySize; i++)
796     {
797         /* Transform global grid index to the local grid index.
798          * Our local grid always runs from 0 to local_range-1.
799          */
800         gtl[i] = (i - local_start + n) % n;
801         /* For coordinates that fall within the local grid the fraction
802          * is correct, we don't need to shift it.
803          */
804         fsh[i] = 0;
805         /* Check if we are using domain decomposition for PME */
806         if (local_range < n)
807         {
808             /* Due to rounding issues i could be 1 beyond the lower or
809              * upper boundary of the local grid. Correct the index for this.
810              * If we shift the index, we need to shift the fraction by
811              * the same amount in the other direction to not affect
812              * the weights.
813              * Note that due to this shifting the weights at the end of
814              * the spline might change, but that will only involve values
815              * between zero and values close to the precision of a real,
816              * which is anyhow the accuracy of the whole mesh calculation.
817              */
818             if (gtl[i] == n - 1)
819             {
820                 /* When this i is used, we should round the local index up */
821                 gtl[i] = 0;
822                 fsh[i] = -1;
823             }
824             else if (gtl[i] == local_range && local_range > 0)
825             {
826                 /* When this i is used, we should round the local index down */
827                 gtl[i] = local_range - 1;
828                 fsh[i] = 1;
829             }
830         }
831     }
832
833     *global_to_local = gtl;
834     *fraction_shift  = fsh;
835 }
836
837 void reuse_pmegrids(const pmegrids_t* oldgrid, pmegrids_t* newgrid)
838 {
839     int d, t;
840
841     for (d = 0; d < DIM; d++)
842     {
843         if (newgrid->grid.n[d] > oldgrid->grid.n[d])
844         {
845             return;
846         }
847     }
848
849     sfree_aligned(newgrid->grid.grid);
850     newgrid->grid.grid = oldgrid->grid.grid;
851
852     if (newgrid->grid_th != nullptr && newgrid->nthread == oldgrid->nthread)
853     {
854         sfree_aligned(newgrid->grid_all);
855         newgrid->grid_all = oldgrid->grid_all;
856         for (t = 0; t < newgrid->nthread; t++)
857         {
858             newgrid->grid_th[t].grid = oldgrid->grid_th[t].grid;
859         }
860     }
861 }