ab759f30deda0d41c894f1b2a4c0361d3c8603ea
[alexxy/gromacs.git] / src / gromacs / ewald / pme_grid.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 /* TODO find out what this file should be called */
38 #include "gmxpre.h"
39
40 #include "pme_grid.h"
41
42 #include "config.h"
43
44 #include <cstdlib>
45
46 #include "gromacs/ewald/pme.h"
47 #include "gromacs/fft/parallel_3dfft.h"
48 #include "gromacs/math/vec.h"
49 #include "gromacs/timing/cyclecounter.h"
50 #include "gromacs/utility/fatalerror.h"
51 #include "gromacs/utility/smalloc.h"
52
53 #include "pme_internal.h"
54
55 #ifdef DEBUG_PME
56 #    include "gromacs/fileio/pdbio.h"
57 #    include "gromacs/utility/cstringutil.h"
58 #    include "gromacs/utility/futil.h"
59 #endif
60
61 #include "pme_simd.h"
62
63 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
64  * to preserve alignment.
65  */
66 #define GMX_CACHE_SEP 64
67
68 void gmx_sum_qgrid_dd(gmx_pme_t* pme, real* grid, const int direction)
69 {
70 #if GMX_MPI
71     pme_overlap_t* overlap;
72     int            send_index0, send_nindex;
73     int            recv_index0, recv_nindex;
74     MPI_Status     stat;
75     int            i, j, k, ix, iy, iz, icnt;
76     int            send_id, recv_id, datasize;
77     real*          p;
78     real *         sendptr, *recvptr;
79
80     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
81     overlap = &pme->overlap[1];
82
83     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
84     {
85         /* Since we have already (un)wrapped the overlap in the z-dimension,
86          * we only have to communicate 0 to nkz (not pmegrid_nz).
87          */
88         if (direction == GMX_SUM_GRID_FORWARD)
89         {
90             send_id     = overlap->comm_data[ipulse].send_id;
91             recv_id     = overlap->comm_data[ipulse].recv_id;
92             send_index0 = overlap->comm_data[ipulse].send_index0;
93             send_nindex = overlap->comm_data[ipulse].send_nindex;
94             recv_index0 = overlap->comm_data[ipulse].recv_index0;
95             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
96         }
97         else
98         {
99             send_id     = overlap->comm_data[ipulse].recv_id;
100             recv_id     = overlap->comm_data[ipulse].send_id;
101             send_index0 = overlap->comm_data[ipulse].recv_index0;
102             send_nindex = overlap->comm_data[ipulse].recv_nindex;
103             recv_index0 = overlap->comm_data[ipulse].send_index0;
104             recv_nindex = overlap->comm_data[ipulse].send_nindex;
105         }
106
107         /* Copy data to contiguous send buffer */
108         if (debug)
109         {
110             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
111                     pme->nodeid, overlap->nodeid, send_id, pme->pmegrid_start_iy,
112                     send_index0 - pme->pmegrid_start_iy,
113                     send_index0 - pme->pmegrid_start_iy + send_nindex);
114         }
115         icnt = 0;
116         for (i = 0; i < pme->pmegrid_nx; i++)
117         {
118             ix = i;
119             for (j = 0; j < send_nindex; j++)
120             {
121                 iy = j + send_index0 - pme->pmegrid_start_iy;
122                 for (k = 0; k < pme->nkz; k++)
123                 {
124                     iz = k;
125                     overlap->sendbuf[icnt++] =
126                             grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz];
127                 }
128             }
129         }
130
131         datasize = pme->pmegrid_nx * pme->nkz;
132
133         MPI_Sendrecv(overlap->sendbuf.data(), send_nindex * datasize, GMX_MPI_REAL, send_id, ipulse,
134                      overlap->recvbuf.data(), recv_nindex * datasize, GMX_MPI_REAL, recv_id, ipulse,
135                      overlap->mpi_comm, &stat);
136
137         /* Get data from contiguous recv buffer */
138         if (debug)
139         {
140             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
141                     pme->nodeid, overlap->nodeid, recv_id, pme->pmegrid_start_iy,
142                     recv_index0 - pme->pmegrid_start_iy,
143                     recv_index0 - pme->pmegrid_start_iy + recv_nindex);
144         }
145         icnt = 0;
146         for (i = 0; i < pme->pmegrid_nx; i++)
147         {
148             ix = i;
149             for (j = 0; j < recv_nindex; j++)
150             {
151                 iy = j + recv_index0 - pme->pmegrid_start_iy;
152                 for (k = 0; k < pme->nkz; k++)
153                 {
154                     iz = k;
155                     if (direction == GMX_SUM_GRID_FORWARD)
156                     {
157                         grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz] +=
158                                 overlap->recvbuf[icnt++];
159                     }
160                     else
161                     {
162                         grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz] =
163                                 overlap->recvbuf[icnt++];
164                     }
165                 }
166             }
167         }
168     }
169
170     /* Major dimension is easier, no copying required,
171      * but we might have to sum to separate array.
172      * Since we don't copy, we have to communicate up to pmegrid_nz,
173      * not nkz as for the minor direction.
174      */
175     overlap = &pme->overlap[0];
176
177     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
178     {
179         if (direction == GMX_SUM_GRID_FORWARD)
180         {
181             send_id     = overlap->comm_data[ipulse].send_id;
182             recv_id     = overlap->comm_data[ipulse].recv_id;
183             send_index0 = overlap->comm_data[ipulse].send_index0;
184             send_nindex = overlap->comm_data[ipulse].send_nindex;
185             recv_index0 = overlap->comm_data[ipulse].recv_index0;
186             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
187             recvptr     = overlap->recvbuf.data();
188         }
189         else
190         {
191             send_id     = overlap->comm_data[ipulse].recv_id;
192             recv_id     = overlap->comm_data[ipulse].send_id;
193             send_index0 = overlap->comm_data[ipulse].recv_index0;
194             send_nindex = overlap->comm_data[ipulse].recv_nindex;
195             recv_index0 = overlap->comm_data[ipulse].send_index0;
196             recv_nindex = overlap->comm_data[ipulse].send_nindex;
197             recvptr = grid + (recv_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
198         }
199
200         sendptr = grid + (send_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
201         datasize = pme->pmegrid_ny * pme->pmegrid_nz;
202
203         if (debug)
204         {
205             fprintf(debug, "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
206                     pme->nodeid, overlap->nodeid, send_id, pme->pmegrid_start_ix,
207                     send_index0 - pme->pmegrid_start_ix,
208                     send_index0 - pme->pmegrid_start_ix + send_nindex);
209             fprintf(debug, "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
210                     pme->nodeid, overlap->nodeid, recv_id, pme->pmegrid_start_ix,
211                     recv_index0 - pme->pmegrid_start_ix,
212                     recv_index0 - pme->pmegrid_start_ix + recv_nindex);
213         }
214
215         MPI_Sendrecv(sendptr, send_nindex * datasize, GMX_MPI_REAL, send_id, ipulse, recvptr,
216                      recv_nindex * datasize, GMX_MPI_REAL, recv_id, ipulse, overlap->mpi_comm, &stat);
217
218         /* ADD data from contiguous recv buffer */
219         if (direction == GMX_SUM_GRID_FORWARD)
220         {
221             p = grid + (recv_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
222             for (i = 0; i < recv_nindex * datasize; i++)
223             {
224                 p[i] += overlap->recvbuf[i];
225             }
226         }
227     }
228 #else  // GMX_MPI
229     GMX_UNUSED_VALUE(pme);
230     GMX_UNUSED_VALUE(grid);
231     GMX_UNUSED_VALUE(direction);
232
233     GMX_RELEASE_ASSERT(false, "gmx_sum_qgrid_dd() should not be called without MPI");
234 #endif // GMX_MPI
235 }
236
237
238 int copy_pmegrid_to_fftgrid(const gmx_pme_t* pme, const real* pmegrid, real* fftgrid, int grid_index)
239 {
240     ivec local_fft_ndata, local_fft_offset, local_fft_size;
241     ivec local_pme_size;
242     int  ix, iy, iz;
243     int  pmeidx, fftidx;
244
245     /* Dimensions should be identical for A/B grid, so we just use A here */
246     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
247                                    local_fft_size);
248
249     local_pme_size[0] = pme->pmegrid_nx;
250     local_pme_size[1] = pme->pmegrid_ny;
251     local_pme_size[2] = pme->pmegrid_nz;
252
253     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
254        the offset is identical, and the PME grid always has more data (due to overlap)
255      */
256     {
257 #ifdef DEBUG_PME
258         FILE *fp, *fp2;
259         char  fn[STRLEN];
260         real  val;
261         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
262         fp = gmx_ffopen(fn, "w");
263         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
264         fp2 = gmx_ffopen(fn, "w");
265 #endif
266
267         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
268         {
269             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
270             {
271                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
272                 {
273                     pmeidx = ix * (local_pme_size[YY] * local_pme_size[ZZ])
274                              + iy * (local_pme_size[ZZ]) + iz;
275                     fftidx = ix * (local_fft_size[YY] * local_fft_size[ZZ])
276                              + iy * (local_fft_size[ZZ]) + iz;
277                     fftgrid[fftidx] = pmegrid[pmeidx];
278 #ifdef DEBUG_PME
279                     val = 100 * pmegrid[pmeidx];
280                     if (pmegrid[pmeidx] != 0)
281                     {
282                         gmx_fprintf_pdb_atomline(fp, epdbATOM, pmeidx, "CA", ' ', "GLY", ' ', pmeidx,
283                                                  ' ', 5.0 * ix, 5.0 * iy, 5.0 * iz, 1.0, val, "");
284                     }
285                     if (pmegrid[pmeidx] != 0)
286                     {
287                         fprintf(fp2, "%-12s  %5d  %5d  %5d  %12.5e\n", "qgrid",
288                                 pme->pmegrid_start_ix + ix, pme->pmegrid_start_iy + iy,
289                                 pme->pmegrid_start_iz + iz, pmegrid[pmeidx]);
290                     }
291 #endif
292                 }
293             }
294         }
295 #ifdef DEBUG_PME
296         gmx_ffclose(fp);
297         gmx_ffclose(fp2);
298 #endif
299     }
300     return 0;
301 }
302
303
304 #ifdef PME_TIME_THREADS
305 static gmx_cycles_t omp_cyc_start()
306 {
307     return gmx_cycles_read();
308 }
309
310 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
311 {
312     return gmx_cycles_read() - c;
313 }
314 #endif
315
316
317 int copy_fftgrid_to_pmegrid(struct gmx_pme_t* pme, const real* fftgrid, real* pmegrid, int grid_index, int nthread, int thread)
318 {
319     ivec local_fft_ndata, local_fft_offset, local_fft_size;
320     ivec local_pme_size;
321     int  ixy0, ixy1, ixy, ix, iy, iz;
322     int  pmeidx, fftidx;
323 #ifdef PME_TIME_THREADS
324     gmx_cycles_t  c1;
325     static double cs1 = 0;
326     static int    cnt = 0;
327 #endif
328
329 #ifdef PME_TIME_THREADS
330     c1 = omp_cyc_start();
331 #endif
332     /* Dimensions should be identical for A/B grid, so we just use A here */
333     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
334                                    local_fft_size);
335
336     local_pme_size[0] = pme->pmegrid_nx;
337     local_pme_size[1] = pme->pmegrid_ny;
338     local_pme_size[2] = pme->pmegrid_nz;
339
340     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
341        the offset is identical, and the PME grid always has more data (due to overlap)
342      */
343     ixy0 = ((thread)*local_fft_ndata[XX] * local_fft_ndata[YY]) / nthread;
344     ixy1 = ((thread + 1) * local_fft_ndata[XX] * local_fft_ndata[YY]) / nthread;
345
346     for (ixy = ixy0; ixy < ixy1; ixy++)
347     {
348         ix = ixy / local_fft_ndata[YY];
349         iy = ixy - ix * local_fft_ndata[YY];
350
351         pmeidx = (ix * local_pme_size[YY] + iy) * local_pme_size[ZZ];
352         fftidx = (ix * local_fft_size[YY] + iy) * local_fft_size[ZZ];
353         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
354         {
355             pmegrid[pmeidx + iz] = fftgrid[fftidx + iz];
356         }
357     }
358
359 #ifdef PME_TIME_THREADS
360     c1 = omp_cyc_end(c1);
361     cs1 += (double)c1;
362     cnt++;
363     if (cnt % 20 == 0)
364     {
365         printf("copy %.2f\n", cs1 * 1e-9);
366     }
367 #endif
368
369     return 0;
370 }
371
372
373 void wrap_periodic_pmegrid(const gmx_pme_t* pme, real* pmegrid)
374 {
375     int nx, ny, nz, pny, pnz, ny_x, overlap, ix, iy, iz;
376
377     nx = pme->nkx;
378     ny = pme->nky;
379     nz = pme->nkz;
380
381     pny = pme->pmegrid_ny;
382     pnz = pme->pmegrid_nz;
383
384     overlap = pme->pme_order - 1;
385
386     /* Add periodic overlap in z */
387     for (ix = 0; ix < pme->pmegrid_nx; ix++)
388     {
389         for (iy = 0; iy < pme->pmegrid_ny; iy++)
390         {
391             for (iz = 0; iz < overlap; iz++)
392             {
393                 pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[(ix * pny + iy) * pnz + nz + iz];
394             }
395         }
396     }
397
398     if (pme->nnodes_minor == 1)
399     {
400         for (ix = 0; ix < pme->pmegrid_nx; ix++)
401         {
402             for (iy = 0; iy < overlap; iy++)
403             {
404                 for (iz = 0; iz < nz; iz++)
405                 {
406                     pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[(ix * pny + ny + iy) * pnz + iz];
407                 }
408             }
409         }
410     }
411
412     if (pme->nnodes_major == 1)
413     {
414         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
415
416         for (ix = 0; ix < overlap; ix++)
417         {
418             for (iy = 0; iy < ny_x; iy++)
419             {
420                 for (iz = 0; iz < nz; iz++)
421                 {
422                     pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[((nx + ix) * pny + iy) * pnz + iz];
423                 }
424             }
425         }
426     }
427 }
428
429
430 void unwrap_periodic_pmegrid(struct gmx_pme_t* pme, real* pmegrid)
431 {
432     int nx, ny, nz, pny, pnz, ny_x, overlap, ix;
433
434     nx = pme->nkx;
435     ny = pme->nky;
436     nz = pme->nkz;
437
438     pny = pme->pmegrid_ny;
439     pnz = pme->pmegrid_nz;
440
441     overlap = pme->pme_order - 1;
442
443     if (pme->nnodes_major == 1)
444     {
445         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
446
447         for (ix = 0; ix < overlap; ix++)
448         {
449             int iy, iz;
450
451             for (iy = 0; iy < ny_x; iy++)
452             {
453                 for (iz = 0; iz < nz; iz++)
454                 {
455                     pmegrid[((nx + ix) * pny + iy) * pnz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
456                 }
457             }
458         }
459     }
460
461     if (pme->nnodes_minor == 1)
462     {
463 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
464         for (ix = 0; ix < pme->pmegrid_nx; ix++)
465         {
466             // Trivial OpenMP region that does not throw, no need for try/catch
467             int iy, iz;
468
469             for (iy = 0; iy < overlap; iy++)
470             {
471                 for (iz = 0; iz < nz; iz++)
472                 {
473                     pmegrid[(ix * pny + ny + iy) * pnz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
474                 }
475             }
476         }
477     }
478
479     /* Copy periodic overlap in z */
480 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
481     for (ix = 0; ix < pme->pmegrid_nx; ix++)
482     {
483         // Trivial OpenMP region that does not throw, no need for try/catch
484         int iy, iz;
485
486         for (iy = 0; iy < pme->pmegrid_ny; iy++)
487         {
488             for (iz = 0; iz < overlap; iz++)
489             {
490                 pmegrid[(ix * pny + iy) * pnz + nz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
491             }
492         }
493     }
494 }
495
496 void set_grid_alignment(int gmx_unused* pmegrid_nz, int gmx_unused pme_order)
497 {
498 #ifdef PME_SIMD4_SPREAD_GATHER
499     if (pme_order == 5
500 #    if !PME_4NSIMD_GATHER
501         || pme_order == 4
502 #    endif
503     )
504     {
505         /* Round nz up to a multiple of 4 to ensure alignment */
506         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
507     }
508 #endif
509 }
510
511 static void set_gridsize_alignment(int gmx_unused* gridsize, int gmx_unused pme_order)
512 {
513 #ifdef PME_SIMD4_SPREAD_GATHER
514 #    if !PME_4NSIMD_GATHER
515     if (pme_order == 4)
516     {
517         /* Add extra elements to ensured aligned operations do not go
518          * beyond the allocated grid size.
519          * Note that for pme_order=5, the pme grid z-size alignment
520          * ensures that we will not go beyond the grid size.
521          */
522         *gridsize += 4;
523     }
524 #    endif
525 #endif
526 }
527
528 void pmegrid_init(pmegrid_t* grid,
529                   int        cx,
530                   int        cy,
531                   int        cz,
532                   int        x0,
533                   int        y0,
534                   int        z0,
535                   int        x1,
536                   int        y1,
537                   int        z1,
538                   gmx_bool   set_alignment,
539                   int        pme_order,
540                   real*      ptr)
541 {
542     int nz, gridsize;
543
544     grid->ci[XX]     = cx;
545     grid->ci[YY]     = cy;
546     grid->ci[ZZ]     = cz;
547     grid->offset[XX] = x0;
548     grid->offset[YY] = y0;
549     grid->offset[ZZ] = z0;
550     grid->n[XX]      = x1 - x0 + pme_order - 1;
551     grid->n[YY]      = y1 - y0 + pme_order - 1;
552     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
553     copy_ivec(grid->n, grid->s);
554
555     nz = grid->s[ZZ];
556     set_grid_alignment(&nz, pme_order);
557     if (set_alignment)
558     {
559         grid->s[ZZ] = nz;
560     }
561     else if (nz != grid->s[ZZ])
562     {
563         gmx_incons("pmegrid_init call with an unaligned z size");
564     }
565
566     grid->order = pme_order;
567     if (ptr == nullptr)
568     {
569         gridsize = grid->s[XX] * grid->s[YY] * grid->s[ZZ];
570         set_gridsize_alignment(&gridsize, pme_order);
571         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
572     }
573     else
574     {
575         grid->grid = ptr;
576     }
577 }
578
579 static int div_round_up(int enumerator, int denominator)
580 {
581     return (enumerator + denominator - 1) / denominator;
582 }
583
584 static void make_subgrid_division(const ivec n, int ovl, int nthread, ivec nsub)
585 {
586     int   gsize_opt, gsize;
587     int   nsx, nsy, nsz;
588     char* env;
589
590     gsize_opt = -1;
591     for (nsx = 1; nsx <= nthread; nsx++)
592     {
593         if (nthread % nsx == 0)
594         {
595             for (nsy = 1; nsy <= nthread; nsy++)
596             {
597                 if (nsx * nsy <= nthread && nthread % (nsx * nsy) == 0)
598                 {
599                     nsz = nthread / (nsx * nsy);
600
601                     /* Determine the number of grid points per thread */
602                     gsize = (div_round_up(n[XX], nsx) + ovl) * (div_round_up(n[YY], nsy) + ovl)
603                             * (div_round_up(n[ZZ], nsz) + ovl);
604
605                     /* Minimize the number of grids points per thread
606                      * and, secondarily, the number of cuts in minor dimensions.
607                      */
608                     if (gsize_opt == -1 || gsize < gsize_opt
609                         || (gsize == gsize_opt && (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
610                     {
611                         nsub[XX]  = nsx;
612                         nsub[YY]  = nsy;
613                         nsub[ZZ]  = nsz;
614                         gsize_opt = gsize;
615                     }
616                 }
617             }
618         }
619     }
620
621     env = getenv("GMX_PME_THREAD_DIVISION");
622     if (env != nullptr)
623     {
624         sscanf(env, "%20d %20d %20d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
625     }
626
627     if (nsub[XX] * nsub[YY] * nsub[ZZ] != nthread)
628     {
629         gmx_fatal(FARGS,
630                   "PME grid thread division (%d x %d x %d) does not match the total number of "
631                   "threads (%d)",
632                   nsub[XX], nsub[YY], nsub[ZZ], nthread);
633     }
634 }
635
636 void pmegrids_init(pmegrids_t* grids,
637                    int         nx,
638                    int         ny,
639                    int         nz,
640                    int         nz_base,
641                    int         pme_order,
642                    gmx_bool    bUseThreads,
643                    int         nthread,
644                    int         overlap_x,
645                    int         overlap_y)
646 {
647     ivec n, n_base;
648     int  t, x, y, z, d, i, tfac;
649     int  max_comm_lines = -1;
650
651     n[XX] = nx - (pme_order - 1);
652     n[YY] = ny - (pme_order - 1);
653     n[ZZ] = nz - (pme_order - 1);
654
655     copy_ivec(n, n_base);
656     n_base[ZZ] = nz_base;
657
658     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order, nullptr);
659
660     grids->nthread = nthread;
661
662     make_subgrid_division(n_base, pme_order - 1, grids->nthread, grids->nc);
663
664     if (bUseThreads)
665     {
666         ivec nst;
667         int  gridsize;
668
669         for (d = 0; d < DIM; d++)
670         {
671             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
672         }
673         set_grid_alignment(&nst[ZZ], pme_order);
674
675         if (debug)
676         {
677             fprintf(debug, "pmegrid thread local division: %d x %d x %d\n", grids->nc[XX],
678                     grids->nc[YY], grids->nc[ZZ]);
679             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n", nx, ny, nz, nst[XX],
680                     nst[YY], nst[ZZ]);
681         }
682
683         snew(grids->grid_th, grids->nthread);
684         t        = 0;
685         gridsize = nst[XX] * nst[YY] * nst[ZZ];
686         set_gridsize_alignment(&gridsize, pme_order);
687         snew_aligned(grids->grid_all, grids->nthread * gridsize + (grids->nthread + 1) * GMX_CACHE_SEP,
688                      SIMD4_ALIGNMENT);
689
690         for (x = 0; x < grids->nc[XX]; x++)
691         {
692             for (y = 0; y < grids->nc[YY]; y++)
693             {
694                 for (z = 0; z < grids->nc[ZZ]; z++)
695                 {
696                     pmegrid_init(&grids->grid_th[t], x, y, z, (n[XX] * (x)) / grids->nc[XX],
697                                  (n[YY] * (y)) / grids->nc[YY], (n[ZZ] * (z)) / grids->nc[ZZ],
698                                  (n[XX] * (x + 1)) / grids->nc[XX], (n[YY] * (y + 1)) / grids->nc[YY],
699                                  (n[ZZ] * (z + 1)) / grids->nc[ZZ], TRUE, pme_order,
700                                  grids->grid_all + GMX_CACHE_SEP + t * (gridsize + GMX_CACHE_SEP));
701                     t++;
702                 }
703             }
704         }
705     }
706     else
707     {
708         grids->grid_th = nullptr;
709     }
710
711     tfac = 1;
712     for (d = DIM - 1; d >= 0; d--)
713     {
714         snew(grids->g2t[d], n[d]);
715         t = 0;
716         for (i = 0; i < n[d]; i++)
717         {
718             /* The second check should match the parameters
719              * of the pmegrid_init call above.
720              */
721             while (t + 1 < grids->nc[d] && i >= (n[d] * (t + 1)) / grids->nc[d])
722             {
723                 t++;
724             }
725             grids->g2t[d][i] = t * tfac;
726         }
727
728         tfac *= grids->nc[d];
729
730         switch (d)
731         {
732             case XX: max_comm_lines = overlap_x; break;
733             case YY: max_comm_lines = overlap_y; break;
734             case ZZ: max_comm_lines = pme_order - 1; break;
735         }
736         grids->nthread_comm[d] = 0;
737         while ((n[d] * grids->nthread_comm[d]) / grids->nc[d] < max_comm_lines
738                && grids->nthread_comm[d] < grids->nc[d])
739         {
740             grids->nthread_comm[d]++;
741         }
742         if (debug != nullptr)
743         {
744             fprintf(debug, "pmegrid thread grid communication range in %c: %d\n", 'x' + d,
745                     grids->nthread_comm[d]);
746         }
747         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
748          * work, but this is not a problematic restriction.
749          */
750         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
751         {
752             gmx_fatal(FARGS,
753                       "Too many threads for PME (%d) compared to the number of grid lines, reduce "
754                       "the number of threads doing PME",
755                       grids->nthread);
756         }
757     }
758 }
759
760 void pmegrids_destroy(pmegrids_t* grids)
761 {
762     if (grids->grid.grid != nullptr)
763     {
764         sfree_aligned(grids->grid.grid);
765
766         if (grids->nthread > 0)
767         {
768             sfree_aligned(grids->grid_all);
769             sfree(grids->grid_th);
770         }
771         for (int d = 0; d < DIM; d++)
772         {
773             sfree(grids->g2t[d]);
774         }
775     }
776 }
777
778 void make_gridindex_to_localindex(int n, int local_start, int local_range, int** global_to_local, real** fraction_shift)
779 {
780     /* Here we construct array for looking up the grid line index and
781      * fraction for particles. This is done because it is slighlty
782      * faster than the modulo operation and to because we need to take
783      * care of rounding issues, see below.
784      * We use an array size of c_pmeNeighborUnitcellCount times the grid size
785      * to allow for particles to be out of the triclinic unit-cell.
786      */
787     const int arraySize = c_pmeNeighborUnitcellCount * n;
788     int*      gtl;
789     real*     fsh;
790
791     snew(gtl, arraySize);
792     snew(fsh, arraySize);
793
794     for (int i = 0; i < arraySize; i++)
795     {
796         /* Transform global grid index to the local grid index.
797          * Our local grid always runs from 0 to local_range-1.
798          */
799         gtl[i] = (i - local_start + n) % n;
800         /* For coordinates that fall within the local grid the fraction
801          * is correct, we don't need to shift it.
802          */
803         fsh[i] = 0;
804         /* Check if we are using domain decomposition for PME */
805         if (local_range < n)
806         {
807             /* Due to rounding issues i could be 1 beyond the lower or
808              * upper boundary of the local grid. Correct the index for this.
809              * If we shift the index, we need to shift the fraction by
810              * the same amount in the other direction to not affect
811              * the weights.
812              * Note that due to this shifting the weights at the end of
813              * the spline might change, but that will only involve values
814              * between zero and values close to the precision of a real,
815              * which is anyhow the accuracy of the whole mesh calculation.
816              */
817             if (gtl[i] == n - 1)
818             {
819                 /* When this i is used, we should round the local index up */
820                 gtl[i] = 0;
821                 fsh[i] = -1;
822             }
823             else if (gtl[i] == local_range && local_range > 0)
824             {
825                 /* When this i is used, we should round the local index down */
826                 gtl[i] = local_range - 1;
827                 fsh[i] = 1;
828             }
829         }
830     }
831
832     *global_to_local = gtl;
833     *fraction_shift  = fsh;
834 }
835
836 void reuse_pmegrids(const pmegrids_t* oldgrid, pmegrids_t* newgrid)
837 {
838     int d, t;
839
840     for (d = 0; d < DIM; d++)
841     {
842         if (newgrid->grid.n[d] > oldgrid->grid.n[d])
843         {
844             return;
845         }
846     }
847
848     sfree_aligned(newgrid->grid.grid);
849     newgrid->grid.grid = oldgrid->grid.grid;
850
851     if (newgrid->grid_th != nullptr && newgrid->nthread == oldgrid->nthread)
852     {
853         sfree_aligned(newgrid->grid_all);
854         newgrid->grid_all = oldgrid->grid_all;
855         for (t = 0; t < newgrid->nthread; t++)
856         {
857             newgrid->grid_th[t].grid = oldgrid->grid_th[t].grid;
858         }
859     }
860 }