Fix random typos
[alexxy/gromacs.git] / src / gromacs / ewald / pme_grid.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 /* TODO find out what this file should be called */
39 #include "gmxpre.h"
40
41 #include "pme_grid.h"
42
43 #include "config.h"
44
45 #include <cstdlib>
46
47 #include "gromacs/ewald/pme.h"
48 #include "gromacs/fft/parallel_3dfft.h"
49 #include "gromacs/math/vec.h"
50 #include "gromacs/timing/cyclecounter.h"
51 #include "gromacs/utility/fatalerror.h"
52 #include "gromacs/utility/smalloc.h"
53
54 #include "pme_internal.h"
55
56 #ifdef DEBUG_PME
57 #    include "gromacs/fileio/pdbio.h"
58 #    include "gromacs/utility/cstringutil.h"
59 #    include "gromacs/utility/futil.h"
60 #endif
61
62 #include "pme_simd.h"
63
64 /* GMX_CACHE_SEP should be a multiple of the SIMD and SIMD4 register size
65  * to preserve alignment.
66  */
67 #define GMX_CACHE_SEP 64
68
69 void gmx_sum_qgrid_dd(gmx_pme_t* pme, real* grid, const int direction)
70 {
71 #if GMX_MPI
72     pme_overlap_t* overlap;
73     int            send_index0, send_nindex;
74     int            recv_index0, recv_nindex;
75     MPI_Status     stat;
76     int            i, j, k, ix, iy, iz, icnt;
77     int            send_id, recv_id, datasize;
78     real*          p;
79     real *         sendptr, *recvptr;
80
81     /* Start with minor-rank communication. This is a bit of a pain since it is not contiguous */
82     overlap = &pme->overlap[1];
83
84     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
85     {
86         /* Since we have already (un)wrapped the overlap in the z-dimension,
87          * we only have to communicate 0 to nkz (not pmegrid_nz).
88          */
89         if (direction == GMX_SUM_GRID_FORWARD)
90         {
91             send_id     = overlap->comm_data[ipulse].send_id;
92             recv_id     = overlap->comm_data[ipulse].recv_id;
93             send_index0 = overlap->comm_data[ipulse].send_index0;
94             send_nindex = overlap->comm_data[ipulse].send_nindex;
95             recv_index0 = overlap->comm_data[ipulse].recv_index0;
96             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
97         }
98         else
99         {
100             send_id     = overlap->comm_data[ipulse].recv_id;
101             recv_id     = overlap->comm_data[ipulse].send_id;
102             send_index0 = overlap->comm_data[ipulse].recv_index0;
103             send_nindex = overlap->comm_data[ipulse].recv_nindex;
104             recv_index0 = overlap->comm_data[ipulse].send_index0;
105             recv_nindex = overlap->comm_data[ipulse].send_nindex;
106         }
107
108         /* Copy data to contiguous send buffer */
109         if (debug)
110         {
111             fprintf(debug,
112                     "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
113                     pme->nodeid,
114                     overlap->nodeid,
115                     send_id,
116                     pme->pmegrid_start_iy,
117                     send_index0 - pme->pmegrid_start_iy,
118                     send_index0 - pme->pmegrid_start_iy + send_nindex);
119         }
120         icnt = 0;
121         for (i = 0; i < pme->pmegrid_nx; i++)
122         {
123             ix = i;
124             for (j = 0; j < send_nindex; j++)
125             {
126                 iy = j + send_index0 - pme->pmegrid_start_iy;
127                 for (k = 0; k < pme->nkz; k++)
128                 {
129                     iz = k;
130                     overlap->sendbuf[icnt++] =
131                             grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz];
132                 }
133             }
134         }
135
136         datasize = pme->pmegrid_nx * pme->nkz;
137
138         MPI_Sendrecv(overlap->sendbuf.data(),
139                      send_nindex * datasize,
140                      GMX_MPI_REAL,
141                      send_id,
142                      ipulse,
143                      overlap->recvbuf.data(),
144                      recv_nindex * datasize,
145                      GMX_MPI_REAL,
146                      recv_id,
147                      ipulse,
148                      overlap->mpi_comm,
149                      &stat);
150
151         /* Get data from contiguous recv buffer */
152         if (debug)
153         {
154             fprintf(debug,
155                     "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
156                     pme->nodeid,
157                     overlap->nodeid,
158                     recv_id,
159                     pme->pmegrid_start_iy,
160                     recv_index0 - pme->pmegrid_start_iy,
161                     recv_index0 - pme->pmegrid_start_iy + recv_nindex);
162         }
163         icnt = 0;
164         for (i = 0; i < pme->pmegrid_nx; i++)
165         {
166             ix = i;
167             for (j = 0; j < recv_nindex; j++)
168             {
169                 iy = j + recv_index0 - pme->pmegrid_start_iy;
170                 for (k = 0; k < pme->nkz; k++)
171                 {
172                     iz = k;
173                     if (direction == GMX_SUM_GRID_FORWARD)
174                     {
175                         grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz] +=
176                                 overlap->recvbuf[icnt++];
177                     }
178                     else
179                     {
180                         grid[ix * (pme->pmegrid_ny * pme->pmegrid_nz) + iy * (pme->pmegrid_nz) + iz] =
181                                 overlap->recvbuf[icnt++];
182                     }
183                 }
184             }
185         }
186     }
187
188     /* Major dimension is easier, no copying required,
189      * but we might have to sum to separate array.
190      * Since we don't copy, we have to communicate up to pmegrid_nz,
191      * not nkz as for the minor direction.
192      */
193     overlap = &pme->overlap[0];
194
195     for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
196     {
197         if (direction == GMX_SUM_GRID_FORWARD)
198         {
199             send_id     = overlap->comm_data[ipulse].send_id;
200             recv_id     = overlap->comm_data[ipulse].recv_id;
201             send_index0 = overlap->comm_data[ipulse].send_index0;
202             send_nindex = overlap->comm_data[ipulse].send_nindex;
203             recv_index0 = overlap->comm_data[ipulse].recv_index0;
204             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
205             recvptr     = overlap->recvbuf.data();
206         }
207         else
208         {
209             send_id     = overlap->comm_data[ipulse].recv_id;
210             recv_id     = overlap->comm_data[ipulse].send_id;
211             send_index0 = overlap->comm_data[ipulse].recv_index0;
212             send_nindex = overlap->comm_data[ipulse].recv_nindex;
213             recv_index0 = overlap->comm_data[ipulse].send_index0;
214             recv_nindex = overlap->comm_data[ipulse].send_nindex;
215             recvptr = grid + (recv_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
216         }
217
218         sendptr = grid + (send_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
219         datasize = pme->pmegrid_ny * pme->pmegrid_nz;
220
221         if (debug)
222         {
223             fprintf(debug,
224                     "PME send rank %d %d -> %d grid start %d Communicating %d to %d\n",
225                     pme->nodeid,
226                     overlap->nodeid,
227                     send_id,
228                     pme->pmegrid_start_ix,
229                     send_index0 - pme->pmegrid_start_ix,
230                     send_index0 - pme->pmegrid_start_ix + send_nindex);
231             fprintf(debug,
232                     "PME recv rank %d %d <- %d grid start %d Communicating %d to %d\n",
233                     pme->nodeid,
234                     overlap->nodeid,
235                     recv_id,
236                     pme->pmegrid_start_ix,
237                     recv_index0 - pme->pmegrid_start_ix,
238                     recv_index0 - pme->pmegrid_start_ix + recv_nindex);
239         }
240
241         MPI_Sendrecv(sendptr,
242                      send_nindex * datasize,
243                      GMX_MPI_REAL,
244                      send_id,
245                      ipulse,
246                      recvptr,
247                      recv_nindex * datasize,
248                      GMX_MPI_REAL,
249                      recv_id,
250                      ipulse,
251                      overlap->mpi_comm,
252                      &stat);
253
254         /* ADD data from contiguous recv buffer */
255         if (direction == GMX_SUM_GRID_FORWARD)
256         {
257             p = grid + (recv_index0 - pme->pmegrid_start_ix) * (pme->pmegrid_ny * pme->pmegrid_nz);
258             for (i = 0; i < recv_nindex * datasize; i++)
259             {
260                 p[i] += overlap->recvbuf[i];
261             }
262         }
263     }
264 #else  // GMX_MPI
265     GMX_UNUSED_VALUE(pme);
266     GMX_UNUSED_VALUE(grid);
267     GMX_UNUSED_VALUE(direction);
268
269     GMX_RELEASE_ASSERT(false, "gmx_sum_qgrid_dd() should not be called without MPI");
270 #endif // GMX_MPI
271 }
272
273
274 int copy_pmegrid_to_fftgrid(const gmx_pme_t* pme, const real* pmegrid, real* fftgrid, int grid_index)
275 {
276     ivec local_fft_ndata, local_fft_offset, local_fft_size;
277     ivec local_pme_size;
278     int  ix, iy, iz;
279     int  pmeidx, fftidx;
280
281     /* Dimensions should be identical for A/B grid, so we just use A here */
282     gmx_parallel_3dfft_real_limits(
283             pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset, local_fft_size);
284
285     local_pme_size[0] = pme->pmegrid_nx;
286     local_pme_size[1] = pme->pmegrid_ny;
287     local_pme_size[2] = pme->pmegrid_nz;
288
289     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
290        the offset is identical, and the PME grid always has more data (due to overlap)
291      */
292     {
293 #ifdef DEBUG_PME
294         FILE *fp, *fp2;
295         char  fn[STRLEN];
296         real  val;
297         sprintf(fn, "pmegrid%d.pdb", pme->nodeid);
298         fp = gmx_ffopen(fn, "w");
299         sprintf(fn, "pmegrid%d.txt", pme->nodeid);
300         fp2 = gmx_ffopen(fn, "w");
301 #endif
302
303         for (ix = 0; ix < local_fft_ndata[XX]; ix++)
304         {
305             for (iy = 0; iy < local_fft_ndata[YY]; iy++)
306             {
307                 for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
308                 {
309                     pmeidx = ix * (local_pme_size[YY] * local_pme_size[ZZ])
310                              + iy * (local_pme_size[ZZ]) + iz;
311                     fftidx = ix * (local_fft_size[YY] * local_fft_size[ZZ])
312                              + iy * (local_fft_size[ZZ]) + iz;
313                     fftgrid[fftidx] = pmegrid[pmeidx];
314 #ifdef DEBUG_PME
315                     val = 100 * pmegrid[pmeidx];
316                     if (pmegrid[pmeidx] != 0)
317                     {
318                         gmx_fprintf_pdb_atomline(fp,
319                                                  epdbATOM,
320                                                  pmeidx,
321                                                  "CA",
322                                                  ' ',
323                                                  "GLY",
324                                                  ' ',
325                                                  pmeidx,
326                                                  ' ',
327                                                  5.0 * ix,
328                                                  5.0 * iy,
329                                                  5.0 * iz,
330                                                  1.0,
331                                                  val,
332                                                  "");
333                     }
334                     if (pmegrid[pmeidx] != 0)
335                     {
336                         fprintf(fp2,
337                                 "%-12s  %5d  %5d  %5d  %12.5e\n",
338                                 "qgrid",
339                                 pme->pmegrid_start_ix + ix,
340                                 pme->pmegrid_start_iy + iy,
341                                 pme->pmegrid_start_iz + iz,
342                                 pmegrid[pmeidx]);
343                     }
344 #endif
345                 }
346             }
347         }
348 #ifdef DEBUG_PME
349         gmx_ffclose(fp);
350         gmx_ffclose(fp2);
351 #endif
352     }
353     return 0;
354 }
355
356
357 #ifdef PME_TIME_THREADS
358 static gmx_cycles_t omp_cyc_start()
359 {
360     return gmx_cycles_read();
361 }
362
363 static gmx_cycles_t omp_cyc_end(gmx_cycles_t c)
364 {
365     return gmx_cycles_read() - c;
366 }
367 #endif
368
369
370 int copy_fftgrid_to_pmegrid(struct gmx_pme_t* pme, const real* fftgrid, real* pmegrid, int grid_index, int nthread, int thread)
371 {
372     ivec local_fft_ndata, local_fft_offset, local_fft_size;
373     ivec local_pme_size;
374     int  ixy0, ixy1, ixy, ix, iy, iz;
375     int  pmeidx, fftidx;
376 #ifdef PME_TIME_THREADS
377     gmx_cycles_t  c1;
378     static double cs1 = 0;
379     static int    cnt = 0;
380 #endif
381
382 #ifdef PME_TIME_THREADS
383     c1 = omp_cyc_start();
384 #endif
385     /* Dimensions should be identical for A/B grid, so we just use A here */
386     gmx_parallel_3dfft_real_limits(
387             pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset, local_fft_size);
388
389     local_pme_size[0] = pme->pmegrid_nx;
390     local_pme_size[1] = pme->pmegrid_ny;
391     local_pme_size[2] = pme->pmegrid_nz;
392
393     /* The fftgrid is always 'justified' to the lower-left corner of the PME grid,
394        the offset is identical, and the PME grid always has more data (due to overlap)
395      */
396     ixy0 = ((thread)*local_fft_ndata[XX] * local_fft_ndata[YY]) / nthread;
397     ixy1 = ((thread + 1) * local_fft_ndata[XX] * local_fft_ndata[YY]) / nthread;
398
399     for (ixy = ixy0; ixy < ixy1; ixy++)
400     {
401         ix = ixy / local_fft_ndata[YY];
402         iy = ixy - ix * local_fft_ndata[YY];
403
404         pmeidx = (ix * local_pme_size[YY] + iy) * local_pme_size[ZZ];
405         fftidx = (ix * local_fft_size[YY] + iy) * local_fft_size[ZZ];
406         for (iz = 0; iz < local_fft_ndata[ZZ]; iz++)
407         {
408             pmegrid[pmeidx + iz] = fftgrid[fftidx + iz];
409         }
410     }
411
412 #ifdef PME_TIME_THREADS
413     c1 = omp_cyc_end(c1);
414     cs1 += (double)c1;
415     cnt++;
416     if (cnt % 20 == 0)
417     {
418         printf("copy %.2f\n", cs1 * 1e-9);
419     }
420 #endif
421
422     return 0;
423 }
424
425
426 void wrap_periodic_pmegrid(const gmx_pme_t* pme, real* pmegrid)
427 {
428     int nx, ny, nz, pny, pnz, ny_x, overlap, ix, iy, iz;
429
430     nx = pme->nkx;
431     ny = pme->nky;
432     nz = pme->nkz;
433
434     pny = pme->pmegrid_ny;
435     pnz = pme->pmegrid_nz;
436
437     overlap = pme->pme_order - 1;
438
439     /* Add periodic overlap in z */
440     for (ix = 0; ix < pme->pmegrid_nx; ix++)
441     {
442         for (iy = 0; iy < pme->pmegrid_ny; iy++)
443         {
444             for (iz = 0; iz < overlap; iz++)
445             {
446                 pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[(ix * pny + iy) * pnz + nz + iz];
447             }
448         }
449     }
450
451     if (pme->nnodes_minor == 1)
452     {
453         for (ix = 0; ix < pme->pmegrid_nx; ix++)
454         {
455             for (iy = 0; iy < overlap; iy++)
456             {
457                 for (iz = 0; iz < nz; iz++)
458                 {
459                     pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[(ix * pny + ny + iy) * pnz + iz];
460                 }
461             }
462         }
463     }
464
465     if (pme->nnodes_major == 1)
466     {
467         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
468
469         for (ix = 0; ix < overlap; ix++)
470         {
471             for (iy = 0; iy < ny_x; iy++)
472             {
473                 for (iz = 0; iz < nz; iz++)
474                 {
475                     pmegrid[(ix * pny + iy) * pnz + iz] += pmegrid[((nx + ix) * pny + iy) * pnz + iz];
476                 }
477             }
478         }
479     }
480 }
481
482
483 void unwrap_periodic_pmegrid(struct gmx_pme_t* pme, real* pmegrid)
484 {
485     int nx, ny, nz, pny, pnz, ny_x, overlap, ix;
486
487     nx = pme->nkx;
488     ny = pme->nky;
489     nz = pme->nkz;
490
491     pny = pme->pmegrid_ny;
492     pnz = pme->pmegrid_nz;
493
494     overlap = pme->pme_order - 1;
495
496     if (pme->nnodes_major == 1)
497     {
498         ny_x = (pme->nnodes_minor == 1 ? ny : pme->pmegrid_ny);
499
500         for (ix = 0; ix < overlap; ix++)
501         {
502             int iy, iz;
503
504             for (iy = 0; iy < ny_x; iy++)
505             {
506                 for (iz = 0; iz < nz; iz++)
507                 {
508                     pmegrid[((nx + ix) * pny + iy) * pnz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
509                 }
510             }
511         }
512     }
513
514     if (pme->nnodes_minor == 1)
515     {
516 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
517         for (ix = 0; ix < pme->pmegrid_nx; ix++)
518         {
519             // Trivial OpenMP region that does not throw, no need for try/catch
520             int iy, iz;
521
522             for (iy = 0; iy < overlap; iy++)
523             {
524                 for (iz = 0; iz < nz; iz++)
525                 {
526                     pmegrid[(ix * pny + ny + iy) * pnz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
527                 }
528             }
529         }
530     }
531
532     /* Copy periodic overlap in z */
533 #pragma omp parallel for num_threads(pme->nthread) schedule(static)
534     for (ix = 0; ix < pme->pmegrid_nx; ix++)
535     {
536         // Trivial OpenMP region that does not throw, no need for try/catch
537         int iy, iz;
538
539         for (iy = 0; iy < pme->pmegrid_ny; iy++)
540         {
541             for (iz = 0; iz < overlap; iz++)
542             {
543                 pmegrid[(ix * pny + iy) * pnz + nz + iz] = pmegrid[(ix * pny + iy) * pnz + iz];
544             }
545         }
546     }
547 }
548
549 void set_grid_alignment(int gmx_unused* pmegrid_nz, int gmx_unused pme_order)
550 {
551 #ifdef PME_SIMD4_SPREAD_GATHER
552     if (pme_order == 5
553 #    if !PME_4NSIMD_GATHER
554         || pme_order == 4
555 #    endif
556     )
557     {
558         /* Round nz up to a multiple of 4 to ensure alignment */
559         *pmegrid_nz = ((*pmegrid_nz + 3) & ~3);
560     }
561 #endif
562 }
563
564 static void set_gridsize_alignment(int gmx_unused* gridsize, int gmx_unused pme_order)
565 {
566 #ifdef PME_SIMD4_SPREAD_GATHER
567 #    if !PME_4NSIMD_GATHER
568     if (pme_order == 4)
569     {
570         /* Add extra elements to ensured aligned operations do not go
571          * beyond the allocated grid size.
572          * Note that for pme_order=5, the pme grid z-size alignment
573          * ensures that we will not go beyond the grid size.
574          */
575         *gridsize += 4;
576     }
577 #    endif
578 #endif
579 }
580
581 void pmegrid_init(pmegrid_t* grid,
582                   int        cx,
583                   int        cy,
584                   int        cz,
585                   int        x0,
586                   int        y0,
587                   int        z0,
588                   int        x1,
589                   int        y1,
590                   int        z1,
591                   gmx_bool   set_alignment,
592                   int        pme_order,
593                   real*      ptr)
594 {
595     int nz, gridsize;
596
597     grid->ci[XX]     = cx;
598     grid->ci[YY]     = cy;
599     grid->ci[ZZ]     = cz;
600     grid->offset[XX] = x0;
601     grid->offset[YY] = y0;
602     grid->offset[ZZ] = z0;
603     grid->n[XX]      = x1 - x0 + pme_order - 1;
604     grid->n[YY]      = y1 - y0 + pme_order - 1;
605     grid->n[ZZ]      = z1 - z0 + pme_order - 1;
606     copy_ivec(grid->n, grid->s);
607
608     nz = grid->s[ZZ];
609     set_grid_alignment(&nz, pme_order);
610     if (set_alignment)
611     {
612         grid->s[ZZ] = nz;
613     }
614     else if (nz != grid->s[ZZ])
615     {
616         gmx_incons("pmegrid_init call with an unaligned z size");
617     }
618
619     grid->order = pme_order;
620     if (ptr == nullptr)
621     {
622         gridsize = grid->s[XX] * grid->s[YY] * grid->s[ZZ];
623         set_gridsize_alignment(&gridsize, pme_order);
624         snew_aligned(grid->grid, gridsize, SIMD4_ALIGNMENT);
625     }
626     else
627     {
628         grid->grid = ptr;
629     }
630 }
631
632 static int div_round_up(int enumerator, int denominator)
633 {
634     return (enumerator + denominator - 1) / denominator;
635 }
636
637 static void make_subgrid_division(const ivec n, int ovl, int nthread, ivec nsub)
638 {
639     int   gsize_opt, gsize;
640     int   nsx, nsy, nsz;
641     char* env;
642
643     gsize_opt = -1;
644     for (nsx = 1; nsx <= nthread; nsx++)
645     {
646         if (nthread % nsx == 0)
647         {
648             for (nsy = 1; nsy <= nthread; nsy++)
649             {
650                 if (nsx * nsy <= nthread && nthread % (nsx * nsy) == 0)
651                 {
652                     nsz = nthread / (nsx * nsy);
653
654                     /* Determine the number of grid points per thread */
655                     gsize = (div_round_up(n[XX], nsx) + ovl) * (div_round_up(n[YY], nsy) + ovl)
656                             * (div_round_up(n[ZZ], nsz) + ovl);
657
658                     /* Minimize the number of grids points per thread
659                      * and, secondarily, the number of cuts in minor dimensions.
660                      */
661                     if (gsize_opt == -1 || gsize < gsize_opt
662                         || (gsize == gsize_opt && (nsz < nsub[ZZ] || (nsz == nsub[ZZ] && nsy < nsub[YY]))))
663                     {
664                         nsub[XX]  = nsx;
665                         nsub[YY]  = nsy;
666                         nsub[ZZ]  = nsz;
667                         gsize_opt = gsize;
668                     }
669                 }
670             }
671         }
672     }
673
674     env = getenv("GMX_PME_THREAD_DIVISION");
675     if (env != nullptr)
676     {
677         sscanf(env, "%20d %20d %20d", &nsub[XX], &nsub[YY], &nsub[ZZ]);
678     }
679
680     if (nsub[XX] * nsub[YY] * nsub[ZZ] != nthread)
681     {
682         gmx_fatal(FARGS,
683                   "PME grid thread division (%d x %d x %d) does not match the total number of "
684                   "threads (%d)",
685                   nsub[XX],
686                   nsub[YY],
687                   nsub[ZZ],
688                   nthread);
689     }
690 }
691
692 void pmegrids_init(pmegrids_t* grids,
693                    int         nx,
694                    int         ny,
695                    int         nz,
696                    int         nz_base,
697                    int         pme_order,
698                    gmx_bool    bUseThreads,
699                    int         nthread,
700                    int         overlap_x,
701                    int         overlap_y)
702 {
703     ivec n, n_base;
704     int  t, x, y, z, d, i, tfac;
705     int  max_comm_lines = -1;
706
707     n[XX] = nx - (pme_order - 1);
708     n[YY] = ny - (pme_order - 1);
709     n[ZZ] = nz - (pme_order - 1);
710
711     copy_ivec(n, n_base);
712     n_base[ZZ] = nz_base;
713
714     pmegrid_init(&grids->grid, 0, 0, 0, 0, 0, 0, n[XX], n[YY], n[ZZ], FALSE, pme_order, nullptr);
715
716     grids->nthread = nthread;
717
718     make_subgrid_division(n_base, pme_order - 1, grids->nthread, grids->nc);
719
720     if (bUseThreads)
721     {
722         ivec nst;
723         int  gridsize;
724
725         for (d = 0; d < DIM; d++)
726         {
727             nst[d] = div_round_up(n[d], grids->nc[d]) + pme_order - 1;
728         }
729         set_grid_alignment(&nst[ZZ], pme_order);
730
731         if (debug)
732         {
733             fprintf(debug,
734                     "pmegrid thread local division: %d x %d x %d\n",
735                     grids->nc[XX],
736                     grids->nc[YY],
737                     grids->nc[ZZ]);
738             fprintf(debug, "pmegrid %d %d %d max thread pmegrid %d %d %d\n", nx, ny, nz, nst[XX], nst[YY], nst[ZZ]);
739         }
740
741         snew(grids->grid_th, grids->nthread);
742         t        = 0;
743         gridsize = nst[XX] * nst[YY] * nst[ZZ];
744         set_gridsize_alignment(&gridsize, pme_order);
745         snew_aligned(grids->grid_all,
746                      grids->nthread * gridsize + (grids->nthread + 1) * GMX_CACHE_SEP,
747                      SIMD4_ALIGNMENT);
748
749         for (x = 0; x < grids->nc[XX]; x++)
750         {
751             for (y = 0; y < grids->nc[YY]; y++)
752             {
753                 for (z = 0; z < grids->nc[ZZ]; z++)
754                 {
755                     pmegrid_init(&grids->grid_th[t],
756                                  x,
757                                  y,
758                                  z,
759                                  (n[XX] * (x)) / grids->nc[XX],
760                                  (n[YY] * (y)) / grids->nc[YY],
761                                  (n[ZZ] * (z)) / grids->nc[ZZ],
762                                  (n[XX] * (x + 1)) / grids->nc[XX],
763                                  (n[YY] * (y + 1)) / grids->nc[YY],
764                                  (n[ZZ] * (z + 1)) / grids->nc[ZZ],
765                                  TRUE,
766                                  pme_order,
767                                  grids->grid_all + GMX_CACHE_SEP + t * (gridsize + GMX_CACHE_SEP));
768                     t++;
769                 }
770             }
771         }
772     }
773     else
774     {
775         grids->grid_th = nullptr;
776     }
777
778     tfac = 1;
779     for (d = DIM - 1; d >= 0; d--)
780     {
781         snew(grids->g2t[d], n[d]);
782         t = 0;
783         for (i = 0; i < n[d]; i++)
784         {
785             /* The second check should match the parameters
786              * of the pmegrid_init call above.
787              */
788             while (t + 1 < grids->nc[d] && i >= (n[d] * (t + 1)) / grids->nc[d])
789             {
790                 t++;
791             }
792             grids->g2t[d][i] = t * tfac;
793         }
794
795         tfac *= grids->nc[d];
796
797         switch (d)
798         {
799             case XX: max_comm_lines = overlap_x; break;
800             case YY: max_comm_lines = overlap_y; break;
801             case ZZ: max_comm_lines = pme_order - 1; break;
802         }
803         grids->nthread_comm[d] = 0;
804         while ((n[d] * grids->nthread_comm[d]) / grids->nc[d] < max_comm_lines
805                && grids->nthread_comm[d] < grids->nc[d])
806         {
807             grids->nthread_comm[d]++;
808         }
809         if (debug != nullptr)
810         {
811             fprintf(debug,
812                     "pmegrid thread grid communication range in %c: %d\n",
813                     'x' + d,
814                     grids->nthread_comm[d]);
815         }
816         /* It should be possible to make grids->nthread_comm[d]==grids->nc[d]
817          * work, but this is not a problematic restriction.
818          */
819         if (grids->nc[d] > 1 && grids->nthread_comm[d] > grids->nc[d])
820         {
821             gmx_fatal(FARGS,
822                       "Too many threads for PME (%d) compared to the number of grid lines, reduce "
823                       "the number of threads doing PME",
824                       grids->nthread);
825         }
826     }
827 }
828
829 void pmegrids_destroy(pmegrids_t* grids)
830 {
831     if (grids->grid.grid != nullptr)
832     {
833         sfree_aligned(grids->grid.grid);
834
835         if (grids->nthread > 0)
836         {
837             sfree_aligned(grids->grid_all);
838             sfree(grids->grid_th);
839         }
840         for (int d = 0; d < DIM; d++)
841         {
842             sfree(grids->g2t[d]);
843         }
844     }
845 }
846
847 void make_gridindex_to_localindex(int n, int local_start, int local_range, int** global_to_local, real** fraction_shift)
848 {
849     /* Here we construct array for looking up the grid line index and
850      * fraction for particles. This is done because it is slighlty
851      * faster than the modulo operation and to because we need to take
852      * care of rounding issues, see below.
853      * We use an array size of c_pmeNeighborUnitcellCount times the grid size
854      * to allow for particles to be out of the triclinic unit-cell.
855      */
856     const int arraySize = c_pmeNeighborUnitcellCount * n;
857     int*      gtl;
858     real*     fsh;
859
860     snew(gtl, arraySize);
861     snew(fsh, arraySize);
862
863     for (int i = 0; i < arraySize; i++)
864     {
865         /* Transform global grid index to the local grid index.
866          * Our local grid always runs from 0 to local_range-1.
867          */
868         gtl[i] = (i - local_start + n) % n;
869         /* For coordinates that fall within the local grid the fraction
870          * is correct, we don't need to shift it.
871          */
872         fsh[i] = 0;
873         /* Check if we are using domain decomposition for PME */
874         if (local_range < n)
875         {
876             /* Due to rounding issues i could be 1 beyond the lower or
877              * upper boundary of the local grid. Correct the index for this.
878              * If we shift the index, we need to shift the fraction by
879              * the same amount in the other direction to not affect
880              * the weights.
881              * Note that due to this shifting the weights at the end of
882              * the spline might change, but that will only involve values
883              * between zero and values close to the precision of a real,
884              * which is anyhow the accuracy of the whole mesh calculation.
885              */
886             if (gtl[i] == n - 1)
887             {
888                 /* When this i is used, we should round the local index up */
889                 gtl[i] = 0;
890                 fsh[i] = -1;
891             }
892             else if (gtl[i] == local_range && local_range > 0)
893             {
894                 /* When this i is used, we should round the local index down */
895                 gtl[i] = local_range - 1;
896                 fsh[i] = 1;
897             }
898         }
899     }
900
901     *global_to_local = gtl;
902     *fraction_shift  = fsh;
903 }
904
905 void reuse_pmegrids(const pmegrids_t* oldgrid, pmegrids_t* newgrid)
906 {
907     int d, t;
908
909     for (d = 0; d < DIM; d++)
910     {
911         if (newgrid->grid.n[d] > oldgrid->grid.n[d])
912         {
913             return;
914         }
915     }
916
917     sfree_aligned(newgrid->grid.grid);
918     newgrid->grid.grid = oldgrid->grid.grid;
919
920     if (newgrid->grid_th != nullptr && newgrid->nthread == oldgrid->nthread)
921     {
922         sfree_aligned(newgrid->grid_all);
923         newgrid->grid_all = oldgrid->grid_all;
924         for (t = 0; t < newgrid->nthread; t++)
925         {
926             newgrid->grid_th[t].grid = oldgrid->grid_th[t].grid;
927         }
928     }
929 }