Apply clang-format to source tree
[alexxy/gromacs.git] / src / gromacs / ewald / pme_spread.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "pme_spread.h"
40
41 #include "config.h"
42
43 #include <cassert>
44
45 #include <algorithm>
46
47 #include "gromacs/ewald/pme.h"
48 #include "gromacs/fft/parallel_3dfft.h"
49 #include "gromacs/simd/simd.h"
50 #include "gromacs/utility/basedefinitions.h"
51 #include "gromacs/utility/exceptions.h"
52 #include "gromacs/utility/fatalerror.h"
53 #include "gromacs/utility/gmxassert.h"
54 #include "gromacs/utility/smalloc.h"
55
56 #include "pme_grid.h"
57 #include "pme_internal.h"
58 #include "pme_simd.h"
59 #include "pme_spline_work.h"
60
61 /* TODO consider split of pme-spline from this file */
62
63 static void calc_interpolation_idx(const gmx_pme_t* pme, PmeAtomComm* atc, int start, int grid_index, int end, int thread)
64 {
65     int         i;
66     int *       idxptr, tix, tiy, tiz;
67     const real* xptr;
68     real *      fptr, tx, ty, tz;
69     real        rxx, ryx, ryy, rzx, rzy, rzz;
70     int         nx, ny, nz;
71     int *       g2tx, *g2ty, *g2tz;
72     gmx_bool    bThreads;
73     int*        thread_idx = nullptr;
74     int*        tpl_n      = nullptr;
75     int         thread_i;
76
77     nx = pme->nkx;
78     ny = pme->nky;
79     nz = pme->nkz;
80
81     rxx = pme->recipbox[XX][XX];
82     ryx = pme->recipbox[YY][XX];
83     ryy = pme->recipbox[YY][YY];
84     rzx = pme->recipbox[ZZ][XX];
85     rzy = pme->recipbox[ZZ][YY];
86     rzz = pme->recipbox[ZZ][ZZ];
87
88     g2tx = pme->pmegrid[grid_index].g2t[XX];
89     g2ty = pme->pmegrid[grid_index].g2t[YY];
90     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
91
92     bThreads = (atc->nthread > 1);
93     if (bThreads)
94     {
95         thread_idx = atc->thread_idx.data();
96
97         tpl_n = atc->threadMap[thread].n;
98         for (i = 0; i < atc->nthread; i++)
99         {
100             tpl_n[i] = 0;
101         }
102     }
103
104     const real shift = c_pmeMaxUnitcellShift;
105
106     for (i = start; i < end; i++)
107     {
108         xptr   = atc->x[i];
109         idxptr = atc->idx[i];
110         fptr   = atc->fractx[i];
111
112         /* Fractional coordinates along box vectors, add a positive shift to ensure tx/ty/tz are positive for triclinic boxes */
113         tx = nx * (xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + shift);
114         ty = ny * (xptr[YY] * ryy + xptr[ZZ] * rzy + shift);
115         tz = nz * (xptr[ZZ] * rzz + shift);
116
117         tix = static_cast<int>(tx);
118         tiy = static_cast<int>(ty);
119         tiz = static_cast<int>(tz);
120
121 #ifdef DEBUG
122         range_check(tix, 0, c_pmeNeighborUnitcellCount * nx);
123         range_check(tiy, 0, c_pmeNeighborUnitcellCount * ny);
124         range_check(tiz, 0, c_pmeNeighborUnitcellCount * nz);
125 #endif
126         /* Because decomposition only occurs in x and y,
127          * we never have a fraction correction in z.
128          */
129         fptr[XX] = tx - tix + pme->fshx[tix];
130         fptr[YY] = ty - tiy + pme->fshy[tiy];
131         fptr[ZZ] = tz - tiz;
132
133         idxptr[XX] = pme->nnx[tix];
134         idxptr[YY] = pme->nny[tiy];
135         idxptr[ZZ] = pme->nnz[tiz];
136
137 #ifdef DEBUG
138         range_check(idxptr[XX], 0, pme->pmegrid_nx);
139         range_check(idxptr[YY], 0, pme->pmegrid_ny);
140         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
141 #endif
142
143         if (bThreads)
144         {
145             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
146             thread_idx[i] = thread_i;
147             tpl_n[thread_i]++;
148         }
149     }
150
151     if (bThreads)
152     {
153         /* Make a list of particle indices sorted on thread */
154
155         /* Get the cumulative count */
156         for (i = 1; i < atc->nthread; i++)
157         {
158             tpl_n[i] += tpl_n[i - 1];
159         }
160         /* The current implementation distributes particles equally
161          * over the threads, so we could actually allocate for that
162          * in pme_realloc_atomcomm_things.
163          */
164         AtomToThreadMap& threadMap = atc->threadMap[thread];
165         threadMap.i.resize(tpl_n[atc->nthread - 1]);
166         /* Set tpl_n to the cumulative start */
167         for (i = atc->nthread - 1; i >= 1; i--)
168         {
169             tpl_n[i] = tpl_n[i - 1];
170         }
171         tpl_n[0] = 0;
172
173         /* Fill our thread local array with indices sorted on thread */
174         for (i = start; i < end; i++)
175         {
176             threadMap.i[tpl_n[atc->thread_idx[i]]++] = i;
177         }
178         /* Now tpl_n contains the cummulative count again */
179     }
180 }
181
182 static void make_thread_local_ind(const PmeAtomComm* atc, int thread, splinedata_t* spline)
183 {
184     int n, t, i, start, end;
185
186     /* Combine the indices made by each thread into one index */
187
188     n     = 0;
189     start = 0;
190     for (t = 0; t < atc->nthread; t++)
191     {
192         const AtomToThreadMap& threadMap = atc->threadMap[t];
193         /* Copy our part (start - end) from the list of thread t */
194         if (thread > 0)
195         {
196             start = threadMap.n[thread - 1];
197         }
198         end = threadMap.n[thread];
199         for (i = start; i < end; i++)
200         {
201             spline->ind[n++] = threadMap.i[i];
202         }
203     }
204
205     spline->n = n;
206 }
207
208 // At run time, the values of order used and asserted upon mean that
209 // indexing out of bounds does not occur. However compilers don't
210 // always understand that, so we suppress this warning for this code
211 // region.
212 #pragma GCC diagnostic push
213 #pragma GCC diagnostic ignored "-Warray-bounds"
214
215 /* Macro to force loop unrolling by fixing order.
216  * This gives a significant performance gain.
217  */
218 #define CALC_SPLINE(order)                                                                               \
219     {                                                                                                    \
220         for (int j = 0; (j < DIM); j++)                                                                  \
221         {                                                                                                \
222             real dr, div;                                                                                \
223             real data[PME_ORDER_MAX];                                                                    \
224                                                                                                          \
225             dr = xptr[j];                                                                                \
226                                                                                                          \
227             /* dr is relative offset from lower cell limit */                                            \
228             data[(order)-1] = 0;                                                                         \
229             data[1]         = dr;                                                                        \
230             data[0]         = 1 - dr;                                                                    \
231                                                                                                          \
232             for (int k = 3; (k < (order)); k++)                                                          \
233             {                                                                                            \
234                 div         = 1.0 / (k - 1.0);                                                           \
235                 data[k - 1] = div * dr * data[k - 2];                                                    \
236                 for (int l = 1; (l < (k - 1)); l++)                                                      \
237                 {                                                                                        \
238                     data[k - l - 1] =                                                                    \
239                             div * ((dr + l) * data[k - l - 2] + (k - l - dr) * data[k - l - 1]);         \
240                 }                                                                                        \
241                 data[0] = div * (1 - dr) * data[0];                                                      \
242             }                                                                                            \
243             /* differentiate */                                                                          \
244             dtheta[j][i * (order) + 0] = -data[0];                                                       \
245             for (int k = 1; (k < (order)); k++)                                                          \
246             {                                                                                            \
247                 dtheta[j][i * (order) + k] = data[k - 1] - data[k];                                      \
248             }                                                                                            \
249                                                                                                          \
250             div             = 1.0 / ((order)-1);                                                         \
251             data[(order)-1] = div * dr * data[(order)-2];                                                \
252             for (int l = 1; (l < ((order)-1)); l++)                                                      \
253             {                                                                                            \
254                 data[(order)-l - 1] =                                                                    \
255                         div * ((dr + l) * data[(order)-l - 2] + ((order)-l - dr) * data[(order)-l - 1]); \
256             }                                                                                            \
257             data[0] = div * (1 - dr) * data[0];                                                          \
258                                                                                                          \
259             for (int k = 0; k < (order); k++)                                                            \
260             {                                                                                            \
261                 theta[j][i * (order) + k] = data[k];                                                     \
262             }                                                                                            \
263         }                                                                                                \
264     }
265
266 static void make_bsplines(splinevec  theta,
267                           splinevec  dtheta,
268                           int        order,
269                           rvec       fractx[],
270                           int        nr,
271                           const int  ind[],
272                           const real coefficient[],
273                           gmx_bool   bDoSplines)
274 {
275     /* construct splines for local atoms */
276     int   i, ii;
277     real* xptr;
278
279     for (i = 0; i < nr; i++)
280     {
281         /* With free energy we do not use the coefficient check.
282          * In most cases this will be more efficient than calling make_bsplines
283          * twice, since usually more than half the particles have non-zero coefficients.
284          */
285         ii = ind[i];
286         if (bDoSplines || coefficient[ii] != 0.0)
287         {
288             xptr = fractx[ii];
289             assert(order >= 3 && order <= PME_ORDER_MAX);
290             switch (order)
291             {
292                 case 4: CALC_SPLINE(4) break;
293                 case 5: CALC_SPLINE(5) break;
294                 default: CALC_SPLINE(order) break;
295             }
296         }
297     }
298 }
299
300 #pragma GCC diagnostic pop
301
302 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
303 #define DO_BSPLINE(order)                             \
304     for (ithx = 0; (ithx < (order)); ithx++)          \
305     {                                                 \
306         index_x = (i0 + ithx) * pny * pnz;            \
307         valx    = coefficient * thx[ithx];            \
308                                                       \
309         for (ithy = 0; (ithy < (order)); ithy++)      \
310         {                                             \
311             valxy    = valx * thy[ithy];              \
312             index_xy = index_x + (j0 + ithy) * pnz;   \
313                                                       \
314             for (ithz = 0; (ithz < (order)); ithz++)  \
315             {                                         \
316                 index_xyz = index_xy + (k0 + ithz);   \
317                 grid[index_xyz] += valxy * thz[ithz]; \
318             }                                         \
319         }                                             \
320     }
321
322
323 static void spread_coefficients_bsplines_thread(const pmegrid_t*       pmegrid,
324                                                 const PmeAtomComm*     atc,
325                                                 splinedata_t*          spline,
326                                                 struct pme_spline_work gmx_unused* work)
327 {
328
329     /* spread coefficients from home atoms to local grid */
330     real*      grid;
331     int        i, nn, n, ithx, ithy, ithz, i0, j0, k0;
332     const int* idxptr;
333     int        order, norder, index_x, index_xy, index_xyz;
334     real       valx, valxy, coefficient;
335     real *     thx, *thy, *thz;
336     int        pnx, pny, pnz, ndatatot;
337     int        offx, offy, offz;
338
339 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
340     alignas(GMX_SIMD_ALIGNMENT) real thz_aligned[GMX_SIMD4_WIDTH * 2];
341 #endif
342
343     pnx = pmegrid->s[XX];
344     pny = pmegrid->s[YY];
345     pnz = pmegrid->s[ZZ];
346
347     offx = pmegrid->offset[XX];
348     offy = pmegrid->offset[YY];
349     offz = pmegrid->offset[ZZ];
350
351     ndatatot = pnx * pny * pnz;
352     grid     = pmegrid->grid;
353     for (i = 0; i < ndatatot; i++)
354     {
355         grid[i] = 0;
356     }
357
358     order = pmegrid->order;
359
360     for (nn = 0; nn < spline->n; nn++)
361     {
362         n           = spline->ind[nn];
363         coefficient = atc->coefficient[n];
364
365         if (coefficient != 0)
366         {
367             idxptr = atc->idx[n];
368             norder = nn * order;
369
370             i0 = idxptr[XX] - offx;
371             j0 = idxptr[YY] - offy;
372             k0 = idxptr[ZZ] - offz;
373
374             thx = spline->theta.coefficients[XX] + norder;
375             thy = spline->theta.coefficients[YY] + norder;
376             thz = spline->theta.coefficients[ZZ] + norder;
377
378             switch (order)
379             {
380                 case 4:
381 #ifdef PME_SIMD4_SPREAD_GATHER
382 #    ifdef PME_SIMD4_UNALIGNED
383 #        define PME_SPREAD_SIMD4_ORDER4
384 #    else
385 #        define PME_SPREAD_SIMD4_ALIGNED
386 #        define PME_ORDER 4
387 #    endif
388 #    include "pme_simd4.h"
389 #else
390                     DO_BSPLINE(4)
391 #endif
392                     break;
393                 case 5:
394 #ifdef PME_SIMD4_SPREAD_GATHER
395 #    define PME_SPREAD_SIMD4_ALIGNED
396 #    define PME_ORDER 5
397 #    include "pme_simd4.h"
398 #else
399                     DO_BSPLINE(5)
400 #endif
401                     break;
402                 default: DO_BSPLINE(order) break;
403             }
404         }
405     }
406 }
407
408 static void copy_local_grid(const gmx_pme_t* pme, const pmegrids_t* pmegrids, int grid_index, int thread, real* fftgrid)
409 {
410     ivec  local_fft_ndata, local_fft_offset, local_fft_size;
411     int   fft_my, fft_mz;
412     int   nsy, nsz;
413     ivec  nf;
414     int   offx, offy, offz, x, y, z, i0, i0t;
415     int   d;
416     real* grid_th;
417
418     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
419                                    local_fft_size);
420     fft_my = local_fft_size[YY];
421     fft_mz = local_fft_size[ZZ];
422
423     const pmegrid_t* pmegrid = &pmegrids->grid_th[thread];
424
425     nsy = pmegrid->s[YY];
426     nsz = pmegrid->s[ZZ];
427
428     for (d = 0; d < DIM; d++)
429     {
430         nf[d] = std::min(pmegrid->n[d] - (pmegrid->order - 1), local_fft_ndata[d] - pmegrid->offset[d]);
431     }
432
433     offx = pmegrid->offset[XX];
434     offy = pmegrid->offset[YY];
435     offz = pmegrid->offset[ZZ];
436
437     /* Directly copy the non-overlapping parts of the local grids.
438      * This also initializes the full grid.
439      */
440     grid_th = pmegrid->grid;
441     for (x = 0; x < nf[XX]; x++)
442     {
443         for (y = 0; y < nf[YY]; y++)
444         {
445             i0  = ((offx + x) * fft_my + (offy + y)) * fft_mz + offz;
446             i0t = (x * nsy + y) * nsz;
447             for (z = 0; z < nf[ZZ]; z++)
448             {
449                 fftgrid[i0 + z] = grid_th[i0t + z];
450             }
451         }
452     }
453 }
454
455 static void reduce_threadgrid_overlap(const gmx_pme_t*  pme,
456                                       const pmegrids_t* pmegrids,
457                                       int               thread,
458                                       real*             fftgrid,
459                                       real*             commbuf_x,
460                                       real*             commbuf_y,
461                                       int               grid_index)
462 {
463     ivec             local_fft_ndata, local_fft_offset, local_fft_size;
464     int              fft_nx, fft_ny, fft_nz;
465     int              fft_my, fft_mz;
466     int              buf_my = -1;
467     int              nsy, nsz;
468     ivec             localcopy_end, commcopy_end;
469     int              offx, offy, offz, x, y, z, i0, i0t;
470     int              sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
471     gmx_bool         bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
472     gmx_bool         bCommX, bCommY;
473     int              d;
474     int              thread_f;
475     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
476     const real*      grid_th;
477     real*            commbuf = nullptr;
478
479     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
480                                    local_fft_size);
481     fft_nx = local_fft_ndata[XX];
482     fft_ny = local_fft_ndata[YY];
483     fft_nz = local_fft_ndata[ZZ];
484
485     fft_my = local_fft_size[YY];
486     fft_mz = local_fft_size[ZZ];
487
488     /* This routine is called when all thread have finished spreading.
489      * Here each thread sums grid contributions calculated by other threads
490      * to the thread local grid volume.
491      * To minimize the number of grid copying operations,
492      * this routines sums immediately from the pmegrid to the fftgrid.
493      */
494
495     /* Determine which part of the full node grid we should operate on,
496      * this is our thread local part of the full grid.
497      */
498     pmegrid = &pmegrids->grid_th[thread];
499
500     for (d = 0; d < DIM; d++)
501     {
502         /* Determine up to where our thread needs to copy from the
503          * thread-local charge spreading grid to the rank-local FFT grid.
504          * This is up to our spreading grid end minus order-1 and
505          * not beyond the local FFT grid.
506          */
507         localcopy_end[d] = std::min(pmegrid->offset[d] + pmegrid->n[d] - (pmegrid->order - 1),
508                                     local_fft_ndata[d]);
509
510         /* Determine up to where our thread needs to copy from the
511          * thread-local charge spreading grid to the communication buffer.
512          * Note: only relevant with communication, ignored otherwise.
513          */
514         commcopy_end[d] = localcopy_end[d];
515         if (pmegrid->ci[d] == pmegrids->nc[d] - 1)
516         {
517             /* The last thread should copy up to the last pme grid line.
518              * When the rank-local FFT grid is narrower than pme-order,
519              * we need the max below to ensure copying of all data.
520              */
521             commcopy_end[d] = std::max(commcopy_end[d], pme->pme_order);
522         }
523     }
524
525     offx = pmegrid->offset[XX];
526     offy = pmegrid->offset[YY];
527     offz = pmegrid->offset[ZZ];
528
529
530     bClearBufX  = TRUE;
531     bClearBufY  = TRUE;
532     bClearBufXY = TRUE;
533
534     /* Now loop over all the thread data blocks that contribute
535      * to the grid region we (our thread) are operating on.
536      */
537     /* Note that fft_nx/y is equal to the number of grid points
538      * between the first point of our node grid and the one of the next node.
539      */
540     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
541     {
542         fx     = pmegrid->ci[XX] + sx;
543         ox     = 0;
544         bCommX = FALSE;
545         if (fx < 0)
546         {
547             fx += pmegrids->nc[XX];
548             ox -= fft_nx;
549             bCommX = (pme->nnodes_major > 1);
550         }
551         pmegrid_g = &pmegrids->grid_th[fx * pmegrids->nc[YY] * pmegrids->nc[ZZ]];
552         ox += pmegrid_g->offset[XX];
553         /* Determine the end of our part of the source grid.
554          * Use our thread local source grid and target grid part
555          */
556         tx1 = std::min(ox + pmegrid_g->n[XX], !bCommX ? localcopy_end[XX] : commcopy_end[XX]);
557
558         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
559         {
560             fy     = pmegrid->ci[YY] + sy;
561             oy     = 0;
562             bCommY = FALSE;
563             if (fy < 0)
564             {
565                 fy += pmegrids->nc[YY];
566                 oy -= fft_ny;
567                 bCommY = (pme->nnodes_minor > 1);
568             }
569             pmegrid_g = &pmegrids->grid_th[fy * pmegrids->nc[ZZ]];
570             oy += pmegrid_g->offset[YY];
571             /* Determine the end of our part of the source grid.
572              * Use our thread local source grid and target grid part
573              */
574             ty1 = std::min(oy + pmegrid_g->n[YY], !bCommY ? localcopy_end[YY] : commcopy_end[YY]);
575
576             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
577             {
578                 fz = pmegrid->ci[ZZ] + sz;
579                 oz = 0;
580                 if (fz < 0)
581                 {
582                     fz += pmegrids->nc[ZZ];
583                     oz -= fft_nz;
584                 }
585                 pmegrid_g = &pmegrids->grid_th[fz];
586                 oz += pmegrid_g->offset[ZZ];
587                 tz1 = std::min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
588
589                 if (sx == 0 && sy == 0 && sz == 0)
590                 {
591                     /* We have already added our local contribution
592                      * before calling this routine, so skip it here.
593                      */
594                     continue;
595                 }
596
597                 thread_f = (fx * pmegrids->nc[YY] + fy) * pmegrids->nc[ZZ] + fz;
598
599                 pmegrid_f = &pmegrids->grid_th[thread_f];
600
601                 grid_th = pmegrid_f->grid;
602
603                 nsy = pmegrid_f->s[YY];
604                 nsz = pmegrid_f->s[ZZ];
605
606 #ifdef DEBUG_PME_REDUCE
607                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d "
608                        "%2d-%2d, %2d-%2d %2d-%2d\n",
609                        pme->nodeid, thread, thread_f, pme->pmegrid_start_ix, pme->pmegrid_start_iy,
610                        pme->pmegrid_start_iz, sx, sy, sz, offx - ox, tx1 - ox, offx, tx1, offy - oy,
611                        ty1 - oy, offy, ty1, offz - oz, tz1 - oz, offz, tz1);
612 #endif
613
614                 if (!(bCommX || bCommY))
615                 {
616                     /* Copy from the thread local grid to the node grid */
617                     for (x = offx; x < tx1; x++)
618                     {
619                         for (y = offy; y < ty1; y++)
620                         {
621                             i0  = (x * fft_my + y) * fft_mz;
622                             i0t = ((x - ox) * nsy + (y - oy)) * nsz - oz;
623                             for (z = offz; z < tz1; z++)
624                             {
625                                 fftgrid[i0 + z] += grid_th[i0t + z];
626                             }
627                         }
628                     }
629                 }
630                 else
631                 {
632                     /* The order of this conditional decides
633                      * where the corner volume gets stored with x+y decomp.
634                      */
635                     if (bCommY)
636                     {
637                         commbuf = commbuf_y;
638                         /* The y-size of the communication buffer is set by
639                          * the overlap of the grid part of our local slab
640                          * with the part starting at the next slab.
641                          */
642                         buf_my = pme->overlap[1].s2g1[pme->nodeid_minor]
643                                  - pme->overlap[1].s2g0[pme->nodeid_minor + 1];
644                         if (bCommX)
645                         {
646                             /* We index commbuf modulo the local grid size */
647                             commbuf += buf_my * fft_nx * fft_nz;
648
649                             bClearBuf   = bClearBufXY;
650                             bClearBufXY = FALSE;
651                         }
652                         else
653                         {
654                             bClearBuf  = bClearBufY;
655                             bClearBufY = FALSE;
656                         }
657                     }
658                     else
659                     {
660                         commbuf    = commbuf_x;
661                         buf_my     = fft_ny;
662                         bClearBuf  = bClearBufX;
663                         bClearBufX = FALSE;
664                     }
665
666                     /* Copy to the communication buffer */
667                     for (x = offx; x < tx1; x++)
668                     {
669                         for (y = offy; y < ty1; y++)
670                         {
671                             i0  = (x * buf_my + y) * fft_nz;
672                             i0t = ((x - ox) * nsy + (y - oy)) * nsz - oz;
673
674                             if (bClearBuf)
675                             {
676                                 /* First access of commbuf, initialize it */
677                                 for (z = offz; z < tz1; z++)
678                                 {
679                                     commbuf[i0 + z] = grid_th[i0t + z];
680                                 }
681                             }
682                             else
683                             {
684                                 for (z = offz; z < tz1; z++)
685                                 {
686                                     commbuf[i0 + z] += grid_th[i0t + z];
687                                 }
688                             }
689                         }
690                     }
691                 }
692             }
693         }
694     }
695 }
696
697
698 static void sum_fftgrid_dd(const gmx_pme_t* pme, real* fftgrid, int grid_index)
699 {
700     ivec local_fft_ndata, local_fft_offset, local_fft_size;
701     int  send_index0, send_nindex;
702     int  recv_nindex;
703 #if GMX_MPI
704     MPI_Status stat;
705 #endif
706     int recv_size_y;
707     int size_yx;
708     int x, y, z, indg, indb;
709
710     /* Note that this routine is only used for forward communication.
711      * Since the force gathering, unlike the coefficient spreading,
712      * can be trivially parallelized over the particles,
713      * the backwards process is much simpler and can use the "old"
714      * communication setup.
715      */
716
717     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index], local_fft_ndata, local_fft_offset,
718                                    local_fft_size);
719
720     if (pme->nnodes_minor > 1)
721     {
722         /* Major dimension */
723         const pme_overlap_t* overlap = &pme->overlap[1];
724
725         if (pme->nnodes_major > 1)
726         {
727             size_yx = pme->overlap[0].comm_data[0].send_nindex;
728         }
729         else
730         {
731             size_yx = 0;
732         }
733 #if GMX_MPI
734         int datasize = (local_fft_ndata[XX] + size_yx) * local_fft_ndata[ZZ];
735
736         int send_size_y = overlap->send_size;
737 #endif
738
739         for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
740         {
741             send_index0 = overlap->comm_data[ipulse].send_index0 - overlap->comm_data[0].send_index0;
742             send_nindex = overlap->comm_data[ipulse].send_nindex;
743             /* We don't use recv_index0, as we always receive starting at 0 */
744             recv_nindex = overlap->comm_data[ipulse].recv_nindex;
745             recv_size_y = overlap->comm_data[ipulse].recv_size;
746
747             auto* sendptr =
748                     const_cast<real*>(overlap->sendbuf.data()) + send_index0 * local_fft_ndata[ZZ];
749             auto* recvptr = const_cast<real*>(overlap->recvbuf.data());
750
751             if (debug != nullptr)
752             {
753                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n", local_fft_ndata[XX],
754                         send_nindex, local_fft_ndata[ZZ]);
755             }
756
757 #if GMX_MPI
758             int send_id = overlap->comm_data[ipulse].send_id;
759             int recv_id = overlap->comm_data[ipulse].recv_id;
760             MPI_Sendrecv(sendptr, send_size_y * datasize, GMX_MPI_REAL, send_id, ipulse, recvptr,
761                          recv_size_y * datasize, GMX_MPI_REAL, recv_id, ipulse, overlap->mpi_comm, &stat);
762 #endif
763
764             for (x = 0; x < local_fft_ndata[XX]; x++)
765             {
766                 for (y = 0; y < recv_nindex; y++)
767                 {
768                     indg = (x * local_fft_size[YY] + y) * local_fft_size[ZZ];
769                     indb = (x * recv_size_y + y) * local_fft_ndata[ZZ];
770                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
771                     {
772                         fftgrid[indg + z] += recvptr[indb + z];
773                     }
774                 }
775             }
776
777             if (pme->nnodes_major > 1)
778             {
779                 /* Copy from the received buffer to the send buffer for dim 0 */
780                 sendptr = const_cast<real*>(pme->overlap[0].sendbuf.data());
781                 for (x = 0; x < size_yx; x++)
782                 {
783                     for (y = 0; y < recv_nindex; y++)
784                     {
785                         indg = (x * local_fft_ndata[YY] + y) * local_fft_ndata[ZZ];
786                         indb = ((local_fft_ndata[XX] + x) * recv_size_y + y) * local_fft_ndata[ZZ];
787                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
788                         {
789                             sendptr[indg + z] += recvptr[indb + z];
790                         }
791                     }
792                 }
793             }
794         }
795     }
796
797     /* We only support a single pulse here.
798      * This is not a severe limitation, as this code is only used
799      * with OpenMP and with OpenMP the (PME) domains can be larger.
800      */
801     if (pme->nnodes_major > 1)
802     {
803         /* Major dimension */
804         const pme_overlap_t* overlap = &pme->overlap[0];
805
806         size_t ipulse = 0;
807
808         send_nindex = overlap->comm_data[ipulse].send_nindex;
809         /* We don't use recv_index0, as we always receive starting at 0 */
810         recv_nindex = overlap->comm_data[ipulse].recv_nindex;
811
812         if (debug != nullptr)
813         {
814             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n", send_nindex, local_fft_ndata[YY],
815                     local_fft_ndata[ZZ]);
816         }
817
818 #if GMX_MPI
819         int   datasize = local_fft_ndata[YY] * local_fft_ndata[ZZ];
820         int   send_id  = overlap->comm_data[ipulse].send_id;
821         int   recv_id  = overlap->comm_data[ipulse].recv_id;
822         auto* sendptr  = const_cast<real*>(overlap->sendbuf.data());
823         auto* recvptr  = const_cast<real*>(overlap->recvbuf.data());
824         MPI_Sendrecv(sendptr, send_nindex * datasize, GMX_MPI_REAL, send_id, ipulse, recvptr,
825                      recv_nindex * datasize, GMX_MPI_REAL, recv_id, ipulse, overlap->mpi_comm, &stat);
826 #endif
827
828         for (x = 0; x < recv_nindex; x++)
829         {
830             for (y = 0; y < local_fft_ndata[YY]; y++)
831             {
832                 indg = (x * local_fft_size[YY] + y) * local_fft_size[ZZ];
833                 indb = (x * local_fft_ndata[YY] + y) * local_fft_ndata[ZZ];
834                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
835                 {
836                     fftgrid[indg + z] += overlap->recvbuf[indb + z];
837                 }
838             }
839         }
840     }
841 }
842
843 void spread_on_grid(const gmx_pme_t*  pme,
844                     PmeAtomComm*      atc,
845                     const pmegrids_t* grids,
846                     gmx_bool          bCalcSplines,
847                     gmx_bool          bSpread,
848                     real*             fftgrid,
849                     gmx_bool          bDoSplines,
850                     int               grid_index)
851 {
852     int nthread, thread;
853 #ifdef PME_TIME_THREADS
854     gmx_cycles_t  c1, c2, c3, ct1a, ct1b, ct1c;
855     static double cs1 = 0, cs2 = 0, cs3 = 0;
856     static double cs1a[6] = { 0, 0, 0, 0, 0, 0 };
857     static int    cnt     = 0;
858 #endif
859
860     nthread = pme->nthread;
861     assert(nthread > 0);
862     GMX_ASSERT(grids != nullptr || !bSpread, "If there's no grid, we cannot be spreading");
863
864 #ifdef PME_TIME_THREADS
865     c1 = omp_cyc_start();
866 #endif
867     if (bCalcSplines)
868     {
869 #pragma omp parallel for num_threads(nthread) schedule(static)
870         for (thread = 0; thread < nthread; thread++)
871         {
872             try
873             {
874                 int start, end;
875
876                 start = atc->numAtoms() * thread / nthread;
877                 end   = atc->numAtoms() * (thread + 1) / nthread;
878
879                 /* Compute fftgrid index for all atoms,
880                  * with help of some extra variables.
881                  */
882                 calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
883             }
884             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
885         }
886     }
887 #ifdef PME_TIME_THREADS
888     c1 = omp_cyc_end(c1);
889     cs1 += (double)c1;
890 #endif
891
892 #ifdef PME_TIME_THREADS
893     c2 = omp_cyc_start();
894 #endif
895 #pragma omp parallel for num_threads(nthread) schedule(static)
896     for (thread = 0; thread < nthread; thread++)
897     {
898         try
899         {
900             splinedata_t* spline;
901
902             /* make local bsplines  */
903             if (grids == nullptr || !pme->bUseThreads)
904             {
905                 spline = &atc->spline[0];
906
907                 spline->n = atc->numAtoms();
908             }
909             else
910             {
911                 spline = &atc->spline[thread];
912
913                 if (grids->nthread == 1)
914                 {
915                     /* One thread, we operate on all coefficients */
916                     spline->n = atc->numAtoms();
917                 }
918                 else
919                 {
920                     /* Get the indices our thread should operate on */
921                     make_thread_local_ind(atc, thread, spline);
922                 }
923             }
924
925             if (bCalcSplines)
926             {
927                 make_bsplines(spline->theta.coefficients, spline->dtheta.coefficients,
928                               pme->pme_order, as_rvec_array(atc->fractx.data()), spline->n,
929                               spline->ind.data(), atc->coefficient.data(), bDoSplines);
930             }
931
932             if (bSpread)
933             {
934                 /* put local atoms on grid. */
935                 const pmegrid_t* grid = pme->bUseThreads ? &grids->grid_th[thread] : &grids->grid;
936
937 #ifdef PME_TIME_SPREAD
938                 ct1a = omp_cyc_start();
939 #endif
940                 spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
941
942                 if (pme->bUseThreads)
943                 {
944                     copy_local_grid(pme, grids, grid_index, thread, fftgrid);
945                 }
946 #ifdef PME_TIME_SPREAD
947                 ct1a = omp_cyc_end(ct1a);
948                 cs1a[thread] += (double)ct1a;
949 #endif
950             }
951         }
952         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
953     }
954 #ifdef PME_TIME_THREADS
955     c2 = omp_cyc_end(c2);
956     cs2 += (double)c2;
957 #endif
958
959     if (bSpread && pme->bUseThreads)
960     {
961 #ifdef PME_TIME_THREADS
962         c3 = omp_cyc_start();
963 #endif
964 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
965         for (thread = 0; thread < grids->nthread; thread++)
966         {
967             try
968             {
969                 reduce_threadgrid_overlap(pme, grids, thread, fftgrid,
970                                           const_cast<real*>(pme->overlap[0].sendbuf.data()),
971                                           const_cast<real*>(pme->overlap[1].sendbuf.data()), grid_index);
972             }
973             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR
974         }
975 #ifdef PME_TIME_THREADS
976         c3 = omp_cyc_end(c3);
977         cs3 += (double)c3;
978 #endif
979
980         if (pme->nnodes > 1)
981         {
982             /* Communicate the overlapping part of the fftgrid.
983              * For this communication call we need to check pme->bUseThreads
984              * to have all ranks communicate here, regardless of pme->nthread.
985              */
986             sum_fftgrid_dd(pme, fftgrid, grid_index);
987         }
988     }
989
990 #ifdef PME_TIME_THREADS
991     cnt++;
992     if (cnt % 20 == 0)
993     {
994         printf("idx %.2f spread %.2f red %.2f", cs1 * 1e-9, cs2 * 1e-9, cs3 * 1e-9);
995 #    ifdef PME_TIME_SPREAD
996         for (thread = 0; thread < nthread; thread++)
997         {
998             printf(" %.2f", cs1a[thread] * 1e-9);
999         }
1000 #    endif
1001         printf("\n");
1002     }
1003 #endif
1004 }