SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / ewald / pme_spread.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017 by the GROMACS development team.
7  * Copyright (c) 2018,2019,2020, by the GROMACS development team, led by
8  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
9  * and including many others, as listed in the AUTHORS file in the
10  * top-level source directory and at http://www.gromacs.org.
11  *
12  * GROMACS is free software; you can redistribute it and/or
13  * modify it under the terms of the GNU Lesser General Public License
14  * as published by the Free Software Foundation; either version 2.1
15  * of the License, or (at your option) any later version.
16  *
17  * GROMACS is distributed in the hope that it will be useful,
18  * but WITHOUT ANY WARRANTY; without even the implied warranty of
19  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
20  * Lesser General Public License for more details.
21  *
22  * You should have received a copy of the GNU Lesser General Public
23  * License along with GROMACS; if not, see
24  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
25  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
26  *
27  * If you want to redistribute modifications to GROMACS, please
28  * consider that scientific software is very special. Version
29  * control is crucial - bugs must be traceable. We will be happy to
30  * consider code for inclusion in the official distribution, but
31  * derived work must not be called official GROMACS. Details are found
32  * in the README & COPYING files - if they are missing, get the
33  * official version at http://www.gromacs.org.
34  *
35  * To help us fund GROMACS development, we humbly ask that you cite
36  * the research papers on the package. Check out http://www.gromacs.org.
37  */
38 #include "gmxpre.h"
39
40 #include "pme_spread.h"
41
42 #include "config.h"
43
44 #include <cassert>
45
46 #include <algorithm>
47
48 #include "gromacs/ewald/pme.h"
49 #include "gromacs/fft/parallel_3dfft.h"
50 #include "gromacs/simd/simd.h"
51 #include "gromacs/utility/basedefinitions.h"
52 #include "gromacs/utility/exceptions.h"
53 #include "gromacs/utility/fatalerror.h"
54 #include "gromacs/utility/gmxassert.h"
55 #include "gromacs/utility/smalloc.h"
56
57 #include "pme_grid.h"
58 #include "pme_internal.h"
59 #include "pme_simd.h"
60 #include "pme_spline_work.h"
61 #include "spline_vectors.h"
62
63 /* TODO consider split of pme-spline from this file */
64
65 static void calc_interpolation_idx(const gmx_pme_t* pme, PmeAtomComm* atc, int start, int grid_index, int end, int thread)
66 {
67     int         i;
68     int *       idxptr, tix, tiy, tiz;
69     const real* xptr;
70     real *      fptr, tx, ty, tz;
71     real        rxx, ryx, ryy, rzx, rzy, rzz;
72     int         nx, ny, nz;
73     int *       g2tx, *g2ty, *g2tz;
74     gmx_bool    bThreads;
75     int*        thread_idx = nullptr;
76     int*        tpl_n      = nullptr;
77     int         thread_i;
78
79     nx = pme->nkx;
80     ny = pme->nky;
81     nz = pme->nkz;
82
83     rxx = pme->recipbox[XX][XX];
84     ryx = pme->recipbox[YY][XX];
85     ryy = pme->recipbox[YY][YY];
86     rzx = pme->recipbox[ZZ][XX];
87     rzy = pme->recipbox[ZZ][YY];
88     rzz = pme->recipbox[ZZ][ZZ];
89
90     g2tx = pme->pmegrid[grid_index].g2t[XX];
91     g2ty = pme->pmegrid[grid_index].g2t[YY];
92     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
93
94     bThreads = (atc->nthread > 1);
95     if (bThreads)
96     {
97         thread_idx = atc->thread_idx.data();
98
99         tpl_n = atc->threadMap[thread].n;
100         for (i = 0; i < atc->nthread; i++)
101         {
102             tpl_n[i] = 0;
103         }
104     }
105
106     const real shift = c_pmeMaxUnitcellShift;
107
108     for (i = start; i < end; i++)
109     {
110         xptr   = atc->x[i];
111         idxptr = atc->idx[i];
112         fptr   = atc->fractx[i];
113
114         /* Fractional coordinates along box vectors, add a positive shift to ensure tx/ty/tz are positive for triclinic boxes */
115         tx = nx * (xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + shift);
116         ty = ny * (xptr[YY] * ryy + xptr[ZZ] * rzy + shift);
117         tz = nz * (xptr[ZZ] * rzz + shift);
118
119         tix = static_cast<int>(tx);
120         tiy = static_cast<int>(ty);
121         tiz = static_cast<int>(tz);
122
123 #ifdef DEBUG
124         range_check(tix, 0, c_pmeNeighborUnitcellCount * nx);
125         range_check(tiy, 0, c_pmeNeighborUnitcellCount * ny);
126         range_check(tiz, 0, c_pmeNeighborUnitcellCount * nz);
127 #endif
128         /* Because decomposition only occurs in x and y,
129          * we never have a fraction correction in z.
130          */
131         fptr[XX] = tx - tix + pme->fshx[tix];
132         fptr[YY] = ty - tiy + pme->fshy[tiy];
133         fptr[ZZ] = tz - tiz;
134
135         idxptr[XX] = pme->nnx[tix];
136         idxptr[YY] = pme->nny[tiy];
137         idxptr[ZZ] = pme->nnz[tiz];
138
139 #ifdef DEBUG
140         range_check(idxptr[XX], 0, pme->pmegrid_nx);
141         range_check(idxptr[YY], 0, pme->pmegrid_ny);
142         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
143 #endif
144
145         if (bThreads)
146         {
147             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
148             thread_idx[i] = thread_i;
149             tpl_n[thread_i]++;
150         }
151     }
152
153     if (bThreads)
154     {
155         /* Make a list of particle indices sorted on thread */
156
157         /* Get the cumulative count */
158         for (i = 1; i < atc->nthread; i++)
159         {
160             tpl_n[i] += tpl_n[i - 1];
161         }
162         /* The current implementation distributes particles equally
163          * over the threads, so we could actually allocate for that
164          * in pme_realloc_atomcomm_things.
165          */
166         AtomToThreadMap& threadMap = atc->threadMap[thread];
167         threadMap.i.resize(tpl_n[atc->nthread - 1]);
168         /* Set tpl_n to the cumulative start */
169         for (i = atc->nthread - 1; i >= 1; i--)
170         {
171             tpl_n[i] = tpl_n[i - 1];
172         }
173         tpl_n[0] = 0;
174
175         /* Fill our thread local array with indices sorted on thread */
176         for (i = start; i < end; i++)
177         {
178             threadMap.i[tpl_n[atc->thread_idx[i]]++] = i;
179         }
180         /* Now tpl_n contains the cummulative count again */
181     }
182 }
183
184 static void make_thread_local_ind(const PmeAtomComm* atc, int thread, splinedata_t* spline)
185 {
186     int n, t, i, start, end;
187
188     /* Combine the indices made by each thread into one index */
189
190     n     = 0;
191     start = 0;
192     for (t = 0; t < atc->nthread; t++)
193     {
194         const AtomToThreadMap& threadMap = atc->threadMap[t];
195         /* Copy our part (start - end) from the list of thread t */
196         if (thread > 0)
197         {
198             start = threadMap.n[thread - 1];
199         }
200         end = threadMap.n[thread];
201         for (i = start; i < end; i++)
202         {
203             spline->ind[n++] = threadMap.i[i];
204         }
205     }
206
207     spline->n = n;
208 }
209
210 // At run time, the values of order used and asserted upon mean that
211 // indexing out of bounds does not occur. However compilers don't
212 // always understand that, so we suppress this warning for this code
213 // region.
214 #pragma GCC diagnostic push
215 #pragma GCC diagnostic ignored "-Warray-bounds"
216
217 /* Macro to force loop unrolling by fixing order.
218  * This gives a significant performance gain.
219  */
220 #define CALC_SPLINE(order)                                                                               \
221     {                                                                                                    \
222         for (int j = 0; (j < DIM); j++)                                                                  \
223         {                                                                                                \
224             real dr, div;                                                                                \
225             real data[PME_ORDER_MAX];                                                                    \
226                                                                                                          \
227             dr = xptr[j];                                                                                \
228                                                                                                          \
229             /* dr is relative offset from lower cell limit */                                            \
230             data[(order)-1] = 0;                                                                         \
231             data[1]         = dr;                                                                        \
232             data[0]         = 1 - dr;                                                                    \
233                                                                                                          \
234             for (int k = 3; (k < (order)); k++)                                                          \
235             {                                                                                            \
236                 div         = 1.0 / (k - 1.0);                                                           \
237                 data[k - 1] = div * dr * data[k - 2];                                                    \
238                 for (int l = 1; (l < (k - 1)); l++)                                                      \
239                 {                                                                                        \
240                     data[k - l - 1] =                                                                    \
241                             div * ((dr + l) * data[k - l - 2] + (k - l - dr) * data[k - l - 1]);         \
242                 }                                                                                        \
243                 data[0] = div * (1 - dr) * data[0];                                                      \
244             }                                                                                            \
245             /* differentiate */                                                                          \
246             dtheta[j][i * (order) + 0] = -data[0];                                                       \
247             for (int k = 1; (k < (order)); k++)                                                          \
248             {                                                                                            \
249                 dtheta[j][i * (order) + k] = data[k - 1] - data[k];                                      \
250             }                                                                                            \
251                                                                                                          \
252             div             = 1.0 / ((order)-1);                                                         \
253             data[(order)-1] = div * dr * data[(order)-2];                                                \
254             for (int l = 1; (l < ((order)-1)); l++)                                                      \
255             {                                                                                            \
256                 data[(order)-l - 1] =                                                                    \
257                         div * ((dr + l) * data[(order)-l - 2] + ((order)-l - dr) * data[(order)-l - 1]); \
258             }                                                                                            \
259             data[0] = div * (1 - dr) * data[0];                                                          \
260                                                                                                          \
261             for (int k = 0; k < (order); k++)                                                            \
262             {                                                                                            \
263                 theta[j][i * (order) + k] = data[k];                                                     \
264             }                                                                                            \
265         }                                                                                                \
266     }
267
268 static void make_bsplines(splinevec  theta,
269                           splinevec  dtheta,
270                           int        order,
271                           rvec       fractx[],
272                           int        nr,
273                           const int  ind[],
274                           const real coefficient[],
275                           gmx_bool   bDoSplines)
276 {
277     /* construct splines for local atoms */
278     int   i, ii;
279     real* xptr;
280
281     for (i = 0; i < nr; i++)
282     {
283         /* With free energy we do not use the coefficient check.
284          * In most cases this will be more efficient than calling make_bsplines
285          * twice, since usually more than half the particles have non-zero coefficients.
286          */
287         ii = ind[i];
288         if (bDoSplines || coefficient[ii] != 0.0)
289         {
290             xptr = fractx[ii];
291             assert(order >= 3 && order <= PME_ORDER_MAX);
292             switch (order)
293             {
294                 case 4: CALC_SPLINE(4) break;
295                 case 5: CALC_SPLINE(5) break;
296                 default: CALC_SPLINE(order) break;
297             }
298         }
299     }
300 }
301
302 #pragma GCC diagnostic pop
303
304 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
305 #define DO_BSPLINE(order)                             \
306     for (ithx = 0; (ithx < (order)); ithx++)          \
307     {                                                 \
308         index_x = (i0 + ithx) * pny * pnz;            \
309         valx    = coefficient * thx[ithx];            \
310                                                       \
311         for (ithy = 0; (ithy < (order)); ithy++)      \
312         {                                             \
313             valxy    = valx * thy[ithy];              \
314             index_xy = index_x + (j0 + ithy) * pnz;   \
315                                                       \
316             for (ithz = 0; (ithz < (order)); ithz++)  \
317             {                                         \
318                 index_xyz = index_xy + (k0 + ithz);   \
319                 grid[index_xyz] += valxy * thz[ithz]; \
320             }                                         \
321         }                                             \
322     }
323
324
325 static void spread_coefficients_bsplines_thread(const pmegrid_t*       pmegrid,
326                                                 const PmeAtomComm*     atc,
327                                                 splinedata_t*          spline,
328                                                 struct pme_spline_work gmx_unused* work)
329 {
330
331     /* spread coefficients from home atoms to local grid */
332     real*      grid;
333     int        i, nn, n, ithx, ithy, ithz, i0, j0, k0;
334     const int* idxptr;
335     int        order, norder, index_x, index_xy, index_xyz;
336     real       valx, valxy, coefficient;
337     real *     thx, *thy, *thz;
338     int        pnx, pny, pnz, ndatatot;
339     int        offx, offy, offz;
340
341 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
342     alignas(GMX_SIMD_ALIGNMENT) real thz_aligned[GMX_SIMD4_WIDTH * 2];
343 #endif
344
345     pnx = pmegrid->s[XX];
346     pny = pmegrid->s[YY];
347     pnz = pmegrid->s[ZZ];
348
349     offx = pmegrid->offset[XX];
350     offy = pmegrid->offset[YY];
351     offz = pmegrid->offset[ZZ];
352
353     ndatatot = pnx * pny * pnz;
354     grid     = pmegrid->grid;
355     for (i = 0; i < ndatatot; i++)
356     {
357         grid[i] = 0;
358     }
359
360     order = pmegrid->order;
361
362     for (nn = 0; nn < spline->n; nn++)
363     {
364         n           = spline->ind[nn];
365         coefficient = atc->coefficient[n];
366
367         if (coefficient != 0)
368         {
369             idxptr = atc->idx[n];
370             norder = nn * order;
371
372             i0 = idxptr[XX] - offx;
373             j0 = idxptr[YY] - offy;
374             k0 = idxptr[ZZ] - offz;
375
376             thx = spline->theta.coefficients[XX] + norder;
377             thy = spline->theta.coefficients[YY] + norder;
378             thz = spline->theta.coefficients[ZZ] + norder;
379
380             switch (order)
381             {
382                 case 4:
383 #ifdef PME_SIMD4_SPREAD_GATHER
384 #    ifdef PME_SIMD4_UNALIGNED
385 #        define PME_SPREAD_SIMD4_ORDER4
386 #    else
387 #        define PME_SPREAD_SIMD4_ALIGNED
388 #        define PME_ORDER 4
389 #    endif
390 #    include "pme_simd4.h"
391 #else
392                     DO_BSPLINE(4)
393 #endif
394                     break;
395                 case 5:
396 #ifdef PME_SIMD4_SPREAD_GATHER
397 #    define PME_SPREAD_SIMD4_ALIGNED
398 #    define PME_ORDER 5
399 #    include "pme_simd4.h"
400 #else
401                     DO_BSPLINE(5)
402 #endif
403                     break;
404                 default: DO_BSPLINE(order) break;
405             }
406         }
407     }
408 }
409
410 static void copy_local_grid(const gmx_pme_t* pme, const pmegrids_t* pmegrids, int grid_index, int thread, real* fftgrid)
411 {
412     ivec  local_fft_ndata, local_fft_offset, local_fft_size;
413     int   fft_my, fft_mz;
414     int   nsy, nsz;
415     ivec  nf;
416     int   offx, offy, offz, x, y, z, i0, i0t;
417     int   d;
418     real* grid_th;
419
420     gmx_parallel_3dfft_real_limits(
421             pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset, local_fft_size);
422     fft_my = local_fft_size[YY];
423     fft_mz = local_fft_size[ZZ];
424
425     const pmegrid_t* pmegrid = &pmegrids->grid_th[thread];
426
427     nsy = pmegrid->s[YY];
428     nsz = pmegrid->s[ZZ];
429
430     for (d = 0; d < DIM; d++)
431     {
432         nf[d] = std::min(pmegrid->n[d] - (pmegrid->order - 1), local_fft_ndata[d] - pmegrid->offset[d]);
433     }
434
435     offx = pmegrid->offset[XX];
436     offy = pmegrid->offset[YY];
437     offz = pmegrid->offset[ZZ];
438
439     /* Directly copy the non-overlapping parts of the local grids.
440      * This also initializes the full grid.
441      */
442     grid_th = pmegrid->grid;
443     for (x = 0; x < nf[XX]; x++)
444     {
445         for (y = 0; y < nf[YY]; y++)
446         {
447             i0  = ((offx + x) * fft_my + (offy + y)) * fft_mz + offz;
448             i0t = (x * nsy + y) * nsz;
449             for (z = 0; z < nf[ZZ]; z++)
450             {
451                 fftgrid[i0 + z] = grid_th[i0t + z];
452             }
453         }
454     }
455 }
456
457 static void reduce_threadgrid_overlap(const gmx_pme_t*  pme,
458                                       const pmegrids_t* pmegrids,
459                                       int               thread,
460                                       real*             fftgrid,
461                                       real*             commbuf_x,
462                                       real*             commbuf_y,
463                                       int               grid_index)
464 {
465     ivec             local_fft_ndata, local_fft_offset, local_fft_size;
466     int              fft_nx, fft_ny, fft_nz;
467     int              fft_my, fft_mz;
468     int              buf_my = -1;
469     int              nsy, nsz;
470     ivec             localcopy_end, commcopy_end;
471     int              offx, offy, offz, x, y, z, i0, i0t;
472     int              sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
473     gmx_bool         bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
474     gmx_bool         bCommX, bCommY;
475     int              d;
476     int              thread_f;
477     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
478     const real*      grid_th;
479     real*            commbuf = nullptr;
480
481     gmx_parallel_3dfft_real_limits(
482             pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset, local_fft_size);
483     fft_nx = local_fft_ndata[XX];
484     fft_ny = local_fft_ndata[YY];
485     fft_nz = local_fft_ndata[ZZ];
486
487     fft_my = local_fft_size[YY];
488     fft_mz = local_fft_size[ZZ];
489
490     /* This routine is called when all thread have finished spreading.
491      * Here each thread sums grid contributions calculated by other threads
492      * to the thread local grid volume.
493      * To minimize the number of grid copying operations,
494      * this routines sums immediately from the pmegrid to the fftgrid.
495      */
496
497     /* Determine which part of the full node grid we should operate on,
498      * this is our thread local part of the full grid.
499      */
500     pmegrid = &pmegrids->grid_th[thread];
501
502     for (d = 0; d < DIM; d++)
503     {
504         /* Determine up to where our thread needs to copy from the
505          * thread-local charge spreading grid to the rank-local FFT grid.
506          * This is up to our spreading grid end minus order-1 and
507          * not beyond the local FFT grid.
508          */
509         localcopy_end[d] = std::min(pmegrid->offset[d] + pmegrid->n[d] - (pmegrid->order - 1),
510                                     local_fft_ndata[d]);
511
512         /* Determine up to where our thread needs to copy from the
513          * thread-local charge spreading grid to the communication buffer.
514          * Note: only relevant with communication, ignored otherwise.
515          */
516         commcopy_end[d] = localcopy_end[d];
517         if (pmegrid->ci[d] == pmegrids->nc[d] - 1)
518         {
519             /* The last thread should copy up to the last pme grid line.
520              * When the rank-local FFT grid is narrower than pme-order,
521              * we need the max below to ensure copying of all data.
522              */
523             commcopy_end[d] = std::max(commcopy_end[d], pme->pme_order);
524         }
525     }
526
527     offx = pmegrid->offset[XX];
528     offy = pmegrid->offset[YY];
529     offz = pmegrid->offset[ZZ];
530
531
532     bClearBufX  = TRUE;
533     bClearBufY  = TRUE;
534     bClearBufXY = TRUE;
535
536     /* Now loop over all the thread data blocks that contribute
537      * to the grid region we (our thread) are operating on.
538      */
539     /* Note that fft_nx/y is equal to the number of grid points
540      * between the first point of our node grid and the one of the next node.
541      */
542     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
543     {
544         fx     = pmegrid->ci[XX] + sx;
545         ox     = 0;
546         bCommX = FALSE;
547         if (fx < 0)
548         {
549             fx += pmegrids->nc[XX];
550             ox -= fft_nx;
551             bCommX = (pme->nnodes_major > 1);
552         }
553         pmegrid_g = &pmegrids->grid_th[fx * pmegrids->nc[YY] * pmegrids->nc[ZZ]];
554         ox += pmegrid_g->offset[XX];
555         /* Determine the end of our part of the source grid.
556          * Use our thread local source grid and target grid part
557          */
558         tx1 = std::min(ox + pmegrid_g->n[XX], !bCommX ? localcopy_end[XX] : commcopy_end[XX]);
559
560         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
561         {
562             fy     = pmegrid->ci[YY] + sy;
563             oy     = 0;
564             bCommY = FALSE;
565             if (fy < 0)
566             {
567                 fy += pmegrids->nc[YY];
568                 oy -= fft_ny;
569                 bCommY = (pme->nnodes_minor > 1);
570             }
571             pmegrid_g = &pmegrids->grid_th[fy * pmegrids->nc[ZZ]];
572             oy += pmegrid_g->offset[YY];
573             /* Determine the end of our part of the source grid.
574              * Use our thread local source grid and target grid part
575              */
576             ty1 = std::min(oy + pmegrid_g->n[YY], !bCommY ? localcopy_end[YY] : commcopy_end[YY]);
577
578             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
579             {
580                 fz = pmegrid->ci[ZZ] + sz;
581                 oz = 0;
582                 if (fz < 0)
583                 {
584                     fz += pmegrids->nc[ZZ];
585                     oz -= fft_nz;
586                 }
587                 pmegrid_g = &pmegrids->grid_th[fz];
588                 oz += pmegrid_g->offset[ZZ];
589                 tz1 = std::min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
590
591                 if (sx == 0 && sy == 0 && sz == 0)
592                 {
593                     /* We have already added our local contribution
594                      * before calling this routine, so skip it here.
595                      */
596                     continue;
597                 }
598
599                 thread_f = (fx * pmegrids->nc[YY] + fy) * pmegrids->nc[ZZ] + fz;
600
601                 pmegrid_f = &pmegrids->grid_th[thread_f];
602
603                 grid_th = pmegrid_f->grid;
604
605                 nsy = pmegrid_f->s[YY];
606                 nsz = pmegrid_f->s[ZZ];
607
608 #ifdef DEBUG_PME_REDUCE
609                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d "
610                        "%2d-%2d, %2d-%2d %2d-%2d\n",
611                        pme->nodeid,
612                        thread,
613                        thread_f,
614                        pme->pmegrid_start_ix,
615                        pme->pmegrid_start_iy,
616                        pme->pmegrid_start_iz,
617                        sx,
618                        sy,
619                        sz,
620                        offx - ox,
621                        tx1 - ox,
622                        offx,
623                        tx1,
624                        offy - oy,
625                        ty1 - oy,
626                        offy,
627                        ty1,
628                        offz - oz,
629                        tz1 - oz,
630                        offz,
631                        tz1);
632 #endif
633
634                 if (!(bCommX || bCommY))
635                 {
636                     /* Copy from the thread local grid to the node grid */
637                     for (x = offx; x < tx1; x++)
638                     {
639                         for (y = offy; y < ty1; y++)
640                         {
641                             i0  = (x * fft_my + y) * fft_mz;
642                             i0t = ((x - ox) * nsy + (y - oy)) * nsz - oz;
643                             for (z = offz; z < tz1; z++)
644                             {
645                                 fftgrid[i0 + z] += grid_th[i0t + z];
646                             }
647                         }
648                     }
649                 }
650                 else
651                 {
652                     /* The order of this conditional decides
653                      * where the corner volume gets stored with x+y decomp.
654                      */
655                     if (bCommY)
656                     {
657                         commbuf = commbuf_y;
658                         /* The y-size of the communication buffer is set by
659                          * the overlap of the grid part of our local slab
660                          * with the part starting at the next slab.
661                          */
662                         buf_my = pme->overlap[1].s2g1[pme->nodeid_minor]
663                                  - pme->overlap[1].s2g0[pme->nodeid_minor + 1];
664                         if (bCommX)
665                         {
666                             /* We index commbuf modulo the local grid size */
667                             commbuf += buf_my * fft_nx * fft_nz;
668
669                             bClearBuf   = bClearBufXY;
670                             bClearBufXY = FALSE;
671                         }
672                         else
673                         {
674                             bClearBuf  = bClearBufY;
675                             bClearBufY = FALSE;
676                         }
677                     }
678                     else
679                     {
680                         commbuf    = commbuf_x;
681                         buf_my     = fft_ny;
682                         bClearBuf  = bClearBufX;
683                         bClearBufX = FALSE;
684                     }
685
686                     /* Copy to the communication buffer */
687                     for (x = offx; x < tx1; x++)
688                     {
689                         for (y = offy; y < ty1; y++)
690                         {
691                             i0  = (x * buf_my + y) * fft_nz;
692                             i0t = ((x - ox) * nsy + (y - oy)) * nsz - oz;
693
694                             if (bClearBuf)
695                             {
696                                 /* First access of commbuf, initialize it */
697                                 for (z = offz; z < tz1; z++)
698                                 {
699                                     commbuf[i0 + z] = grid_th[i0t + z];
700                                 }
701                             }
702                             else
703                             {
704                                 for (z = offz; z < tz1; z++)
705                                 {
706                                     commbuf[i0 + z] += grid_th[i0t + z];
707                                 }
708                             }
709                         }
710                     }
711                 }
712             }
713         }
714     }
715 }
716
717
718 static void sum_fftgrid_dd(const gmx_pme_t* pme, real* fftgrid, int grid_index)
719 {
720     ivec local_fft_ndata, local_fft_offset, local_fft_size;
721     int  send_index0, send_nindex;
722     int  recv_nindex;
723 #if GMX_MPI
724     MPI_Status stat;
725 #endif
726     int recv_size_y;
727     int size_yx;
728     int x, y, z, indg, indb;
729
730     /* Note that this routine is only used for forward communication.
731      * Since the force gathering, unlike the coefficient spreading,
732      * can be trivially parallelized over the particles,
733      * the backwards process is much simpler and can use the "old"
734      * communication setup.
735      */
736
737     gmx_parallel_3dfft_real_limits(
738             pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset, local_fft_size);
739
740     if (pme->nnodes_minor > 1)
741     {
742         /* Major dimension */
743         const pme_overlap_t* overlap = &pme->overlap[1];
744
745         if (pme->nnodes_major > 1)
746         {
747             size_yx = pme->overlap[0].comm_data[0].send_nindex;
748         }
749         else
750         {
751             size_yx = 0;
752         }
753 #if GMX_MPI
754         int datasize = (local_fft_ndata[XX] + size_yx) * local_fft_ndata[ZZ];
755
756         int send_size_y = overlap->send_size;
757 #endif
758
759         for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
760         {
761             send_index0 = overlap->comm_data[ipulse].send_index0 - overlap->comm_data[0].send_index0;
762             send_nindex = overlap->comm_data[ipulse].send_nindex;
763             /* We don't use recv_index0, as we always receive starting at 0 */
764             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
765             recv_size_y = overlap->comm_data[ipulse].recv_size;
766
767             auto* sendptr =
768                     const_cast<real*>(overlap->sendbuf.data()) + send_index0 * local_fft_ndata[ZZ];
769             auto* recvptr = const_cast<real*>(overlap->recvbuf.data());
770
771             if (debug != nullptr)
772             {
773                 fprintf(debug,
774                         "PME fftgrid comm y %2d x %2d x %2d\n",
775                         local_fft_ndata[XX],
776                         send_nindex,
777                         local_fft_ndata[ZZ]);
778             }
779
780 #if GMX_MPI
781             int send_id = overlap->comm_data[ipulse].send_id;
782             int recv_id = overlap->comm_data[ipulse].recv_id;
783             MPI_Sendrecv(sendptr,
784                          send_size_y * datasize,
785                          GMX_MPI_REAL,
786                          send_id,
787                          ipulse,
788                          recvptr,
789                          recv_size_y * datasize,
790                          GMX_MPI_REAL,
791                          recv_id,
792                          ipulse,
793                          overlap->mpi_comm,
794                          &stat);
795 #endif
796
797             for (x = 0; x < local_fft_ndata[XX]; x++)
798             {
799                 for (y = 0; y < recv_nindex; y++)
800                 {
801                     indg = (x * local_fft_size[YY] + y) * local_fft_size[ZZ];
802                     indb = (x * recv_size_y + y) * local_fft_ndata[ZZ];
803                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
804                     {
805                         fftgrid[indg + z] += recvptr[indb + z];
806                     }
807                 }
808             }
809
810             if (pme->nnodes_major > 1)
811             {
812                 /* Copy from the received buffer to the send buffer for dim 0 */
813                 sendptr = const_cast<real*>(pme->overlap[0].sendbuf.data());
814                 for (x = 0; x < size_yx; x++)
815                 {
816                     for (y = 0; y < recv_nindex; y++)
817                     {
818                         indg = (x * local_fft_ndata[YY] + y) * local_fft_ndata[ZZ];
819                         indb = ((local_fft_ndata[XX] + x) * recv_size_y + y) * local_fft_ndata[ZZ];
820                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
821                         {
822                             sendptr[indg + z] += recvptr[indb + z];
823                         }
824                     }
825                 }
826             }
827         }
828     }
829
830     /* We only support a single pulse here.
831      * This is not a severe limitation, as this code is only used
832      * with OpenMP and with OpenMP the (PME) domains can be larger.
833      */
834     if (pme->nnodes_major > 1)
835     {
836         /* Major dimension */
837         const pme_overlap_t* overlap = &pme->overlap[0];
838
839         size_t ipulse = 0;
840
841         send_nindex = overlap->comm_data[ipulse].send_nindex;
842         /* We don't use recv_index0, as we always receive starting at 0 */
843         recv_nindex = overlap->comm_data[ipulse].recv_nindex;
844
845         if (debug != nullptr)
846         {
847             fprintf(debug,
848                     "PME fftgrid comm x %2d x %2d x %2d\n",
849                     send_nindex,
850                     local_fft_ndata[YY],
851                     local_fft_ndata[ZZ]);
852         }
853
854 #if GMX_MPI
855         int   datasize = local_fft_ndata[YY] * local_fft_ndata[ZZ];
856         int   send_id  = overlap->comm_data[ipulse].send_id;
857         int   recv_id  = overlap->comm_data[ipulse].recv_id;
858         auto* sendptr  = const_cast<real*>(overlap->sendbuf.data());
859         auto* recvptr  = const_cast<real*>(overlap->recvbuf.data());
860         MPI_Sendrecv(sendptr,
861                      send_nindex * datasize,
862                      GMX_MPI_REAL,
863                      send_id,
864                      ipulse,
865                      recvptr,
866                      recv_nindex * datasize,
867                      GMX_MPI_REAL,
868                      recv_id,
869                      ipulse,
870                      overlap->mpi_comm,
871                      &stat);
872 #endif
873
874         for (x = 0; x < recv_nindex; x++)
875         {
876             for (y = 0; y < local_fft_ndata[YY]; y++)
877             {
878                 indg = (x * local_fft_size[YY] + y) * local_fft_size[ZZ];
879                 indb = (x * local_fft_ndata[YY] + y) * local_fft_ndata[ZZ];
880                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
881                 {
882                     fftgrid[indg + z] += overlap->recvbuf[indb + z];
883                 }
884             }
885         }
886     }
887 }
888
889 void spread_on_grid(const gmx_pme_t*  pme,
890                     PmeAtomComm*      atc,
891                     const pmegrids_t* grids,
892                     gmx_bool          bCalcSplines,
893                     gmx_bool          bSpread,
894                     real*             fftgrid,
895                     gmx_bool          bDoSplines,
896                     int               grid_index)
897 {
898 #ifdef PME_TIME_THREADS
899     gmx_cycles_t  c1, c2, c3, ct1a, ct1b, ct1c;
900     static double cs1 = 0, cs2 = 0, cs3 = 0;
901     static double cs1a[6] = { 0, 0, 0, 0, 0, 0 };
902     static int    cnt     = 0;
903 #endif
904
905     const int nthread = pme->nthread;
906     assert(nthread > 0);
907     GMX_ASSERT(grids != nullptr || !bSpread, "If there's no grid, we cannot be spreading");
908
909 #ifdef PME_TIME_THREADS
910     c1 = omp_cyc_start();
911 #endif
912     if (bCalcSplines)
913     {
914 #pragma omp parallel for num_threads(nthread) schedule(static)
915         for (int thread = 0; thread < nthread; thread++)
916         {
917             try
918             {
919                 int start, end;
920
921                 start = atc->numAtoms() * thread / nthread;
922                 end   = atc->numAtoms() * (thread + 1) / nthread;
923
924                 /* Compute fftgrid index for all atoms,
925                  * with help of some extra variables.
926                  */
927                 calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
928             }
929             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
930         }
931     }
932 #ifdef PME_TIME_THREADS
933     c1 = omp_cyc_end(c1);
934     cs1 += (double)c1;
935 #endif
936
937 #ifdef PME_TIME_THREADS
938     c2 = omp_cyc_start();
939 #endif
940 #pragma omp parallel for num_threads(nthread) schedule(static)
941     for (int thread = 0; thread < nthread; thread++)
942     {
943         try
944         {
945             splinedata_t* spline;
946
947             /* make local bsplines  */
948             if (grids == nullptr || !pme->bUseThreads)
949             {
950                 spline = &atc->spline[0];
951
952                 spline->n = atc->numAtoms();
953             }
954             else
955             {
956                 spline = &atc->spline[thread];
957
958                 if (grids->nthread == 1)
959                 {
960                     /* One thread, we operate on all coefficients */
961                     spline->n = atc->numAtoms();
962                 }
963                 else
964                 {
965                     /* Get the indices our thread should operate on */
966                     make_thread_local_ind(atc, thread, spline);
967                 }
968             }
969
970             if (bCalcSplines)
971             {
972                 make_bsplines(spline->theta.coefficients,
973                               spline->dtheta.coefficients,
974                               pme->pme_order,
975                               as_rvec_array(atc->fractx.data()),
976                               spline->n,
977                               spline->ind.data(),
978                               atc->coefficient.data(),
979                               bDoSplines);
980             }
981
982             if (bSpread)
983             {
984                 /* put local atoms on grid. */
985                 const pmegrid_t* grid = pme->bUseThreads ? &grids->grid_th[thread] : &grids->grid;
986
987 #ifdef PME_TIME_SPREAD
988                 ct1a = omp_cyc_start();
989 #endif
990                 spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
991
992                 if (pme->bUseThreads)
993                 {
994                     copy_local_grid(pme, grids, grid_index, thread, fftgrid);
995                 }
996 #ifdef PME_TIME_SPREAD
997                 ct1a = omp_cyc_end(ct1a);
998                 cs1a[thread] += (double)ct1a;
999 #endif
1000             }
1001         }
1002         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
1003     }
1004 #ifdef PME_TIME_THREADS
1005     c2 = omp_cyc_end(c2);
1006     cs2 += (double)c2;
1007 #endif
1008
1009     if (bSpread && pme->bUseThreads)
1010     {
1011 #ifdef PME_TIME_THREADS
1012         c3 = omp_cyc_start();
1013 #endif
1014 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
1015         for (int thread = 0; thread < grids->nthread; thread++)
1016         {
1017             try
1018             {
1019                 reduce_threadgrid_overlap(pme,
1020                                           grids,
1021                                           thread,
1022                                           fftgrid,
1023                                           const_cast<real*>(pme->overlap[0].sendbuf.data()),
1024                                           const_cast<real*>(pme->overlap[1].sendbuf.data()),
1025                                           grid_index);
1026             }
1027             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
1028         }
1029 #ifdef PME_TIME_THREADS
1030         c3 = omp_cyc_end(c3);
1031         cs3 += (double)c3;
1032 #endif
1033
1034         if (pme->nnodes > 1)
1035         {
1036             /* Communicate the overlapping part of the fftgrid.
1037              * For this communication call we need to check pme->bUseThreads
1038              * to have all ranks communicate here, regardless of pme->nthread.
1039              */
1040             sum_fftgrid_dd(pme, fftgrid, grid_index);
1041         }
1042     }
1043
1044 #ifdef PME_TIME_THREADS
1045     cnt++;
1046     if (cnt % 20 == 0)
1047     {
1048         printf("idx %.2f spread %.2f red %.2f", cs1 * 1e-9, cs2 * 1e-9, cs3 * 1e-9);
1049 #    ifdef PME_TIME_SPREAD
1050         for (int thread = 0; thread < nthread; thread++)
1051         {
1052             printf(" %.2f", cs1a[thread] * 1e-9);
1053         }
1054 #    endif
1055         printf("\n");
1056     }
1057 #endif
1058 }