88cbad784d890b2ad511376b98ffe13a1fb5140e
[alexxy/gromacs.git] / src / gromacs / ewald / pme_spread.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018,2019, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "pme_spread.h"
40
41 #include "config.h"
42
43 #include <cassert>
44
45 #include <algorithm>
46
47 #include "gromacs/ewald/pme.h"
48 #include "gromacs/fft/parallel_3dfft.h"
49 #include "gromacs/simd/simd.h"
50 #include "gromacs/utility/basedefinitions.h"
51 #include "gromacs/utility/exceptions.h"
52 #include "gromacs/utility/fatalerror.h"
53 #include "gromacs/utility/smalloc.h"
54
55 #include "pme_grid.h"
56 #include "pme_internal.h"
57 #include "pme_simd.h"
58 #include "pme_spline_work.h"
59
60 /* TODO consider split of pme-spline from this file */
61
62 static void calc_interpolation_idx(const gmx_pme_t *pme, PmeAtomComm *atc,
63                                    int start, int grid_index, int end, int thread)
64 {
65     int             i;
66     int            *idxptr, tix, tiy, tiz;
67     const real     *xptr;
68     real           *fptr, tx, ty, tz;
69     real            rxx, ryx, ryy, rzx, rzy, rzz;
70     int             nx, ny, nz;
71     int            *g2tx, *g2ty, *g2tz;
72     gmx_bool        bThreads;
73     int            *thread_idx = nullptr;
74     int            *tpl_n      = nullptr;
75     int             thread_i;
76
77     nx  = pme->nkx;
78     ny  = pme->nky;
79     nz  = pme->nkz;
80
81     rxx = pme->recipbox[XX][XX];
82     ryx = pme->recipbox[YY][XX];
83     ryy = pme->recipbox[YY][YY];
84     rzx = pme->recipbox[ZZ][XX];
85     rzy = pme->recipbox[ZZ][YY];
86     rzz = pme->recipbox[ZZ][ZZ];
87
88     g2tx = pme->pmegrid[grid_index].g2t[XX];
89     g2ty = pme->pmegrid[grid_index].g2t[YY];
90     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
91
92     bThreads = (atc->nthread > 1);
93     if (bThreads)
94     {
95         thread_idx = atc->thread_idx.data();
96
97         tpl_n = atc->threadMap[thread].n;
98         for (i = 0; i < atc->nthread; i++)
99         {
100             tpl_n[i] = 0;
101         }
102     }
103
104     const real shift = c_pmeMaxUnitcellShift;
105
106     for (i = start; i < end; i++)
107     {
108         xptr   = atc->x[i];
109         idxptr = atc->idx[i];
110         fptr   = atc->fractx[i];
111
112         /* Fractional coordinates along box vectors, add a positive shift to ensure tx/ty/tz are positive for triclinic boxes */
113         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + shift );
114         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + shift );
115         tz = nz * (                                   xptr[ZZ] * rzz + shift );
116
117         tix = static_cast<int>(tx);
118         tiy = static_cast<int>(ty);
119         tiz = static_cast<int>(tz);
120
121 #ifdef DEBUG
122         range_check(tix, 0, c_pmeNeighborUnitcellCount * nx);
123         range_check(tiy, 0, c_pmeNeighborUnitcellCount * ny);
124         range_check(tiz, 0, c_pmeNeighborUnitcellCount * nz);
125 #endif
126         /* Because decomposition only occurs in x and y,
127          * we never have a fraction correction in z.
128          */
129         fptr[XX] = tx - tix + pme->fshx[tix];
130         fptr[YY] = ty - tiy + pme->fshy[tiy];
131         fptr[ZZ] = tz - tiz;
132
133         idxptr[XX] = pme->nnx[tix];
134         idxptr[YY] = pme->nny[tiy];
135         idxptr[ZZ] = pme->nnz[tiz];
136
137 #ifdef DEBUG
138         range_check(idxptr[XX], 0, pme->pmegrid_nx);
139         range_check(idxptr[YY], 0, pme->pmegrid_ny);
140         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
141 #endif
142
143         if (bThreads)
144         {
145             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
146             thread_idx[i] = thread_i;
147             tpl_n[thread_i]++;
148         }
149     }
150
151     if (bThreads)
152     {
153         /* Make a list of particle indices sorted on thread */
154
155         /* Get the cumulative count */
156         for (i = 1; i < atc->nthread; i++)
157         {
158             tpl_n[i] += tpl_n[i-1];
159         }
160         /* The current implementation distributes particles equally
161          * over the threads, so we could actually allocate for that
162          * in pme_realloc_atomcomm_things.
163          */
164         AtomToThreadMap &threadMap = atc->threadMap[thread];
165         threadMap.i.resize(tpl_n[atc->nthread - 1]);
166         /* Set tpl_n to the cumulative start */
167         for (i = atc->nthread-1; i >= 1; i--)
168         {
169             tpl_n[i] = tpl_n[i-1];
170         }
171         tpl_n[0] = 0;
172
173         /* Fill our thread local array with indices sorted on thread */
174         for (i = start; i < end; i++)
175         {
176             threadMap.i[tpl_n[atc->thread_idx[i]]++] = i;
177         }
178         /* Now tpl_n contains the cummulative count again */
179     }
180 }
181
182 static void make_thread_local_ind(const PmeAtomComm *atc,
183                                   int thread, splinedata_t *spline)
184 {
185     int             n, t, i, start, end;
186
187     /* Combine the indices made by each thread into one index */
188
189     n     = 0;
190     start = 0;
191     for (t = 0; t < atc->nthread; t++)
192     {
193         const AtomToThreadMap &threadMap = atc->threadMap[t];
194         /* Copy our part (start - end) from the list of thread t */
195         if (thread > 0)
196         {
197             start = threadMap.n[thread-1];
198         }
199         end = threadMap.n[thread];
200         for (i = start; i < end; i++)
201         {
202             spline->ind[n++] = threadMap.i[i];
203         }
204     }
205
206     spline->n = n;
207 }
208
209 /* Macro to force loop unrolling by fixing order.
210  * This gives a significant performance gain.
211  */
212 #define CALC_SPLINE(order)                     \
213     {                                              \
214         for (int j = 0; (j < DIM); j++)            \
215         {                                          \
216             real dr, div;                          \
217             real data[PME_ORDER_MAX];              \
218                                                    \
219             dr  = xptr[j];                         \
220                                                \
221             /* dr is relative offset from lower cell limit */ \
222             data[(order)-1] = 0;                     \
223             data[1]         = dr;                          \
224             data[0]         = 1 - dr;                      \
225                                                \
226             for (int k = 3; (k < (order)); k++)      \
227             {                                      \
228                 div       = 1.0/(k - 1.0);               \
229                 data[k-1] = div*dr*data[k-2];      \
230                 for (int l = 1; (l < (k-1)); l++)  \
231                 {                                  \
232                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
233                                        data[k-l-1]);                \
234                 }                                  \
235                 data[0] = div*(1-dr)*data[0];      \
236             }                                      \
237             /* differentiate */                    \
238             dtheta[j][i*(order)+0] = -data[0];       \
239             for (int k = 1; (k < (order)); k++)      \
240             {                                      \
241                 dtheta[j][i*(order)+k] = data[k-1] - data[k]; \
242             }                                      \
243                                                \
244             div             = 1.0/((order) - 1);                 \
245             data[(order)-1] = div*dr*data[(order)-2];  \
246             for (int l = 1; (l < ((order)-1)); l++)  \
247             {                                      \
248                 data[(order)-l-1] = div*((dr+l)*data[(order)-l-2]+    \
249                                          ((order)-l-dr)*data[(order)-l-1]); \
250             }                                      \
251             data[0] = div*(1 - dr)*data[0];        \
252                                                \
253             for (int k = 0; k < (order); k++)        \
254             {                                      \
255                 theta[j][i*(order)+k]  = data[k];    \
256             }                                      \
257         }                                          \
258     }
259
260 static void make_bsplines(splinevec theta, splinevec dtheta, int order,
261                           rvec fractx[], int nr, const int ind[], const real coefficient[],
262                           gmx_bool bDoSplines)
263 {
264     /* construct splines for local atoms */
265     int   i, ii;
266     real *xptr;
267
268     for (i = 0; i < nr; i++)
269     {
270         /* With free energy we do not use the coefficient check.
271          * In most cases this will be more efficient than calling make_bsplines
272          * twice, since usually more than half the particles have non-zero coefficients.
273          */
274         ii = ind[i];
275         if (bDoSplines || coefficient[ii] != 0.0)
276         {
277             xptr = fractx[ii];
278             assert(order >= 3 && order <= PME_ORDER_MAX);
279             switch (order)
280             {
281                 case 4:  CALC_SPLINE(4);     break;
282                 case 5:  CALC_SPLINE(5);     break;
283                 default: CALC_SPLINE(order); break;
284             }
285         }
286     }
287 }
288
289 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
290 #define DO_BSPLINE(order)                            \
291     for (ithx = 0; (ithx < (order)); ithx++)                    \
292     {                                                    \
293         index_x = (i0+ithx)*pny*pnz;                     \
294         valx    = coefficient*thx[ithx];                          \
295                                                      \
296         for (ithy = 0; (ithy < (order)); ithy++)                \
297         {                                                \
298             valxy    = valx*thy[ithy];                   \
299             index_xy = index_x+(j0+ithy)*pnz;            \
300                                                      \
301             for (ithz = 0; (ithz < (order)); ithz++)            \
302             {                                            \
303                 index_xyz        = index_xy+(k0+ithz);   \
304                 grid[index_xyz] += valxy*thz[ithz];      \
305             }                                            \
306         }                                                \
307     }
308
309
310 static void spread_coefficients_bsplines_thread(const pmegrid_t                   *pmegrid,
311                                                 const PmeAtomComm                 *atc,
312                                                 splinedata_t                      *spline,
313                                                 struct pme_spline_work gmx_unused *work)
314 {
315
316     /* spread coefficients from home atoms to local grid */
317     real          *grid;
318     int            i, nn, n, ithx, ithy, ithz, i0, j0, k0;
319     const int     *idxptr;
320     int            order, norder, index_x, index_xy, index_xyz;
321     real           valx, valxy, coefficient;
322     real          *thx, *thy, *thz;
323     int            pnx, pny, pnz, ndatatot;
324     int            offx, offy, offz;
325
326 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
327     alignas(GMX_SIMD_ALIGNMENT) real  thz_aligned[GMX_SIMD4_WIDTH*2];
328 #endif
329
330     pnx = pmegrid->s[XX];
331     pny = pmegrid->s[YY];
332     pnz = pmegrid->s[ZZ];
333
334     offx = pmegrid->offset[XX];
335     offy = pmegrid->offset[YY];
336     offz = pmegrid->offset[ZZ];
337
338     ndatatot = pnx*pny*pnz;
339     grid     = pmegrid->grid;
340     for (i = 0; i < ndatatot; i++)
341     {
342         grid[i] = 0;
343     }
344
345     order = pmegrid->order;
346
347     for (nn = 0; nn < spline->n; nn++)
348     {
349         n           = spline->ind[nn];
350         coefficient = atc->coefficient[n];
351
352         if (coefficient != 0)
353         {
354             idxptr = atc->idx[n];
355             norder = nn*order;
356
357             i0   = idxptr[XX] - offx;
358             j0   = idxptr[YY] - offy;
359             k0   = idxptr[ZZ] - offz;
360
361             thx = spline->theta.coefficients[XX] + norder;
362             thy = spline->theta.coefficients[YY] + norder;
363             thz = spline->theta.coefficients[ZZ] + norder;
364
365             switch (order)
366             {
367                 case 4:
368 #ifdef PME_SIMD4_SPREAD_GATHER
369 #ifdef PME_SIMD4_UNALIGNED
370 #define PME_SPREAD_SIMD4_ORDER4
371 #else
372 #define PME_SPREAD_SIMD4_ALIGNED
373 #define PME_ORDER 4
374 #endif
375 #include "pme_simd4.h"
376 #else
377                     DO_BSPLINE(4);
378 #endif
379                     break;
380                 case 5:
381 #ifdef PME_SIMD4_SPREAD_GATHER
382 #define PME_SPREAD_SIMD4_ALIGNED
383 #define PME_ORDER 5
384 #include "pme_simd4.h"
385 #else
386                     DO_BSPLINE(5);
387 #endif
388                     break;
389                 default:
390                     DO_BSPLINE(order);
391                     break;
392             }
393         }
394     }
395 }
396
397 static void copy_local_grid(const gmx_pme_t *pme, const pmegrids_t *pmegrids,
398                             int grid_index, int thread, real *fftgrid)
399 {
400     ivec local_fft_ndata, local_fft_offset, local_fft_size;
401     int  fft_my, fft_mz;
402     int  nsy, nsz;
403     ivec nf;
404     int  offx, offy, offz, x, y, z, i0, i0t;
405     int  d;
406     real *grid_th;
407
408     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
409                                    local_fft_ndata,
410                                    local_fft_offset,
411                                    local_fft_size);
412     fft_my = local_fft_size[YY];
413     fft_mz = local_fft_size[ZZ];
414
415     const pmegrid_t *pmegrid = &pmegrids->grid_th[thread];
416
417     nsy = pmegrid->s[YY];
418     nsz = pmegrid->s[ZZ];
419
420     for (d = 0; d < DIM; d++)
421     {
422         nf[d] = std::min(pmegrid->n[d] - (pmegrid->order - 1),
423                          local_fft_ndata[d] - pmegrid->offset[d]);
424     }
425
426     offx = pmegrid->offset[XX];
427     offy = pmegrid->offset[YY];
428     offz = pmegrid->offset[ZZ];
429
430     /* Directly copy the non-overlapping parts of the local grids.
431      * This also initializes the full grid.
432      */
433     grid_th = pmegrid->grid;
434     for (x = 0; x < nf[XX]; x++)
435     {
436         for (y = 0; y < nf[YY]; y++)
437         {
438             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
439             i0t = (x*nsy + y)*nsz;
440             for (z = 0; z < nf[ZZ]; z++)
441             {
442                 fftgrid[i0+z] = grid_th[i0t+z];
443             }
444         }
445     }
446 }
447
448 static void
449 reduce_threadgrid_overlap(const gmx_pme_t *pme,
450                           const pmegrids_t *pmegrids, int thread,
451                           real *fftgrid, real *commbuf_x, real *commbuf_y,
452                           int grid_index)
453 {
454     ivec local_fft_ndata, local_fft_offset, local_fft_size;
455     int  fft_nx, fft_ny, fft_nz;
456     int  fft_my, fft_mz;
457     int  buf_my = -1;
458     int  nsy, nsz;
459     ivec localcopy_end, commcopy_end;
460     int  offx, offy, offz, x, y, z, i0, i0t;
461     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
462     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
463     gmx_bool bCommX, bCommY;
464     int  d;
465     int  thread_f;
466     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
467     const real *grid_th;
468     real *commbuf = nullptr;
469
470     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
471                                    local_fft_ndata,
472                                    local_fft_offset,
473                                    local_fft_size);
474     fft_nx = local_fft_ndata[XX];
475     fft_ny = local_fft_ndata[YY];
476     fft_nz = local_fft_ndata[ZZ];
477
478     fft_my = local_fft_size[YY];
479     fft_mz = local_fft_size[ZZ];
480
481     /* This routine is called when all thread have finished spreading.
482      * Here each thread sums grid contributions calculated by other threads
483      * to the thread local grid volume.
484      * To minimize the number of grid copying operations,
485      * this routines sums immediately from the pmegrid to the fftgrid.
486      */
487
488     /* Determine which part of the full node grid we should operate on,
489      * this is our thread local part of the full grid.
490      */
491     pmegrid = &pmegrids->grid_th[thread];
492
493     for (d = 0; d < DIM; d++)
494     {
495         /* Determine up to where our thread needs to copy from the
496          * thread-local charge spreading grid to the rank-local FFT grid.
497          * This is up to our spreading grid end minus order-1 and
498          * not beyond the local FFT grid.
499          */
500         localcopy_end[d] =
501             std::min(pmegrid->offset[d] + pmegrid->n[d] - (pmegrid->order - 1),
502                      local_fft_ndata[d]);
503
504         /* Determine up to where our thread needs to copy from the
505          * thread-local charge spreading grid to the communication buffer.
506          * Note: only relevant with communication, ignored otherwise.
507          */
508         commcopy_end[d]  = localcopy_end[d];
509         if (pmegrid->ci[d] == pmegrids->nc[d] - 1)
510         {
511             /* The last thread should copy up to the last pme grid line.
512              * When the rank-local FFT grid is narrower than pme-order,
513              * we need the max below to ensure copying of all data.
514              */
515             commcopy_end[d] = std::max(commcopy_end[d], pme->pme_order);
516         }
517     }
518
519     offx = pmegrid->offset[XX];
520     offy = pmegrid->offset[YY];
521     offz = pmegrid->offset[ZZ];
522
523
524     bClearBufX  = TRUE;
525     bClearBufY  = TRUE;
526     bClearBufXY = TRUE;
527
528     /* Now loop over all the thread data blocks that contribute
529      * to the grid region we (our thread) are operating on.
530      */
531     /* Note that fft_nx/y is equal to the number of grid points
532      * between the first point of our node grid and the one of the next node.
533      */
534     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
535     {
536         fx     = pmegrid->ci[XX] + sx;
537         ox     = 0;
538         bCommX = FALSE;
539         if (fx < 0)
540         {
541             fx    += pmegrids->nc[XX];
542             ox    -= fft_nx;
543             bCommX = (pme->nnodes_major > 1);
544         }
545         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
546         ox       += pmegrid_g->offset[XX];
547         /* Determine the end of our part of the source grid.
548          * Use our thread local source grid and target grid part
549          */
550         tx1 = std::min(ox + pmegrid_g->n[XX],
551                        !bCommX ? localcopy_end[XX] : commcopy_end[XX]);
552
553         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
554         {
555             fy     = pmegrid->ci[YY] + sy;
556             oy     = 0;
557             bCommY = FALSE;
558             if (fy < 0)
559             {
560                 fy    += pmegrids->nc[YY];
561                 oy    -= fft_ny;
562                 bCommY = (pme->nnodes_minor > 1);
563             }
564             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
565             oy       += pmegrid_g->offset[YY];
566             /* Determine the end of our part of the source grid.
567              * Use our thread local source grid and target grid part
568              */
569             ty1 = std::min(oy + pmegrid_g->n[YY],
570                            !bCommY ? localcopy_end[YY] : commcopy_end[YY]);
571
572             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
573             {
574                 fz = pmegrid->ci[ZZ] + sz;
575                 oz = 0;
576                 if (fz < 0)
577                 {
578                     fz += pmegrids->nc[ZZ];
579                     oz -= fft_nz;
580                 }
581                 pmegrid_g = &pmegrids->grid_th[fz];
582                 oz       += pmegrid_g->offset[ZZ];
583                 tz1       = std::min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
584
585                 if (sx == 0 && sy == 0 && sz == 0)
586                 {
587                     /* We have already added our local contribution
588                      * before calling this routine, so skip it here.
589                      */
590                     continue;
591                 }
592
593                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
594
595                 pmegrid_f = &pmegrids->grid_th[thread_f];
596
597                 grid_th = pmegrid_f->grid;
598
599                 nsy = pmegrid_f->s[YY];
600                 nsz = pmegrid_f->s[ZZ];
601
602 #ifdef DEBUG_PME_REDUCE
603                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
604                        pme->nodeid, thread, thread_f,
605                        pme->pmegrid_start_ix,
606                        pme->pmegrid_start_iy,
607                        pme->pmegrid_start_iz,
608                        sx, sy, sz,
609                        offx-ox, tx1-ox, offx, tx1,
610                        offy-oy, ty1-oy, offy, ty1,
611                        offz-oz, tz1-oz, offz, tz1);
612 #endif
613
614                 if (!(bCommX || bCommY))
615                 {
616                     /* Copy from the thread local grid to the node grid */
617                     for (x = offx; x < tx1; x++)
618                     {
619                         for (y = offy; y < ty1; y++)
620                         {
621                             i0  = (x*fft_my + y)*fft_mz;
622                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
623                             for (z = offz; z < tz1; z++)
624                             {
625                                 fftgrid[i0+z] += grid_th[i0t+z];
626                             }
627                         }
628                     }
629                 }
630                 else
631                 {
632                     /* The order of this conditional decides
633                      * where the corner volume gets stored with x+y decomp.
634                      */
635                     if (bCommY)
636                     {
637                         commbuf = commbuf_y;
638                         /* The y-size of the communication buffer is set by
639                          * the overlap of the grid part of our local slab
640                          * with the part starting at the next slab.
641                          */
642                         buf_my  =
643                             pme->overlap[1].s2g1[pme->nodeid_minor] -
644                             pme->overlap[1].s2g0[pme->nodeid_minor+1];
645                         if (bCommX)
646                         {
647                             /* We index commbuf modulo the local grid size */
648                             commbuf += buf_my*fft_nx*fft_nz;
649
650                             bClearBuf   = bClearBufXY;
651                             bClearBufXY = FALSE;
652                         }
653                         else
654                         {
655                             bClearBuf  = bClearBufY;
656                             bClearBufY = FALSE;
657                         }
658                     }
659                     else
660                     {
661                         commbuf    = commbuf_x;
662                         buf_my     = fft_ny;
663                         bClearBuf  = bClearBufX;
664                         bClearBufX = FALSE;
665                     }
666
667                     /* Copy to the communication buffer */
668                     for (x = offx; x < tx1; x++)
669                     {
670                         for (y = offy; y < ty1; y++)
671                         {
672                             i0  = (x*buf_my + y)*fft_nz;
673                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
674
675                             if (bClearBuf)
676                             {
677                                 /* First access of commbuf, initialize it */
678                                 for (z = offz; z < tz1; z++)
679                                 {
680                                     commbuf[i0+z]  = grid_th[i0t+z];
681                                 }
682                             }
683                             else
684                             {
685                                 for (z = offz; z < tz1; z++)
686                                 {
687                                     commbuf[i0+z] += grid_th[i0t+z];
688                                 }
689                             }
690                         }
691                     }
692                 }
693             }
694         }
695     }
696 }
697
698
699 static void sum_fftgrid_dd(const gmx_pme_t *pme, real *fftgrid, int grid_index)
700 {
701     ivec local_fft_ndata, local_fft_offset, local_fft_size;
702     int  send_index0, send_nindex;
703     int  recv_nindex;
704 #if GMX_MPI
705     MPI_Status stat;
706 #endif
707     int  recv_size_y;
708     int  size_yx;
709     int  x, y, z, indg, indb;
710
711     /* Note that this routine is only used for forward communication.
712      * Since the force gathering, unlike the coefficient spreading,
713      * can be trivially parallelized over the particles,
714      * the backwards process is much simpler and can use the "old"
715      * communication setup.
716      */
717
718     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
719                                    local_fft_ndata,
720                                    local_fft_offset,
721                                    local_fft_size);
722
723     if (pme->nnodes_minor > 1)
724     {
725         /* Major dimension */
726         const pme_overlap_t *overlap = &pme->overlap[1];
727
728         if (pme->nnodes_major > 1)
729         {
730             size_yx = pme->overlap[0].comm_data[0].send_nindex;
731         }
732         else
733         {
734             size_yx = 0;
735         }
736 #if GMX_MPI
737         int datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
738
739         int send_size_y = overlap->send_size;
740 #endif
741
742         for (size_t ipulse = 0; ipulse < overlap->comm_data.size(); ipulse++)
743         {
744             send_index0   =
745                 overlap->comm_data[ipulse].send_index0 -
746                 overlap->comm_data[0].send_index0;
747             send_nindex   = overlap->comm_data[ipulse].send_nindex;
748             /* We don't use recv_index0, as we always receive starting at 0 */
749             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
750             recv_size_y   = overlap->comm_data[ipulse].recv_size;
751
752             auto *sendptr = const_cast<real *>(overlap->sendbuf.data()) + send_index0 * local_fft_ndata[ZZ];
753             auto *recvptr = const_cast<real *>(overlap->recvbuf.data());
754
755             if (debug != nullptr)
756             {
757                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n",
758                         local_fft_ndata[XX], send_nindex, local_fft_ndata[ZZ]);
759             }
760
761 #if GMX_MPI
762             int send_id = overlap->comm_data[ipulse].send_id;
763             int recv_id = overlap->comm_data[ipulse].recv_id;
764             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
765                          send_id, ipulse,
766                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
767                          recv_id, ipulse,
768                          overlap->mpi_comm, &stat);
769 #endif
770
771             for (x = 0; x < local_fft_ndata[XX]; x++)
772             {
773                 for (y = 0; y < recv_nindex; y++)
774                 {
775                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
776                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
777                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
778                     {
779                         fftgrid[indg+z] += recvptr[indb+z];
780                     }
781                 }
782             }
783
784             if (pme->nnodes_major > 1)
785             {
786                 /* Copy from the received buffer to the send buffer for dim 0 */
787                 sendptr = const_cast<real *>(pme->overlap[0].sendbuf.data());
788                 for (x = 0; x < size_yx; x++)
789                 {
790                     for (y = 0; y < recv_nindex; y++)
791                     {
792                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
793                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
794                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
795                         {
796                             sendptr[indg+z] += recvptr[indb+z];
797                         }
798                     }
799                 }
800             }
801         }
802     }
803
804     /* We only support a single pulse here.
805      * This is not a severe limitation, as this code is only used
806      * with OpenMP and with OpenMP the (PME) domains can be larger.
807      */
808     if (pme->nnodes_major > 1)
809     {
810         /* Major dimension */
811         const pme_overlap_t *overlap = &pme->overlap[0];
812
813         size_t ipulse = 0;
814
815         send_nindex   = overlap->comm_data[ipulse].send_nindex;
816         /* We don't use recv_index0, as we always receive starting at 0 */
817         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
818
819         if (debug != nullptr)
820         {
821             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n",
822                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
823         }
824
825 #if GMX_MPI
826         int datasize  = local_fft_ndata[YY]*local_fft_ndata[ZZ];
827         int send_id   = overlap->comm_data[ipulse].send_id;
828         int recv_id   = overlap->comm_data[ipulse].recv_id;
829         auto *sendptr = const_cast<real *>(overlap->sendbuf.data());
830         auto *recvptr = const_cast<real *>(overlap->recvbuf.data());
831         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
832                      send_id, ipulse,
833                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
834                      recv_id, ipulse,
835                      overlap->mpi_comm, &stat);
836 #endif
837
838         for (x = 0; x < recv_nindex; x++)
839         {
840             for (y = 0; y < local_fft_ndata[YY]; y++)
841             {
842                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
843                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
844                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
845                 {
846                     fftgrid[indg + z] += overlap->recvbuf[indb + z];
847                 }
848             }
849         }
850     }
851 }
852
853 void spread_on_grid(const gmx_pme_t *pme,
854                     PmeAtomComm *atc, const pmegrids_t *grids,
855                     gmx_bool bCalcSplines, gmx_bool bSpread,
856                     real *fftgrid, gmx_bool bDoSplines, int grid_index)
857 {
858     int nthread, thread;
859 #ifdef PME_TIME_THREADS
860     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
861     static double cs1     = 0, cs2 = 0, cs3 = 0;
862     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
863     static int cnt        = 0;
864 #endif
865
866     nthread = pme->nthread;
867     assert(nthread > 0);
868
869 #ifdef PME_TIME_THREADS
870     c1 = omp_cyc_start();
871 #endif
872     if (bCalcSplines)
873     {
874 #pragma omp parallel for num_threads(nthread) schedule(static)
875         for (thread = 0; thread < nthread; thread++)
876         {
877             try
878             {
879                 int start, end;
880
881                 start = atc->numAtoms()* thread   /nthread;
882                 end   = atc->numAtoms()*(thread+1)/nthread;
883
884                 /* Compute fftgrid index for all atoms,
885                  * with help of some extra variables.
886                  */
887                 calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
888             }
889             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
890         }
891     }
892 #ifdef PME_TIME_THREADS
893     c1   = omp_cyc_end(c1);
894     cs1 += (double)c1;
895 #endif
896
897 #ifdef PME_TIME_THREADS
898     c2 = omp_cyc_start();
899 #endif
900 #pragma omp parallel for num_threads(nthread) schedule(static)
901     for (thread = 0; thread < nthread; thread++)
902     {
903         try
904         {
905             splinedata_t *spline;
906
907             /* make local bsplines  */
908             if (grids == nullptr || !pme->bUseThreads)
909             {
910                 spline = &atc->spline[0];
911
912                 spline->n = atc->numAtoms();
913             }
914             else
915             {
916                 spline = &atc->spline[thread];
917
918                 if (grids->nthread == 1)
919                 {
920                     /* One thread, we operate on all coefficients */
921                     spline->n = atc->numAtoms();
922                 }
923                 else
924                 {
925                     /* Get the indices our thread should operate on */
926                     make_thread_local_ind(atc, thread, spline);
927                 }
928             }
929
930             if (bCalcSplines)
931             {
932                 make_bsplines(spline->theta.coefficients,
933                               spline->dtheta.coefficients,
934                               pme->pme_order,
935                               as_rvec_array(atc->fractx.data()), spline->n, spline->ind.data(),
936                               atc->coefficient.data(), bDoSplines);
937             }
938
939             if (bSpread)
940             {
941                 /* put local atoms on grid. */
942                 const pmegrid_t *grid = pme->bUseThreads ? &grids->grid_th[thread] : &grids->grid;
943
944 #ifdef PME_TIME_SPREAD
945                 ct1a = omp_cyc_start();
946 #endif
947                 spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
948
949                 if (pme->bUseThreads)
950                 {
951                     copy_local_grid(pme, grids, grid_index, thread, fftgrid);
952                 }
953 #ifdef PME_TIME_SPREAD
954                 ct1a          = omp_cyc_end(ct1a);
955                 cs1a[thread] += (double)ct1a;
956 #endif
957             }
958         }
959         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
960     }
961 #ifdef PME_TIME_THREADS
962     c2   = omp_cyc_end(c2);
963     cs2 += (double)c2;
964 #endif
965
966     if (bSpread && pme->bUseThreads)
967     {
968 #ifdef PME_TIME_THREADS
969         c3 = omp_cyc_start();
970 #endif
971 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
972         for (thread = 0; thread < grids->nthread; thread++)
973         {
974             try
975             {
976                 reduce_threadgrid_overlap(pme, grids, thread,
977                                           fftgrid,
978                                           const_cast<real *>(pme->overlap[0].sendbuf.data()),
979                                           const_cast<real *>(pme->overlap[1].sendbuf.data()),
980                                           grid_index);
981             }
982             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
983         }
984 #ifdef PME_TIME_THREADS
985         c3   = omp_cyc_end(c3);
986         cs3 += (double)c3;
987 #endif
988
989         if (pme->nnodes > 1)
990         {
991             /* Communicate the overlapping part of the fftgrid.
992              * For this communication call we need to check pme->bUseThreads
993              * to have all ranks communicate here, regardless of pme->nthread.
994              */
995             sum_fftgrid_dd(pme, fftgrid, grid_index);
996         }
997     }
998
999 #ifdef PME_TIME_THREADS
1000     cnt++;
1001     if (cnt % 20 == 0)
1002     {
1003         printf("idx %.2f spread %.2f red %.2f",
1004                cs1*1e-9, cs2*1e-9, cs3*1e-9);
1005 #ifdef PME_TIME_SPREAD
1006         for (thread = 0; thread < nthread; thread++)
1007         {
1008             printf(" %.2f", cs1a[thread]*1e-9);
1009         }
1010 #endif
1011         printf("\n");
1012     }
1013 #endif
1014 }