Merge branch release-2018
[alexxy/gromacs.git] / src / gromacs / ewald / pme-spread.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 1991-2000, University of Groningen, The Netherlands.
5  * Copyright (c) 2001-2004, The GROMACS development team.
6  * Copyright (c) 2013,2014,2015,2016,2017,2018, by the GROMACS development team, led by
7  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
8  * and including many others, as listed in the AUTHORS file in the
9  * top-level source directory and at http://www.gromacs.org.
10  *
11  * GROMACS is free software; you can redistribute it and/or
12  * modify it under the terms of the GNU Lesser General Public License
13  * as published by the Free Software Foundation; either version 2.1
14  * of the License, or (at your option) any later version.
15  *
16  * GROMACS is distributed in the hope that it will be useful,
17  * but WITHOUT ANY WARRANTY; without even the implied warranty of
18  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
19  * Lesser General Public License for more details.
20  *
21  * You should have received a copy of the GNU Lesser General Public
22  * License along with GROMACS; if not, see
23  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
24  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
25  *
26  * If you want to redistribute modifications to GROMACS, please
27  * consider that scientific software is very special. Version
28  * control is crucial - bugs must be traceable. We will be happy to
29  * consider code for inclusion in the official distribution, but
30  * derived work must not be called official GROMACS. Details are found
31  * in the README & COPYING files - if they are missing, get the
32  * official version at http://www.gromacs.org.
33  *
34  * To help us fund GROMACS development, we humbly ask that you cite
35  * the research papers on the package. Check out http://www.gromacs.org.
36  */
37 #include "gmxpre.h"
38
39 #include "pme-spread.h"
40
41 #include "config.h"
42
43 #include <assert.h>
44
45 #include <algorithm>
46
47 #include "gromacs/ewald/pme.h"
48 #include "gromacs/fft/parallel_3dfft.h"
49 #include "gromacs/simd/simd.h"
50 #include "gromacs/utility/basedefinitions.h"
51 #include "gromacs/utility/exceptions.h"
52 #include "gromacs/utility/fatalerror.h"
53 #include "gromacs/utility/smalloc.h"
54
55 #include "pme-grid.h"
56 #include "pme-internal.h"
57 #include "pme-simd.h"
58 #include "pme-spline-work.h"
59
60 /* TODO consider split of pme-spline from this file */
61
62 static void calc_interpolation_idx(const gmx_pme_t *pme, const pme_atomcomm_t *atc,
63                                    int start, int grid_index, int end, int thread)
64 {
65     int             i;
66     int            *idxptr, tix, tiy, tiz;
67     real           *xptr, *fptr, tx, ty, tz;
68     real            rxx, ryx, ryy, rzx, rzy, rzz;
69     int             nx, ny, nz;
70     int            *g2tx, *g2ty, *g2tz;
71     gmx_bool        bThreads;
72     int            *thread_idx = nullptr;
73     thread_plist_t *tpl        = nullptr;
74     int            *tpl_n      = nullptr;
75     int             thread_i;
76
77     nx  = pme->nkx;
78     ny  = pme->nky;
79     nz  = pme->nkz;
80
81     rxx = pme->recipbox[XX][XX];
82     ryx = pme->recipbox[YY][XX];
83     ryy = pme->recipbox[YY][YY];
84     rzx = pme->recipbox[ZZ][XX];
85     rzy = pme->recipbox[ZZ][YY];
86     rzz = pme->recipbox[ZZ][ZZ];
87
88     g2tx = pme->pmegrid[grid_index].g2t[XX];
89     g2ty = pme->pmegrid[grid_index].g2t[YY];
90     g2tz = pme->pmegrid[grid_index].g2t[ZZ];
91
92     bThreads = (atc->nthread > 1);
93     if (bThreads)
94     {
95         thread_idx = atc->thread_idx;
96
97         tpl   = &atc->thread_plist[thread];
98         tpl_n = tpl->n;
99         for (i = 0; i < atc->nthread; i++)
100         {
101             tpl_n[i] = 0;
102         }
103     }
104
105     const real shift = c_pmeMaxUnitcellShift;
106
107     for (i = start; i < end; i++)
108     {
109         xptr   = atc->x[i];
110         idxptr = atc->idx[i];
111         fptr   = atc->fractx[i];
112
113         /* Fractional coordinates along box vectors, add a positive shift to ensure tx/ty/tz are positive for triclinic boxes */
114         tx = nx * ( xptr[XX] * rxx + xptr[YY] * ryx + xptr[ZZ] * rzx + shift );
115         ty = ny * (                  xptr[YY] * ryy + xptr[ZZ] * rzy + shift );
116         tz = nz * (                                   xptr[ZZ] * rzz + shift );
117
118         tix = static_cast<int>(tx);
119         tiy = static_cast<int>(ty);
120         tiz = static_cast<int>(tz);
121
122 #ifdef DEBUG
123         range_check(tix, 0, c_pmeNeighborUnitcellCount * nx);
124         range_check(tiy, 0, c_pmeNeighborUnitcellCount * ny);
125         range_check(tiz, 0, c_pmeNeighborUnitcellCount * nz);
126 #endif
127         /* Because decomposition only occurs in x and y,
128          * we never have a fraction correction in z.
129          */
130         fptr[XX] = tx - tix + pme->fshx[tix];
131         fptr[YY] = ty - tiy + pme->fshy[tiy];
132         fptr[ZZ] = tz - tiz;
133
134         idxptr[XX] = pme->nnx[tix];
135         idxptr[YY] = pme->nny[tiy];
136         idxptr[ZZ] = pme->nnz[tiz];
137
138 #ifdef DEBUG
139         range_check(idxptr[XX], 0, pme->pmegrid_nx);
140         range_check(idxptr[YY], 0, pme->pmegrid_ny);
141         range_check(idxptr[ZZ], 0, pme->pmegrid_nz);
142 #endif
143
144         if (bThreads)
145         {
146             thread_i      = g2tx[idxptr[XX]] + g2ty[idxptr[YY]] + g2tz[idxptr[ZZ]];
147             thread_idx[i] = thread_i;
148             tpl_n[thread_i]++;
149         }
150     }
151
152     if (bThreads)
153     {
154         /* Make a list of particle indices sorted on thread */
155
156         /* Get the cumulative count */
157         for (i = 1; i < atc->nthread; i++)
158         {
159             tpl_n[i] += tpl_n[i-1];
160         }
161         /* The current implementation distributes particles equally
162          * over the threads, so we could actually allocate for that
163          * in pme_realloc_atomcomm_things.
164          */
165         if (tpl_n[atc->nthread-1] > tpl->nalloc)
166         {
167             tpl->nalloc = over_alloc_large(tpl_n[atc->nthread-1]);
168             srenew(tpl->i, tpl->nalloc);
169         }
170         /* Set tpl_n to the cumulative start */
171         for (i = atc->nthread-1; i >= 1; i--)
172         {
173             tpl_n[i] = tpl_n[i-1];
174         }
175         tpl_n[0] = 0;
176
177         /* Fill our thread local array with indices sorted on thread */
178         for (i = start; i < end; i++)
179         {
180             tpl->i[tpl_n[atc->thread_idx[i]]++] = i;
181         }
182         /* Now tpl_n contains the cummulative count again */
183     }
184 }
185
186 static void make_thread_local_ind(const pme_atomcomm_t *atc,
187                                   int thread, splinedata_t *spline)
188 {
189     int             n, t, i, start, end;
190     thread_plist_t *tpl;
191
192     /* Combine the indices made by each thread into one index */
193
194     n     = 0;
195     start = 0;
196     for (t = 0; t < atc->nthread; t++)
197     {
198         tpl = &atc->thread_plist[t];
199         /* Copy our part (start - end) from the list of thread t */
200         if (thread > 0)
201         {
202             start = tpl->n[thread-1];
203         }
204         end = tpl->n[thread];
205         for (i = start; i < end; i++)
206         {
207             spline->ind[n++] = tpl->i[i];
208         }
209     }
210
211     spline->n = n;
212 }
213
214 /* Macro to force loop unrolling by fixing order.
215  * This gives a significant performance gain.
216  */
217 #define CALC_SPLINE(order)                     \
218     {                                              \
219         for (int j = 0; (j < DIM); j++)            \
220         {                                          \
221             real dr, div;                          \
222             real data[PME_ORDER_MAX];              \
223                                                    \
224             dr  = xptr[j];                         \
225                                                \
226             /* dr is relative offset from lower cell limit */ \
227             data[order-1] = 0;                     \
228             data[1]       = dr;                          \
229             data[0]       = 1 - dr;                      \
230                                                \
231             for (int k = 3; (k < order); k++)      \
232             {                                      \
233                 div       = 1.0/(k - 1.0);               \
234                 data[k-1] = div*dr*data[k-2];      \
235                 for (int l = 1; (l < (k-1)); l++)  \
236                 {                                  \
237                     data[k-l-1] = div*((dr+l)*data[k-l-2]+(k-l-dr)* \
238                                        data[k-l-1]);                \
239                 }                                  \
240                 data[0] = div*(1-dr)*data[0];      \
241             }                                      \
242             /* differentiate */                    \
243             dtheta[j][i*order+0] = -data[0];       \
244             for (int k = 1; (k < order); k++)      \
245             {                                      \
246                 dtheta[j][i*order+k] = data[k-1] - data[k]; \
247             }                                      \
248                                                \
249             div           = 1.0/(order - 1);                 \
250             data[order-1] = div*dr*data[order-2];  \
251             for (int l = 1; (l < (order-1)); l++)  \
252             {                                      \
253                 data[order-l-1] = div*((dr+l)*data[order-l-2]+    \
254                                        (order-l-dr)*data[order-l-1]); \
255             }                                      \
256             data[0] = div*(1 - dr)*data[0];        \
257                                                \
258             for (int k = 0; k < order; k++)        \
259             {                                      \
260                 theta[j][i*order+k]  = data[k];    \
261             }                                      \
262         }                                          \
263     }
264
265 static void make_bsplines(splinevec theta, splinevec dtheta, int order,
266                           rvec fractx[], int nr, int ind[], real coefficient[],
267                           gmx_bool bDoSplines)
268 {
269     /* construct splines for local atoms */
270     int   i, ii;
271     real *xptr;
272
273     for (i = 0; i < nr; i++)
274     {
275         /* With free energy we do not use the coefficient check.
276          * In most cases this will be more efficient than calling make_bsplines
277          * twice, since usually more than half the particles have non-zero coefficients.
278          */
279         ii = ind[i];
280         if (bDoSplines || coefficient[ii] != 0.0)
281         {
282             xptr = fractx[ii];
283             assert(order >= 3 && order <= PME_ORDER_MAX);
284             switch (order)
285             {
286                 case 4:  CALC_SPLINE(4);     break;
287                 case 5:  CALC_SPLINE(5);     break;
288                 default: CALC_SPLINE(order); break;
289             }
290         }
291     }
292 }
293
294 /* This has to be a macro to enable full compiler optimization with xlC (and probably others too) */
295 #define DO_BSPLINE(order)                            \
296     for (ithx = 0; (ithx < order); ithx++)                    \
297     {                                                    \
298         index_x = (i0+ithx)*pny*pnz;                     \
299         valx    = coefficient*thx[ithx];                          \
300                                                      \
301         for (ithy = 0; (ithy < order); ithy++)                \
302         {                                                \
303             valxy    = valx*thy[ithy];                   \
304             index_xy = index_x+(j0+ithy)*pnz;            \
305                                                      \
306             for (ithz = 0; (ithz < order); ithz++)            \
307             {                                            \
308                 index_xyz        = index_xy+(k0+ithz);   \
309                 grid[index_xyz] += valxy*thz[ithz];      \
310             }                                            \
311         }                                                \
312     }
313
314
315 static void spread_coefficients_bsplines_thread(const pmegrid_t                   *pmegrid,
316                                                 const pme_atomcomm_t              *atc,
317                                                 splinedata_t                      *spline,
318                                                 struct pme_spline_work gmx_unused *work)
319 {
320
321     /* spread coefficients from home atoms to local grid */
322     real          *grid;
323     int            i, nn, n, ithx, ithy, ithz, i0, j0, k0;
324     int       *    idxptr;
325     int            order, norder, index_x, index_xy, index_xyz;
326     real           valx, valxy, coefficient;
327     real          *thx, *thy, *thz;
328     int            pnx, pny, pnz, ndatatot;
329     int            offx, offy, offz;
330
331 #if defined PME_SIMD4_SPREAD_GATHER && !defined PME_SIMD4_UNALIGNED
332     alignas(GMX_SIMD_ALIGNMENT) real  thz_aligned[GMX_SIMD4_WIDTH*2];
333 #endif
334
335     pnx = pmegrid->s[XX];
336     pny = pmegrid->s[YY];
337     pnz = pmegrid->s[ZZ];
338
339     offx = pmegrid->offset[XX];
340     offy = pmegrid->offset[YY];
341     offz = pmegrid->offset[ZZ];
342
343     ndatatot = pnx*pny*pnz;
344     grid     = pmegrid->grid;
345     for (i = 0; i < ndatatot; i++)
346     {
347         grid[i] = 0;
348     }
349
350     order = pmegrid->order;
351
352     for (nn = 0; nn < spline->n; nn++)
353     {
354         n           = spline->ind[nn];
355         coefficient = atc->coefficient[n];
356
357         if (coefficient != 0)
358         {
359             idxptr = atc->idx[n];
360             norder = nn*order;
361
362             i0   = idxptr[XX] - offx;
363             j0   = idxptr[YY] - offy;
364             k0   = idxptr[ZZ] - offz;
365
366             thx = spline->theta[XX] + norder;
367             thy = spline->theta[YY] + norder;
368             thz = spline->theta[ZZ] + norder;
369
370             switch (order)
371             {
372                 case 4:
373 #ifdef PME_SIMD4_SPREAD_GATHER
374 #ifdef PME_SIMD4_UNALIGNED
375 #define PME_SPREAD_SIMD4_ORDER4
376 #else
377 #define PME_SPREAD_SIMD4_ALIGNED
378 #define PME_ORDER 4
379 #endif
380 #include "pme-simd4.h"
381 #else
382                     DO_BSPLINE(4);
383 #endif
384                     break;
385                 case 5:
386 #ifdef PME_SIMD4_SPREAD_GATHER
387 #define PME_SPREAD_SIMD4_ALIGNED
388 #define PME_ORDER 5
389 #include "pme-simd4.h"
390 #else
391                     DO_BSPLINE(5);
392 #endif
393                     break;
394                 default:
395                     DO_BSPLINE(order);
396                     break;
397             }
398         }
399     }
400 }
401
402 static void copy_local_grid(const gmx_pme_t *pme, const pmegrids_t *pmegrids,
403                             int grid_index, int thread, real *fftgrid)
404 {
405     ivec local_fft_ndata, local_fft_offset, local_fft_size;
406     int  fft_my, fft_mz;
407     int  nsy, nsz;
408     ivec nf;
409     int  offx, offy, offz, x, y, z, i0, i0t;
410     int  d;
411     real *grid_th;
412
413     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
414                                    local_fft_ndata,
415                                    local_fft_offset,
416                                    local_fft_size);
417     fft_my = local_fft_size[YY];
418     fft_mz = local_fft_size[ZZ];
419
420     const pmegrid_t *pmegrid = &pmegrids->grid_th[thread];
421
422     nsy = pmegrid->s[YY];
423     nsz = pmegrid->s[ZZ];
424
425     for (d = 0; d < DIM; d++)
426     {
427         nf[d] = std::min(pmegrid->n[d] - (pmegrid->order - 1),
428                          local_fft_ndata[d] - pmegrid->offset[d]);
429     }
430
431     offx = pmegrid->offset[XX];
432     offy = pmegrid->offset[YY];
433     offz = pmegrid->offset[ZZ];
434
435     /* Directly copy the non-overlapping parts of the local grids.
436      * This also initializes the full grid.
437      */
438     grid_th = pmegrid->grid;
439     for (x = 0; x < nf[XX]; x++)
440     {
441         for (y = 0; y < nf[YY]; y++)
442         {
443             i0  = ((offx + x)*fft_my + (offy + y))*fft_mz + offz;
444             i0t = (x*nsy + y)*nsz;
445             for (z = 0; z < nf[ZZ]; z++)
446             {
447                 fftgrid[i0+z] = grid_th[i0t+z];
448             }
449         }
450     }
451 }
452
453 static void
454 reduce_threadgrid_overlap(const gmx_pme_t *pme,
455                           const pmegrids_t *pmegrids, int thread,
456                           real *fftgrid, real *commbuf_x, real *commbuf_y,
457                           int grid_index)
458 {
459     ivec local_fft_ndata, local_fft_offset, local_fft_size;
460     int  fft_nx, fft_ny, fft_nz;
461     int  fft_my, fft_mz;
462     int  buf_my = -1;
463     int  nsy, nsz;
464     ivec localcopy_end, commcopy_end;
465     int  offx, offy, offz, x, y, z, i0, i0t;
466     int  sx, sy, sz, fx, fy, fz, tx1, ty1, tz1, ox, oy, oz;
467     gmx_bool bClearBufX, bClearBufY, bClearBufXY, bClearBuf;
468     gmx_bool bCommX, bCommY;
469     int  d;
470     int  thread_f;
471     const pmegrid_t *pmegrid, *pmegrid_g, *pmegrid_f;
472     const real *grid_th;
473     real *commbuf = nullptr;
474
475     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
476                                    local_fft_ndata,
477                                    local_fft_offset,
478                                    local_fft_size);
479     fft_nx = local_fft_ndata[XX];
480     fft_ny = local_fft_ndata[YY];
481     fft_nz = local_fft_ndata[ZZ];
482
483     fft_my = local_fft_size[YY];
484     fft_mz = local_fft_size[ZZ];
485
486     /* This routine is called when all thread have finished spreading.
487      * Here each thread sums grid contributions calculated by other threads
488      * to the thread local grid volume.
489      * To minimize the number of grid copying operations,
490      * this routines sums immediately from the pmegrid to the fftgrid.
491      */
492
493     /* Determine which part of the full node grid we should operate on,
494      * this is our thread local part of the full grid.
495      */
496     pmegrid = &pmegrids->grid_th[thread];
497
498     for (d = 0; d < DIM; d++)
499     {
500         /* Determine up to where our thread needs to copy from the
501          * thread-local charge spreading grid to the rank-local FFT grid.
502          * This is up to our spreading grid end minus order-1 and
503          * not beyond the local FFT grid.
504          */
505         localcopy_end[d] =
506             std::min(pmegrid->offset[d] + pmegrid->n[d] - (pmegrid->order - 1),
507                      local_fft_ndata[d]);
508
509         /* Determine up to where our thread needs to copy from the
510          * thread-local charge spreading grid to the communication buffer.
511          * Note: only relevant with communication, ignored otherwise.
512          */
513         commcopy_end[d]  = localcopy_end[d];
514         if (pmegrid->ci[d] == pmegrids->nc[d] - 1)
515         {
516             /* The last thread should copy up to the last pme grid line.
517              * When the rank-local FFT grid is narrower than pme-order,
518              * we need the max below to ensure copying of all data.
519              */
520             commcopy_end[d] = std::max(commcopy_end[d], pme->pme_order);
521         }
522     }
523
524     offx = pmegrid->offset[XX];
525     offy = pmegrid->offset[YY];
526     offz = pmegrid->offset[ZZ];
527
528
529     bClearBufX  = TRUE;
530     bClearBufY  = TRUE;
531     bClearBufXY = TRUE;
532
533     /* Now loop over all the thread data blocks that contribute
534      * to the grid region we (our thread) are operating on.
535      */
536     /* Note that fft_nx/y is equal to the number of grid points
537      * between the first point of our node grid and the one of the next node.
538      */
539     for (sx = 0; sx >= -pmegrids->nthread_comm[XX]; sx--)
540     {
541         fx     = pmegrid->ci[XX] + sx;
542         ox     = 0;
543         bCommX = FALSE;
544         if (fx < 0)
545         {
546             fx    += pmegrids->nc[XX];
547             ox    -= fft_nx;
548             bCommX = (pme->nnodes_major > 1);
549         }
550         pmegrid_g = &pmegrids->grid_th[fx*pmegrids->nc[YY]*pmegrids->nc[ZZ]];
551         ox       += pmegrid_g->offset[XX];
552         /* Determine the end of our part of the source grid.
553          * Use our thread local source grid and target grid part
554          */
555         tx1 = std::min(ox + pmegrid_g->n[XX],
556                        !bCommX ? localcopy_end[XX] : commcopy_end[XX]);
557
558         for (sy = 0; sy >= -pmegrids->nthread_comm[YY]; sy--)
559         {
560             fy     = pmegrid->ci[YY] + sy;
561             oy     = 0;
562             bCommY = FALSE;
563             if (fy < 0)
564             {
565                 fy    += pmegrids->nc[YY];
566                 oy    -= fft_ny;
567                 bCommY = (pme->nnodes_minor > 1);
568             }
569             pmegrid_g = &pmegrids->grid_th[fy*pmegrids->nc[ZZ]];
570             oy       += pmegrid_g->offset[YY];
571             /* Determine the end of our part of the source grid.
572              * Use our thread local source grid and target grid part
573              */
574             ty1 = std::min(oy + pmegrid_g->n[YY],
575                            !bCommY ? localcopy_end[YY] : commcopy_end[YY]);
576
577             for (sz = 0; sz >= -pmegrids->nthread_comm[ZZ]; sz--)
578             {
579                 fz = pmegrid->ci[ZZ] + sz;
580                 oz = 0;
581                 if (fz < 0)
582                 {
583                     fz += pmegrids->nc[ZZ];
584                     oz -= fft_nz;
585                 }
586                 pmegrid_g = &pmegrids->grid_th[fz];
587                 oz       += pmegrid_g->offset[ZZ];
588                 tz1       = std::min(oz + pmegrid_g->n[ZZ], localcopy_end[ZZ]);
589
590                 if (sx == 0 && sy == 0 && sz == 0)
591                 {
592                     /* We have already added our local contribution
593                      * before calling this routine, so skip it here.
594                      */
595                     continue;
596                 }
597
598                 thread_f = (fx*pmegrids->nc[YY] + fy)*pmegrids->nc[ZZ] + fz;
599
600                 pmegrid_f = &pmegrids->grid_th[thread_f];
601
602                 grid_th = pmegrid_f->grid;
603
604                 nsy = pmegrid_f->s[YY];
605                 nsz = pmegrid_f->s[ZZ];
606
607 #ifdef DEBUG_PME_REDUCE
608                 printf("n%d t%d add %d  %2d %2d %2d  %2d %2d %2d  %2d-%2d %2d-%2d, %2d-%2d %2d-%2d, %2d-%2d %2d-%2d\n",
609                        pme->nodeid, thread, thread_f,
610                        pme->pmegrid_start_ix,
611                        pme->pmegrid_start_iy,
612                        pme->pmegrid_start_iz,
613                        sx, sy, sz,
614                        offx-ox, tx1-ox, offx, tx1,
615                        offy-oy, ty1-oy, offy, ty1,
616                        offz-oz, tz1-oz, offz, tz1);
617 #endif
618
619                 if (!(bCommX || bCommY))
620                 {
621                     /* Copy from the thread local grid to the node grid */
622                     for (x = offx; x < tx1; x++)
623                     {
624                         for (y = offy; y < ty1; y++)
625                         {
626                             i0  = (x*fft_my + y)*fft_mz;
627                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
628                             for (z = offz; z < tz1; z++)
629                             {
630                                 fftgrid[i0+z] += grid_th[i0t+z];
631                             }
632                         }
633                     }
634                 }
635                 else
636                 {
637                     /* The order of this conditional decides
638                      * where the corner volume gets stored with x+y decomp.
639                      */
640                     if (bCommY)
641                     {
642                         commbuf = commbuf_y;
643                         /* The y-size of the communication buffer is set by
644                          * the overlap of the grid part of our local slab
645                          * with the part starting at the next slab.
646                          */
647                         buf_my  =
648                             pme->overlap[1].s2g1[pme->nodeid_minor] -
649                             pme->overlap[1].s2g0[pme->nodeid_minor+1];
650                         if (bCommX)
651                         {
652                             /* We index commbuf modulo the local grid size */
653                             commbuf += buf_my*fft_nx*fft_nz;
654
655                             bClearBuf   = bClearBufXY;
656                             bClearBufXY = FALSE;
657                         }
658                         else
659                         {
660                             bClearBuf  = bClearBufY;
661                             bClearBufY = FALSE;
662                         }
663                     }
664                     else
665                     {
666                         commbuf    = commbuf_x;
667                         buf_my     = fft_ny;
668                         bClearBuf  = bClearBufX;
669                         bClearBufX = FALSE;
670                     }
671
672                     /* Copy to the communication buffer */
673                     for (x = offx; x < tx1; x++)
674                     {
675                         for (y = offy; y < ty1; y++)
676                         {
677                             i0  = (x*buf_my + y)*fft_nz;
678                             i0t = ((x - ox)*nsy + (y - oy))*nsz - oz;
679
680                             if (bClearBuf)
681                             {
682                                 /* First access of commbuf, initialize it */
683                                 for (z = offz; z < tz1; z++)
684                                 {
685                                     commbuf[i0+z]  = grid_th[i0t+z];
686                                 }
687                             }
688                             else
689                             {
690                                 for (z = offz; z < tz1; z++)
691                                 {
692                                     commbuf[i0+z] += grid_th[i0t+z];
693                                 }
694                             }
695                         }
696                     }
697                 }
698             }
699         }
700     }
701 }
702
703
704 static void sum_fftgrid_dd(const gmx_pme_t *pme, real *fftgrid, int grid_index)
705 {
706     ivec local_fft_ndata, local_fft_offset, local_fft_size;
707     int  send_index0, send_nindex;
708     int  recv_nindex;
709 #if GMX_MPI
710     MPI_Status stat;
711 #endif
712     int  recv_size_y;
713     int  ipulse, size_yx;
714     real *sendptr, *recvptr;
715     int  x, y, z, indg, indb;
716
717     /* Note that this routine is only used for forward communication.
718      * Since the force gathering, unlike the coefficient spreading,
719      * can be trivially parallelized over the particles,
720      * the backwards process is much simpler and can use the "old"
721      * communication setup.
722      */
723
724     gmx_parallel_3dfft_real_limits(pme->pfft_setup[grid_index],
725                                    local_fft_ndata,
726                                    local_fft_offset,
727                                    local_fft_size);
728
729     if (pme->nnodes_minor > 1)
730     {
731         /* Major dimension */
732         const pme_overlap_t *overlap = &pme->overlap[1];
733
734         if (pme->nnodes_major > 1)
735         {
736             size_yx = pme->overlap[0].comm_data[0].send_nindex;
737         }
738         else
739         {
740             size_yx = 0;
741         }
742 #if GMX_MPI
743         int datasize = (local_fft_ndata[XX] + size_yx)*local_fft_ndata[ZZ];
744
745         int send_size_y = overlap->send_size;
746 #endif
747
748         for (ipulse = 0; ipulse < overlap->noverlap_nodes; ipulse++)
749         {
750             send_index0   =
751                 overlap->comm_data[ipulse].send_index0 -
752                 overlap->comm_data[0].send_index0;
753             send_nindex   = overlap->comm_data[ipulse].send_nindex;
754             /* We don't use recv_index0, as we always receive starting at 0 */
755             recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
756             recv_size_y   = overlap->comm_data[ipulse].recv_size;
757
758             sendptr = overlap->sendbuf + send_index0*local_fft_ndata[ZZ];
759             recvptr = overlap->recvbuf;
760
761             if (debug != nullptr)
762             {
763                 fprintf(debug, "PME fftgrid comm y %2d x %2d x %2d\n",
764                         local_fft_ndata[XX], send_nindex, local_fft_ndata[ZZ]);
765             }
766
767 #if GMX_MPI
768             int send_id = overlap->send_id[ipulse];
769             int recv_id = overlap->recv_id[ipulse];
770             MPI_Sendrecv(sendptr, send_size_y*datasize, GMX_MPI_REAL,
771                          send_id, ipulse,
772                          recvptr, recv_size_y*datasize, GMX_MPI_REAL,
773                          recv_id, ipulse,
774                          overlap->mpi_comm, &stat);
775 #endif
776
777             for (x = 0; x < local_fft_ndata[XX]; x++)
778             {
779                 for (y = 0; y < recv_nindex; y++)
780                 {
781                     indg = (x*local_fft_size[YY] + y)*local_fft_size[ZZ];
782                     indb = (x*recv_size_y        + y)*local_fft_ndata[ZZ];
783                     for (z = 0; z < local_fft_ndata[ZZ]; z++)
784                     {
785                         fftgrid[indg+z] += recvptr[indb+z];
786                     }
787                 }
788             }
789
790             if (pme->nnodes_major > 1)
791             {
792                 /* Copy from the received buffer to the send buffer for dim 0 */
793                 sendptr = pme->overlap[0].sendbuf;
794                 for (x = 0; x < size_yx; x++)
795                 {
796                     for (y = 0; y < recv_nindex; y++)
797                     {
798                         indg = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
799                         indb = ((local_fft_ndata[XX] + x)*recv_size_y + y)*local_fft_ndata[ZZ];
800                         for (z = 0; z < local_fft_ndata[ZZ]; z++)
801                         {
802                             sendptr[indg+z] += recvptr[indb+z];
803                         }
804                     }
805                 }
806             }
807         }
808     }
809
810     /* We only support a single pulse here.
811      * This is not a severe limitation, as this code is only used
812      * with OpenMP and with OpenMP the (PME) domains can be larger.
813      */
814     if (pme->nnodes_major > 1)
815     {
816         /* Major dimension */
817         const pme_overlap_t *overlap = &pme->overlap[0];
818
819         ipulse = 0;
820
821         send_nindex   = overlap->comm_data[ipulse].send_nindex;
822         /* We don't use recv_index0, as we always receive starting at 0 */
823         recv_nindex   = overlap->comm_data[ipulse].recv_nindex;
824
825         recvptr = overlap->recvbuf;
826
827         if (debug != nullptr)
828         {
829             fprintf(debug, "PME fftgrid comm x %2d x %2d x %2d\n",
830                     send_nindex, local_fft_ndata[YY], local_fft_ndata[ZZ]);
831         }
832
833 #if GMX_MPI
834         int datasize = local_fft_ndata[YY]*local_fft_ndata[ZZ];
835         int send_id  = overlap->send_id[ipulse];
836         int recv_id  = overlap->recv_id[ipulse];
837         sendptr      = overlap->sendbuf;
838         MPI_Sendrecv(sendptr, send_nindex*datasize, GMX_MPI_REAL,
839                      send_id, ipulse,
840                      recvptr, recv_nindex*datasize, GMX_MPI_REAL,
841                      recv_id, ipulse,
842                      overlap->mpi_comm, &stat);
843 #endif
844
845         for (x = 0; x < recv_nindex; x++)
846         {
847             for (y = 0; y < local_fft_ndata[YY]; y++)
848             {
849                 indg = (x*local_fft_size[YY]  + y)*local_fft_size[ZZ];
850                 indb = (x*local_fft_ndata[YY] + y)*local_fft_ndata[ZZ];
851                 for (z = 0; z < local_fft_ndata[ZZ]; z++)
852                 {
853                     fftgrid[indg+z] += recvptr[indb+z];
854                 }
855             }
856         }
857     }
858 }
859
860 void spread_on_grid(const gmx_pme_t *pme,
861                     const pme_atomcomm_t *atc, const pmegrids_t *grids,
862                     gmx_bool bCalcSplines, gmx_bool bSpread,
863                     real *fftgrid, gmx_bool bDoSplines, int grid_index)
864 {
865     int nthread, thread;
866 #ifdef PME_TIME_THREADS
867     gmx_cycles_t c1, c2, c3, ct1a, ct1b, ct1c;
868     static double cs1     = 0, cs2 = 0, cs3 = 0;
869     static double cs1a[6] = {0, 0, 0, 0, 0, 0};
870     static int cnt        = 0;
871 #endif
872
873     nthread = pme->nthread;
874     assert(nthread > 0);
875
876 #ifdef PME_TIME_THREADS
877     c1 = omp_cyc_start();
878 #endif
879     if (bCalcSplines)
880     {
881 #pragma omp parallel for num_threads(nthread) schedule(static)
882         for (thread = 0; thread < nthread; thread++)
883         {
884             try
885             {
886                 int start, end;
887
888                 start = atc->n* thread   /nthread;
889                 end   = atc->n*(thread+1)/nthread;
890
891                 /* Compute fftgrid index for all atoms,
892                  * with help of some extra variables.
893                  */
894                 calc_interpolation_idx(pme, atc, start, grid_index, end, thread);
895             }
896             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
897         }
898     }
899 #ifdef PME_TIME_THREADS
900     c1   = omp_cyc_end(c1);
901     cs1 += (double)c1;
902 #endif
903
904 #ifdef PME_TIME_THREADS
905     c2 = omp_cyc_start();
906 #endif
907 #pragma omp parallel for num_threads(nthread) schedule(static)
908     for (thread = 0; thread < nthread; thread++)
909     {
910         try
911         {
912             splinedata_t *spline;
913
914             /* make local bsplines  */
915             if (grids == nullptr || !pme->bUseThreads)
916             {
917                 spline = &atc->spline[0];
918
919                 spline->n = atc->n;
920             }
921             else
922             {
923                 spline = &atc->spline[thread];
924
925                 if (grids->nthread == 1)
926                 {
927                     /* One thread, we operate on all coefficients */
928                     spline->n = atc->n;
929                 }
930                 else
931                 {
932                     /* Get the indices our thread should operate on */
933                     make_thread_local_ind(atc, thread, spline);
934                 }
935             }
936
937             if (bCalcSplines)
938             {
939                 make_bsplines(spline->theta, spline->dtheta, pme->pme_order,
940                               atc->fractx, spline->n, spline->ind, atc->coefficient, bDoSplines);
941             }
942
943             if (bSpread)
944             {
945                 /* put local atoms on grid. */
946                 const pmegrid_t *grid = pme->bUseThreads ? &grids->grid_th[thread] : &grids->grid;
947
948 #ifdef PME_TIME_SPREAD
949                 ct1a = omp_cyc_start();
950 #endif
951                 spread_coefficients_bsplines_thread(grid, atc, spline, pme->spline_work);
952
953                 if (pme->bUseThreads)
954                 {
955                     copy_local_grid(pme, grids, grid_index, thread, fftgrid);
956                 }
957 #ifdef PME_TIME_SPREAD
958                 ct1a          = omp_cyc_end(ct1a);
959                 cs1a[thread] += (double)ct1a;
960 #endif
961             }
962         }
963         GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
964     }
965 #ifdef PME_TIME_THREADS
966     c2   = omp_cyc_end(c2);
967     cs2 += (double)c2;
968 #endif
969
970     if (bSpread && pme->bUseThreads)
971     {
972 #ifdef PME_TIME_THREADS
973         c3 = omp_cyc_start();
974 #endif
975 #pragma omp parallel for num_threads(grids->nthread) schedule(static)
976         for (thread = 0; thread < grids->nthread; thread++)
977         {
978             try
979             {
980                 reduce_threadgrid_overlap(pme, grids, thread,
981                                           fftgrid,
982                                           pme->overlap[0].sendbuf,
983                                           pme->overlap[1].sendbuf,
984                                           grid_index);
985             }
986             GMX_CATCH_ALL_AND_EXIT_WITH_FATAL_ERROR;
987         }
988 #ifdef PME_TIME_THREADS
989         c3   = omp_cyc_end(c3);
990         cs3 += (double)c3;
991 #endif
992
993         if (pme->nnodes > 1)
994         {
995             /* Communicate the overlapping part of the fftgrid.
996              * For this communication call we need to check pme->bUseThreads
997              * to have all ranks communicate here, regardless of pme->nthread.
998              */
999             sum_fftgrid_dd(pme, fftgrid, grid_index);
1000         }
1001     }
1002
1003 #ifdef PME_TIME_THREADS
1004     cnt++;
1005     if (cnt % 20 == 0)
1006     {
1007         printf("idx %.2f spread %.2f red %.2f",
1008                cs1*1e-9, cs2*1e-9, cs3*1e-9);
1009 #ifdef PME_TIME_SPREAD
1010         for (thread = 0; thread < nthread; thread++)
1011         {
1012             printf(" %.2f", cs1a[thread]*1e-9);
1013         }
1014 #endif
1015         printf("\n");
1016     }
1017 #endif
1018 }