d133daaeeaa5acc4668366aa26bd347f9bd44fca
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gather_sycl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements PME force gathering in SYCL.
38  *
39  *  \author Andrey Alekseenko <al42and@gmail.com>
40  */
41
42 #include "gmxpre.h"
43
44 #include "pme_gather_sycl.h"
45
46 #include "gromacs/gpu_utils/gmxsycl.h"
47 #include "gromacs/gpu_utils/gputraits_sycl.h"
48 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
49 #include "gromacs/gpu_utils/syclutils.h"
50 #include "gromacs/math/functions.h"
51
52 #include "pme_gpu_calculate_splines_sycl.h"
53 #include "pme_grid.h"
54 #include "pme_gpu_constants.h"
55 #include "pme_gpu_types_host.h"
56
57
58 /*! \brief Reduce the partial force contributions.
59  *
60  * \tparam     order              The PME order (must be 4).
61  * \tparam     atomDataSize       The number of partial force contributions for each atom (currently
62  *                                order^2 == 16).
63  * \tparam     workGroupSize      The size of a work-group.
64  * \tparam     subGroupSize       The size of a sub-group.
65  *
66  * \param[in]  itemIdx            SYCL thread ID.
67  * \param[out] sm_forces          Shared memory array with the output forces (number of elements
68  *                                is number of atoms per block).
69  * \param[in]  atomIndexLocal     Local atom index.
70  * \param[in]  splineIndex        Spline index.
71  * \param[in]  lineIndex          Line index (same as threadLocalId)
72  * \param[in]  realGridSizeFP     Local grid size constant
73  * \param[in]  fx                 Input force partial component X
74  * \param[in]  fy                 Input force partial component Y
75  * \param[in]  fz                 Input force partial component Z
76  */
77 template<int order, int atomDataSize, int workGroupSize, int subGroupSize>
78 inline void reduceAtomForces(cl::sycl::nd_item<3>        itemIdx,
79                              cl::sycl::local_ptr<Float3> sm_forces,
80                              const int                   atomIndexLocal,
81                              const int                   splineIndex,
82                              const int gmx_unused        lineIndex,
83                              const float                 realGridSizeFP[3],
84                              float&                      fx, // NOLINT(google-runtime-references)
85                              float&                      fy, // NOLINT(google-runtime-references)
86                              float&                      fz)                      // NOLINT(google-runtime-references)
87 {
88     static_assert(gmx::isPowerOfTwo(order));
89     // TODO: find out if this is the best in terms of transactions count
90     static_assert(order == 4, "Only order of 4 is implemented");
91
92     sycl_2020::sub_group sg = itemIdx.get_sub_group();
93
94     static_assert(atomDataSize <= subGroupSize,
95                   "TODO: rework for atomDataSize > subGroupSize (order 8 or larger)");
96
97     fx += sycl_2020::shift_left(sg, fx, 1);
98     fy += sycl_2020::shift_right(sg, fy, 1);
99     fz += sycl_2020::shift_left(sg, fz, 1);
100     if (splineIndex & 1)
101     {
102         fx = fy;
103     }
104     fx += sycl_2020::shift_left(sg, fx, 2);
105     fz += sycl_2020::shift_right(sg, fz, 2);
106     if (splineIndex & 2)
107     {
108         fx = fz;
109     }
110     // We have to just further reduce those groups of 4
111     for (int delta = 4; delta < atomDataSize; delta *= 2)
112     {
113         fx += sycl_2020::shift_left(sg, fx, delta);
114     }
115     const int dimIndex = splineIndex;
116     if (dimIndex < DIM)
117     {
118         const float n                       = realGridSizeFP[dimIndex];
119         sm_forces[atomIndexLocal][dimIndex] = fx * n;
120     }
121 }
122
123 /*! \brief Calculate the sum of the force partial components (in X, Y and Z)
124  *
125  * \tparam     order              The PME order (must be 4).
126  * \tparam     atomsPerWarp       The number of atoms per GPU warp.
127  * \tparam     wrapX              Tells if the grid is wrapped in the X dimension.
128  * \tparam     wrapY              Tells if the grid is wrapped in the Y dimension.
129  * \param[out] fx                 The force partial component in the X dimension.
130  * \param[out] fy                 The force partial component in the Y dimension.
131  * \param[out] fz                 The force partial component in the Z dimension.
132  * \param[in] ithyMin             The thread minimum index in the Y dimension.
133  * \param[in] ithyMax             The thread maximum index in the Y dimension.
134  * \param[in] ixBase              The grid line index base value in the X dimension.
135  * \param[in] iz                  The grid line index in the Z dimension.
136  * \param[in] nx                  The grid real size in the X dimension.
137  * \param[in] ny                  The grid real size in the Y dimension.
138  * \param[in] pny                 The padded grid real size in the Y dimension.
139  * \param[in] pnz                 The padded grid real size in the Z dimension.
140  * \param[in] atomIndexLocal      The atom index for this thread.
141  * \param[in] splineIndexBase     The base value of the spline parameter index.
142  * \param[in] tdz                 The theta and dtheta in the Z dimension.
143  * \param[in] sm_gridlineIndices  Shared memory array of grid line indices.
144  * \param[in] sm_theta            Shared memory array of atom theta values.
145  * \param[in] sm_dtheta           Shared memory array of atom dtheta values.
146  * \param[in] gm_grid             Global memory array of the grid to use.
147  */
148 template<int order, int atomsPerWarp, bool wrapX, bool wrapY>
149 inline void sumForceComponents(cl::sycl::private_ptr<float>      fx,
150                                cl::sycl::private_ptr<float>      fy,
151                                cl::sycl::private_ptr<float>      fz,
152                                const int                         ithyMin,
153                                const int                         ithyMax,
154                                const int                         ixBase,
155                                const int                         iz,
156                                const int                         nx,
157                                const int                         ny,
158                                const int                         pny,
159                                const int                         pnz,
160                                const int                         atomIndexLocal,
161                                const int                         splineIndexBase,
162                                const cl::sycl::float2            tdz,
163                                const cl::sycl::local_ptr<int>    sm_gridlineIndices,
164                                const cl::sycl::local_ptr<float>  sm_theta,
165                                const cl::sycl::local_ptr<float>  sm_dtheta,
166                                const cl::sycl::global_ptr<float> gm_grid)
167 {
168     for (int ithy = ithyMin; ithy < ithyMax; ithy++)
169     {
170         const int splineIndexY = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, YY, ithy);
171         const cl::sycl::float2 tdy{ sm_theta[splineIndexY], sm_dtheta[splineIndexY] };
172
173         int iy = sm_gridlineIndices[atomIndexLocal * DIM + YY] + ithy;
174         if (wrapY & (iy >= ny))
175         {
176             iy -= ny;
177         }
178         const int constOffset = iy * pnz + iz;
179
180 #pragma unroll
181         for (int ithx = 0; ithx < order; ithx++)
182         {
183             int ix = ixBase + ithx;
184             if (wrapX & (ix >= nx))
185             {
186                 ix -= nx;
187             }
188             const int gridIndexGlobal = ix * pny * pnz + constOffset;
189             assert(gridIndexGlobal >= 0);
190             const float gridValue = gm_grid[gridIndexGlobal];
191             assertIsFinite(gridValue);
192             const int splineIndexX = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, XX, ithx);
193             const cl::sycl::float2 tdx{ sm_theta[splineIndexX], sm_dtheta[splineIndexX] };
194             const float            fxy1 = tdz[XX] * gridValue;
195             const float            fz1  = tdz[YY] * gridValue;
196             *fx += tdx[YY] * tdy[XX] * fxy1;
197             *fy += tdx[XX] * tdy[YY] * fxy1;
198             *fz += tdx[XX] * tdy[XX] * fz1;
199         }
200     }
201 }
202
203
204 /*! \brief Calculate the grid forces and store them in shared memory.
205  *
206  * \param[in,out] sm_forces       Shared memory array with the output forces.
207  * \param[in] forceIndexLocal     The local (per thread) index in the sm_forces array.
208  * \param[in] forceIndexGlobal    The index of the thread in the gm_coefficients array.
209  * \param[in] recipBox0           The reciprocal box (first vector).
210  * \param[in] recipBox1           The reciprocal box (second vector).
211  * \param[in] recipBox2           The reciprocal box (third vector).
212  * \param[in] scale               The scale to use when calculating the forces. For gm_coefficientsB
213  * (when using multiple coefficients on a single grid) the scale will be (1.0 - scale).
214  * \param[in] gm_coefficients     Global memory array of the coefficients to use for an unperturbed
215  * or FEP in state A if a single grid is used (\p multiCoefficientsSingleGrid == true).If two
216  * separate grids are used this should be the coefficients of the grid in question.
217  */
218 inline void calculateAndStoreGridForces(cl::sycl::local_ptr<Float3>       sm_forces,
219                                         const int                         forceIndexLocal,
220                                         const int                         forceIndexGlobal,
221                                         const Float3&                     recipBox0,
222                                         const Float3&                     recipBox1,
223                                         const Float3&                     recipBox2,
224                                         const float                       scale,
225                                         const cl::sycl::global_ptr<float> gm_coefficients)
226 {
227     const Float3 atomForces     = sm_forces[forceIndexLocal];
228     float        negCoefficient = -scale * gm_coefficients[forceIndexGlobal];
229     Float3       result;
230     result[XX] = negCoefficient * recipBox0[XX] * atomForces[XX];
231     result[YY] = negCoefficient * (recipBox0[YY] * atomForces[XX] + recipBox1[YY] * atomForces[YY]);
232     result[ZZ] = negCoefficient
233                  * (recipBox0[ZZ] * atomForces[XX] + recipBox1[ZZ] * atomForces[YY]
234                     + recipBox2[ZZ] * atomForces[ZZ]);
235     sm_forces[forceIndexLocal] = result;
236 }
237
238 /*! \brief
239  * A SYCL kernel which gathers the atom forces from the grid.
240  * The grid is assumed to be wrapped in dimension Z.
241  *
242  * \tparam order          PME interpolation order.
243  * \tparam wrapX          A boolean which tells if the grid overlap in dimension X should
244  *                        be wrapped.
245  * \tparam wrapY          A boolean which tells if the grid overlap in dimension Y should
246  *                        be wrapped.
247  * \tparam numGrids       The number of grids to use in the kernel. Can be 1 or 2.
248  * \tparam writeGlobal    Tells if we should read spline values from global memory.
249  * \tparam threadsPerAtom How many threads work on each atom.
250  * \tparam subGroupSize   Size of the sub-group.
251  */
252 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
253 auto pmeGatherKernel(cl::sycl::handler&                                 cgh,
254                      const int                                          nAtoms,
255                      DeviceAccessor<float, mode::read>                  a_gridA,
256                      OptionalAccessor<float, mode::read, numGrids == 2> a_gridB,
257                      DeviceAccessor<float, mode::read>                  a_coefficientsA,
258                      OptionalAccessor<float, mode::read, numGrids == 2> a_coefficientsB,
259                      OptionalAccessor<Float3, mode::read, !readGlobal>  a_coordinates,
260                      DeviceAccessor<Float3, mode::read_write>           a_forces,
261                      DeviceAccessor<float, mode::read>                  a_theta,
262                      DeviceAccessor<float, mode::read>                  a_dtheta,
263                      DeviceAccessor<int, mode::read>                    a_gridlineIndices,
264                      OptionalAccessor<float, mode::read, !readGlobal>   a_fractShiftsTable,
265                      OptionalAccessor<int, mode::read, !readGlobal>     a_gridlineIndicesTable,
266                      const gmx::IVec                                    tablesOffsets,
267                      const gmx::IVec                                    realGridSize,
268                      const gmx::RVec                                    realGridSizeFP,
269                      const gmx::IVec                                    realGridSizePadded,
270                      const gmx::RVec                                    currentRecipBox0,
271                      const gmx::RVec                                    currentRecipBox1,
272                      const gmx::RVec                                    currentRecipBox2,
273                      const float                                        scale)
274 {
275     static_assert(numGrids == 1 || numGrids == 2);
276
277     constexpr int threadsPerAtomValue = (threadsPerAtom == ThreadsPerAtom::Order) ? order : order * order;
278     constexpr int atomDataSize        = threadsPerAtomValue;
279     constexpr int atomsPerBlock       = (c_gatherMaxWarpsPerBlock * subGroupSize) / atomDataSize;
280     // Number of atoms processed by a single warp in spread and gather
281     static_assert(subGroupSize >= atomDataSize);
282     constexpr int atomsPerWarp        = subGroupSize / atomDataSize;
283     constexpr int blockSize           = atomsPerBlock * atomDataSize;
284     constexpr int splineParamsSize    = atomsPerBlock * DIM * order;
285     constexpr int gridlineIndicesSize = atomsPerBlock * DIM;
286
287     a_gridA.bind(cgh);
288     a_coefficientsA.bind(cgh);
289     a_forces.bind(cgh);
290
291     if constexpr (numGrids == 2)
292     {
293         a_gridB.bind(cgh);
294         a_coefficientsB.bind(cgh);
295     }
296
297     if constexpr (readGlobal)
298     {
299         a_theta.bind(cgh);
300         a_dtheta.bind(cgh);
301         a_gridlineIndices.bind(cgh);
302     }
303     else
304     {
305         a_coordinates.bind(cgh);
306         a_fractShiftsTable.bind(cgh);
307         a_gridlineIndicesTable.bind(cgh);
308     }
309
310     // Gridline indices, ivec
311     cl::sycl::accessor<int, 1, mode::read_write, target::local> sm_gridlineIndices(
312             cl::sycl::range<1>(atomsPerBlock * DIM), cgh);
313     // Spline values
314     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_theta(
315             cl::sycl::range<1>(atomsPerBlock * DIM * order), cgh);
316     // Spline derivatives
317     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_dtheta(
318             cl::sycl::range<1>(atomsPerBlock * DIM * order), cgh);
319     // Coefficients prefetch cache
320     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_coefficients(
321             cl::sycl::range<1>(atomsPerBlock), cgh);
322     // Coordinates prefetch cache
323     cl::sycl::accessor<Float3, 1, mode::read_write, target::local> sm_coordinates(
324             cl::sycl::range<1>(atomsPerBlock), cgh);
325     // Reduction of partial force contributions
326     cl::sycl::accessor<Float3, 1, mode::read_write, target::local> sm_forces(
327             cl::sycl::range<1>(atomsPerBlock), cgh);
328
329     auto sm_fractCoords = [&]() {
330         if constexpr (!readGlobal)
331         {
332             return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
333                     cl::sycl::range<1>(atomsPerBlock * DIM), cgh);
334         }
335         else
336         {
337             return nullptr;
338         }
339     }();
340
341     return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
342     {
343         assert(blockSize == itemIdx.get_local_range().size());
344         /* These are the atom indices - for the shared and global memory */
345         const int atomIndexLocal = itemIdx.get_local_id(XX);
346         const int blockIndex =
347                 itemIdx.get_group(YY) * itemIdx.get_group_range(ZZ) + itemIdx.get_group(ZZ);
348         // itemIdx.get_group_linear_id();
349
350         const int atomIndexOffset = blockIndex * atomsPerBlock;
351         const int atomIndexGlobal = atomIndexOffset + atomIndexLocal;
352         /* Early return for fully empty blocks at the end
353          * (should only happen for billions of input atoms)
354          */
355         if (atomIndexOffset >= nAtoms)
356         {
357             return;
358         }
359         /* Spline Z coordinates */
360         const int ithz = itemIdx.get_local_id(ZZ);
361         /* These are the spline contribution indices in shared memory */
362         const int splineIndex =
363                 itemIdx.get_local_id(YY) * itemIdx.get_local_range(ZZ) + itemIdx.get_local_id(ZZ);
364
365         const int threadLocalId    = itemIdx.get_local_linear_id();
366         const int threadLocalIdMax = blockSize;
367         assert(threadLocalId < threadLocalIdMax);
368
369         const int lineIndex =
370                 (itemIdx.get_local_id(XX) * (itemIdx.get_local_range(ZZ) * itemIdx.get_local_range(YY)))
371                 + splineIndex; // And to all the block's particles
372         assert(lineIndex == threadLocalId);
373
374         if constexpr (readGlobal)
375         {
376             /* Read splines */
377             const int localGridlineIndicesIndex = threadLocalId;
378             const int globalGridlineIndicesIndex =
379                     blockIndex * gridlineIndicesSize + localGridlineIndicesIndex;
380             // itemIdx.get_group(ZZ) * gridlineIndicesSize + localGridlineIndicesIndex;
381             if (localGridlineIndicesIndex < gridlineIndicesSize)
382             {
383                 sm_gridlineIndices[localGridlineIndicesIndex] =
384                         a_gridlineIndices[globalGridlineIndicesIndex];
385                 assert(sm_gridlineIndices[localGridlineIndicesIndex] >= 0);
386             }
387             /* The loop needed for order threads per atom to make sure we load all data values, as each thread must load multiple values
388                with order*order threads per atom, it is only required for each thread to load one data value */
389
390             const int iMin = 0;
391             const int iMax = (threadsPerAtom == ThreadsPerAtom::Order) ? 3 : 1;
392
393             for (int i = iMin; i < iMax; i++)
394             {
395                 // i will always be zero for order*order threads per atom
396                 const int localSplineParamsIndex = threadLocalId + i * threadLocalIdMax;
397                 const int globalSplineParamsIndex = blockIndex * splineParamsSize + localSplineParamsIndex;
398                 // const int globalSplineParamsIndex = itemIdx.get_group(ZZ) * splineParamsSize + localSplineParamsIndex;
399                 if (localSplineParamsIndex < splineParamsSize)
400                 {
401                     sm_theta[localSplineParamsIndex]  = a_theta[globalSplineParamsIndex];
402                     sm_dtheta[localSplineParamsIndex] = a_dtheta[globalSplineParamsIndex];
403                     assertIsFinite(sm_theta[localSplineParamsIndex]);
404                     assertIsFinite(sm_dtheta[localSplineParamsIndex]);
405                 }
406             }
407
408             itemIdx.barrier(fence_space::local_space);
409         }
410         else
411         {
412             /* Recalculate  Splines  */
413             /* Staging coefficients/charges */
414             pmeGpuStageAtomData<float, atomsPerBlock, 1>(
415                     sm_coefficients.get_pointer(), a_coefficientsA.get_pointer(), itemIdx);
416             /* Staging coordinates */
417             pmeGpuStageAtomData<Float3, atomsPerBlock, 1>(
418                     sm_coordinates.get_pointer(), a_coordinates.get_pointer(), itemIdx);
419             itemIdx.barrier(fence_space::local_space);
420             const Float3 atomX      = sm_coordinates[atomIndexLocal];
421             const float  atomCharge = sm_coefficients[atomIndexLocal];
422
423             calculateSplines<order, atomsPerBlock, atomsPerWarp, true, false, numGrids, subGroupSize>(
424                     atomIndexOffset,
425                     atomX,
426                     atomCharge,
427                     tablesOffsets,
428                     realGridSizeFP,
429                     currentRecipBox0,
430                     currentRecipBox1,
431                     currentRecipBox2,
432                     nullptr,
433                     nullptr,
434                     nullptr,
435                     a_fractShiftsTable.get_pointer(),
436                     a_gridlineIndicesTable.get_pointer(),
437                     sm_theta.get_pointer(),
438                     sm_dtheta.get_pointer(),
439                     sm_gridlineIndices.get_pointer(),
440                     sm_fractCoords.get_pointer(),
441                     itemIdx);
442             subGroupBarrier(itemIdx);
443         }
444         float fx = 0.0F;
445         float fy = 0.0F;
446         float fz = 0.0F;
447
448         const int chargeCheck = pmeGpuCheckAtomCharge(a_coefficientsA[atomIndexGlobal]);
449
450         const int nx  = realGridSize[XX];
451         const int ny  = realGridSize[YY];
452         const int nz  = realGridSize[ZZ];
453         const int pny = realGridSizePadded[YY];
454         const int pnz = realGridSizePadded[ZZ];
455
456         const int atomWarpIndex = atomIndexLocal % atomsPerWarp;
457         const int warpIndex     = atomIndexLocal / atomsPerWarp;
458
459         const int splineIndexBase = getSplineParamIndexBase<order, atomsPerWarp>(warpIndex, atomWarpIndex);
460         const int splineIndexZ = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, ZZ, ithz);
461         const cl::sycl::float2 tdz{ sm_theta[splineIndexZ], sm_dtheta[splineIndexZ] };
462
463         int       iz     = sm_gridlineIndices[atomIndexLocal * DIM + ZZ] + ithz;
464         const int ixBase = sm_gridlineIndices[atomIndexLocal * DIM + XX];
465
466         if (iz >= nz)
467         {
468             iz -= nz;
469         }
470
471         const int ithyMin = (threadsPerAtom == ThreadsPerAtom::Order) ? 0 : itemIdx.get_local_id(YY);
472         const int ithyMax =
473                 (threadsPerAtom == ThreadsPerAtom::Order) ? order : itemIdx.get_local_id(YY) + 1;
474         if (chargeCheck)
475         {
476             sumForceComponents<order, atomsPerWarp, wrapX, wrapY>(&fx,
477                                                                   &fy,
478                                                                   &fz,
479                                                                   ithyMin,
480                                                                   ithyMax,
481                                                                   ixBase,
482                                                                   iz,
483                                                                   nx,
484                                                                   ny,
485                                                                   pny,
486                                                                   pnz,
487                                                                   atomIndexLocal,
488                                                                   splineIndexBase,
489                                                                   tdz,
490                                                                   sm_gridlineIndices.get_pointer(),
491                                                                   sm_theta.get_pointer(),
492                                                                   sm_dtheta.get_pointer(),
493                                                                   a_gridA.get_pointer());
494         }
495         reduceAtomForces<order, atomDataSize, blockSize, subGroupSize>(
496                 itemIdx, sm_forces.get_pointer(), atomIndexLocal, splineIndex, lineIndex, realGridSizeFP, fx, fy, fz);
497         itemIdx.barrier(fence_space::local_space);
498
499         /* Calculating the final forces with no component branching, atomsPerBlock threads */
500         const int forceIndexLocal  = threadLocalId;
501         const int forceIndexGlobal = atomIndexOffset + forceIndexLocal;
502         if (forceIndexLocal < atomsPerBlock)
503         {
504             calculateAndStoreGridForces(sm_forces.get_pointer(),
505                                         forceIndexLocal,
506                                         forceIndexGlobal,
507                                         currentRecipBox0,
508                                         currentRecipBox1,
509                                         currentRecipBox2,
510                                         scale,
511                                         a_coefficientsA.get_pointer());
512         }
513         itemIdx.barrier(fence_space::local_space);
514
515         static_assert(atomsPerBlock <= subGroupSize);
516
517         /* Writing or adding the final forces component-wise, single warp */
518         constexpr int blockForcesSize = atomsPerBlock * DIM;
519         constexpr int numIter         = (blockForcesSize + subGroupSize - 1) / subGroupSize;
520         constexpr int iterThreads     = blockForcesSize / numIter;
521         if (threadLocalId < iterThreads)
522         {
523 #pragma unroll
524             for (int i = 0; i < numIter; i++)
525             {
526                 const int floatIndexLocal  = i * iterThreads + threadLocalId;
527                 const int float3IndexLocal = floatIndexLocal / 3;
528                 const int dimLocal         = floatIndexLocal % 3;
529                 static_assert(blockForcesSize % DIM == 0); // Assures that dimGlobal == dimLocal
530                 const int float3IndexGlobal = blockIndex * atomsPerBlock + float3IndexLocal;
531                 // const int float3IndexGlobal = itemIdx.get_group(ZZ) * atomsPerBlock + float3IndexLocal;
532                 a_forces[float3IndexGlobal][dimLocal] = sm_forces[float3IndexLocal][dimLocal];
533             }
534         }
535
536         if constexpr (numGrids == 2)
537         {
538             /* We must sync here since the same shared memory is used as above. */
539             itemIdx.barrier(fence_space::local_space);
540             fx                     = 0.0F;
541             fy                     = 0.0F;
542             fz                     = 0.0F;
543             const bool chargeCheck = pmeGpuCheckAtomCharge(a_coefficientsB[atomIndexGlobal]);
544             if (chargeCheck)
545             {
546                 sumForceComponents<order, atomsPerWarp, wrapX, wrapY>(&fx,
547                                                                       &fy,
548                                                                       &fz,
549                                                                       ithyMin,
550                                                                       ithyMax,
551                                                                       ixBase,
552                                                                       iz,
553                                                                       nx,
554                                                                       ny,
555                                                                       pny,
556                                                                       pnz,
557                                                                       atomIndexLocal,
558                                                                       splineIndexBase,
559                                                                       tdz,
560                                                                       sm_gridlineIndices.get_pointer(),
561                                                                       sm_theta.get_pointer(),
562                                                                       sm_dtheta.get_pointer(),
563                                                                       a_gridB.get_pointer());
564             }
565             // Reduction of partial force contributions
566             reduceAtomForces<order, atomDataSize, blockSize, subGroupSize>(
567                     itemIdx, sm_forces.get_pointer(), atomIndexLocal, splineIndex, lineIndex, realGridSizeFP, fx, fy, fz);
568             itemIdx.barrier(fence_space::local_space);
569
570             /* Calculating the final forces with no component branching, atomsPerBlock threads */
571             if (forceIndexLocal < atomsPerBlock)
572             {
573                 calculateAndStoreGridForces(sm_forces.get_pointer(),
574                                             forceIndexLocal,
575                                             forceIndexGlobal,
576                                             currentRecipBox0,
577                                             currentRecipBox1,
578                                             currentRecipBox2,
579                                             1.0F - scale,
580                                             a_coefficientsB.get_pointer());
581             }
582
583             itemIdx.barrier(fence_space::local_space);
584
585             /* Writing or adding the final forces component-wise, single warp */
586             if (threadLocalId < iterThreads)
587             {
588 #pragma unroll
589                 for (int i = 0; i < numIter; i++)
590                 {
591                     const int floatIndexLocal  = i * iterThreads + threadLocalId;
592                     const int float3IndexLocal = floatIndexLocal / 3;
593                     const int dimLocal         = floatIndexLocal % 3;
594                     static_assert(blockForcesSize % DIM == 0); // Assures that dimGlobal == dimLocal
595                     const int float3IndexGlobal = blockIndex * atomsPerBlock + float3IndexLocal;
596                     a_forces[float3IndexGlobal][dimLocal] += sm_forces[float3IndexLocal][dimLocal];
597                 }
598             }
599         }
600     };
601 }
602
603 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
604 PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::PmeGatherKernel()
605 {
606     reset();
607 }
608
609 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
610 void PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::setArg(
611         size_t argIndex,
612         void*  arg)
613 {
614     if (argIndex == 0)
615     {
616         auto* params   = reinterpret_cast<PmeGpuKernelParams*>(arg);
617         gridParams_    = &params->grid;
618         atomParams_    = &params->atoms;
619         dynamicParams_ = &params->current;
620     }
621     else
622     {
623         GMX_RELEASE_ASSERT(argIndex == 0, "Trying to pass too many args to the kernel");
624     }
625 }
626
627
628 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
629 cl::sycl::event PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::launch(
630         const KernelLaunchConfig& config,
631         const DeviceStream&       deviceStream)
632 {
633     GMX_RELEASE_ASSERT(gridParams_, "Can not launch the kernel before setting its args");
634     GMX_RELEASE_ASSERT(atomParams_, "Can not launch the kernel before setting its args");
635     GMX_RELEASE_ASSERT(dynamicParams_, "Can not launch the kernel before setting its args");
636
637     using kernelNameType =
638             PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>;
639
640     // SYCL has different multidimensional layout than OpenCL/CUDA.
641     const cl::sycl::range<3> localSize{ config.blockSize[2], config.blockSize[1], config.blockSize[0] };
642     const cl::sycl::range<3> groupRange{ config.gridSize[2], config.gridSize[1], config.gridSize[0] };
643     const cl::sycl::nd_range<3> range{ groupRange * localSize, localSize };
644
645     cl::sycl::queue q = deviceStream.stream();
646
647     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
648         auto kernel = pmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>(
649                 cgh,
650                 atomParams_->nAtoms,
651                 gridParams_->d_realGrid[0],
652                 gridParams_->d_realGrid[1],
653                 atomParams_->d_coefficients[0],
654                 atomParams_->d_coefficients[1],
655                 atomParams_->d_coordinates,
656                 atomParams_->d_forces,
657                 atomParams_->d_theta,
658                 atomParams_->d_dtheta,
659                 atomParams_->d_gridlineIndices,
660                 gridParams_->d_fractShiftsTable,
661                 gridParams_->d_gridlineIndicesTable,
662                 gridParams_->tablesOffsets,
663                 gridParams_->realGridSize,
664                 gridParams_->realGridSizeFP,
665                 gridParams_->realGridSizePadded,
666                 dynamicParams_->recipBox[0],
667                 dynamicParams_->recipBox[1],
668                 dynamicParams_->recipBox[2],
669                 dynamicParams_->scale);
670         cgh.parallel_for<kernelNameType>(range, kernel);
671     });
672
673     // Delete set args, so we don't forget to set them before the next launch.
674     reset();
675
676     return e;
677 }
678
679
680 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
681 void PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::reset()
682 {
683     gridParams_    = nullptr;
684     atomParams_    = nullptr;
685     dynamicParams_ = nullptr;
686 }
687
688
689 //! Kernel instantiations
690 /* Disable the "explicit template instantiation 'PmeSplineAndSpreadKernel<...>' will emit a vtable in every
691  * translation unit [-Wweak-template-vtables]" warning.
692  * It is only explicitly instantiated in this translation unit, so we should be safe.
693  */
694 #ifdef __clang__
695 #    pragma clang diagnostic push
696 #    pragma clang diagnostic ignored "-Wweak-template-vtables"
697 #endif
698
699 #define INSTANTIATE_3(order, numGrids, readGlobal, threadsPerAtom, subGroupSize) \
700     template class PmeGatherKernel<order, true, true, numGrids, readGlobal, threadsPerAtom, subGroupSize>;
701
702 #define INSTANTIATE_2(order, numGrids, threadsPerAtom, subGroupSize)    \
703     INSTANTIATE_3(order, numGrids, true, threadsPerAtom, subGroupSize); \
704     INSTANTIATE_3(order, numGrids, false, threadsPerAtom, subGroupSize);
705
706 #define INSTANTIATE(order, subGroupSize)                                 \
707     INSTANTIATE_2(order, 1, ThreadsPerAtom::Order, subGroupSize);        \
708     INSTANTIATE_2(order, 1, ThreadsPerAtom::OrderSquared, subGroupSize); \
709     INSTANTIATE_2(order, 2, ThreadsPerAtom::Order, subGroupSize);        \
710     INSTANTIATE_2(order, 2, ThreadsPerAtom::OrderSquared, subGroupSize);
711
712 #if GMX_SYCL_DPCPP
713 INSTANTIATE(4, 16); // TODO: Choose best value, Issue #4153.
714 #elif GMX_SYCL_HIPSYCL
715 INSTANTIATE(4, 32);
716 INSTANTIATE(4, 64);
717 #endif
718
719 #ifdef __clang__
720 #    pragma clang diagnostic pop
721 #endif