2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2021, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 * \brief Implements PME force gathering in SYCL.
39 * \author Andrey Alekseenko <al42and@gmail.com>
44 #include "pme_gather_sycl.h"
46 #include "gromacs/gpu_utils/gmxsycl.h"
47 #include "gromacs/gpu_utils/gputraits_sycl.h"
48 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
49 #include "gromacs/gpu_utils/syclutils.h"
50 #include "gromacs/math/functions.h"
52 #include "pme_gpu_calculate_splines_sycl.h"
54 #include "pme_gpu_constants.h"
55 #include "pme_gpu_types_host.h"
58 /*! \brief Reduce the partial force contributions.
60 * \tparam order The PME order (must be 4).
61 * \tparam atomDataSize The number of partial force contributions for each atom (currently
63 * \tparam workGroupSize The size of a work-group.
64 * \tparam subGroupSize The size of a sub-group.
66 * \param[in] itemIdx SYCL thread ID.
67 * \param[out] sm_forces Shared memory array with the output forces (number of elements
68 * is number of atoms per block).
69 * \param[in] atomIndexLocal Local atom index.
70 * \param[in] splineIndex Spline index.
71 * \param[in] lineIndex Line index (same as threadLocalId)
72 * \param[in] realGridSizeFP Local grid size constant
73 * \param[in] fx Input force partial component X
74 * \param[in] fy Input force partial component Y
75 * \param[in] fz Input force partial component Z
77 template<int order, int atomDataSize, int workGroupSize, int subGroupSize>
78 inline void reduceAtomForces(cl::sycl::nd_item<3> itemIdx,
79 cl::sycl::local_ptr<Float3> sm_forces,
80 const int atomIndexLocal,
81 const int splineIndex,
82 const int gmx_unused lineIndex,
83 const float realGridSizeFP[3],
84 float& fx, // NOLINT(google-runtime-references)
85 float& fy, // NOLINT(google-runtime-references)
86 float& fz) // NOLINT(google-runtime-references)
88 static_assert(gmx::isPowerOfTwo(order));
89 // TODO: find out if this is the best in terms of transactions count
90 static_assert(order == 4, "Only order of 4 is implemented");
92 sycl_2020::sub_group sg = itemIdx.get_sub_group();
94 static_assert(atomDataSize <= subGroupSize,
95 "TODO: rework for atomDataSize > subGroupSize (order 8 or larger)");
97 fx += sycl_2020::shift_left(sg, fx, 1);
98 fy += sycl_2020::shift_right(sg, fy, 1);
99 fz += sycl_2020::shift_left(sg, fz, 1);
104 fx += sycl_2020::shift_left(sg, fx, 2);
105 fz += sycl_2020::shift_right(sg, fz, 2);
110 // We have to just further reduce those groups of 4
111 for (int delta = 4; delta < atomDataSize; delta *= 2)
113 fx += sycl_2020::shift_left(sg, fx, delta);
115 const int dimIndex = splineIndex;
118 const float n = realGridSizeFP[dimIndex];
119 sm_forces[atomIndexLocal][dimIndex] = fx * n;
123 /*! \brief Calculate the sum of the force partial components (in X, Y and Z)
125 * \tparam order The PME order (must be 4).
126 * \tparam atomsPerWarp The number of atoms per GPU warp.
127 * \tparam wrapX Tells if the grid is wrapped in the X dimension.
128 * \tparam wrapY Tells if the grid is wrapped in the Y dimension.
129 * \param[out] fx The force partial component in the X dimension.
130 * \param[out] fy The force partial component in the Y dimension.
131 * \param[out] fz The force partial component in the Z dimension.
132 * \param[in] ithyMin The thread minimum index in the Y dimension.
133 * \param[in] ithyMax The thread maximum index in the Y dimension.
134 * \param[in] ixBase The grid line index base value in the X dimension.
135 * \param[in] iz The grid line index in the Z dimension.
136 * \param[in] nx The grid real size in the X dimension.
137 * \param[in] ny The grid real size in the Y dimension.
138 * \param[in] pny The padded grid real size in the Y dimension.
139 * \param[in] pnz The padded grid real size in the Z dimension.
140 * \param[in] atomIndexLocal The atom index for this thread.
141 * \param[in] splineIndexBase The base value of the spline parameter index.
142 * \param[in] tdz The theta and dtheta in the Z dimension.
143 * \param[in] sm_gridlineIndices Shared memory array of grid line indices.
144 * \param[in] sm_theta Shared memory array of atom theta values.
145 * \param[in] sm_dtheta Shared memory array of atom dtheta values.
146 * \param[in] gm_grid Global memory array of the grid to use.
148 template<int order, int atomsPerWarp, bool wrapX, bool wrapY>
149 inline void sumForceComponents(cl::sycl::private_ptr<float> fx,
150 cl::sycl::private_ptr<float> fy,
151 cl::sycl::private_ptr<float> fz,
160 const int atomIndexLocal,
161 const int splineIndexBase,
162 const cl::sycl::float2 tdz,
163 const cl::sycl::local_ptr<int> sm_gridlineIndices,
164 const cl::sycl::local_ptr<float> sm_theta,
165 const cl::sycl::local_ptr<float> sm_dtheta,
166 const cl::sycl::global_ptr<float> gm_grid)
168 for (int ithy = ithyMin; ithy < ithyMax; ithy++)
170 const int splineIndexY = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, YY, ithy);
171 const cl::sycl::float2 tdy{ sm_theta[splineIndexY], sm_dtheta[splineIndexY] };
173 int iy = sm_gridlineIndices[atomIndexLocal * DIM + YY] + ithy;
174 if (wrapY & (iy >= ny))
178 const int constOffset = iy * pnz + iz;
181 for (int ithx = 0; ithx < order; ithx++)
183 int ix = ixBase + ithx;
184 if (wrapX & (ix >= nx))
188 const int gridIndexGlobal = ix * pny * pnz + constOffset;
189 assert(gridIndexGlobal >= 0);
190 const float gridValue = gm_grid[gridIndexGlobal];
191 assertIsFinite(gridValue);
192 const int splineIndexX = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, XX, ithx);
193 const cl::sycl::float2 tdx{ sm_theta[splineIndexX], sm_dtheta[splineIndexX] };
194 const float fxy1 = tdz[XX] * gridValue;
195 const float fz1 = tdz[YY] * gridValue;
196 *fx += tdx[YY] * tdy[XX] * fxy1;
197 *fy += tdx[XX] * tdy[YY] * fxy1;
198 *fz += tdx[XX] * tdy[XX] * fz1;
204 /*! \brief Calculate the grid forces and store them in shared memory.
206 * \param[in,out] sm_forces Shared memory array with the output forces.
207 * \param[in] forceIndexLocal The local (per thread) index in the sm_forces array.
208 * \param[in] forceIndexGlobal The index of the thread in the gm_coefficients array.
209 * \param[in] recipBox0 The reciprocal box (first vector).
210 * \param[in] recipBox1 The reciprocal box (second vector).
211 * \param[in] recipBox2 The reciprocal box (third vector).
212 * \param[in] scale The scale to use when calculating the forces. For gm_coefficientsB
213 * (when using multiple coefficients on a single grid) the scale will be (1.0 - scale).
214 * \param[in] gm_coefficients Global memory array of the coefficients to use for an unperturbed
215 * or FEP in state A if a single grid is used (\p multiCoefficientsSingleGrid == true).If two
216 * separate grids are used this should be the coefficients of the grid in question.
218 inline void calculateAndStoreGridForces(cl::sycl::local_ptr<Float3> sm_forces,
219 const int forceIndexLocal,
220 const int forceIndexGlobal,
221 const Float3& recipBox0,
222 const Float3& recipBox1,
223 const Float3& recipBox2,
225 const cl::sycl::global_ptr<float> gm_coefficients)
227 const Float3 atomForces = sm_forces[forceIndexLocal];
228 float negCoefficient = -scale * gm_coefficients[forceIndexGlobal];
230 result[XX] = negCoefficient * recipBox0[XX] * atomForces[XX];
231 result[YY] = negCoefficient * (recipBox0[YY] * atomForces[XX] + recipBox1[YY] * atomForces[YY]);
232 result[ZZ] = negCoefficient
233 * (recipBox0[ZZ] * atomForces[XX] + recipBox1[ZZ] * atomForces[YY]
234 + recipBox2[ZZ] * atomForces[ZZ]);
235 sm_forces[forceIndexLocal] = result;
239 * A SYCL kernel which gathers the atom forces from the grid.
240 * The grid is assumed to be wrapped in dimension Z.
242 * \tparam order PME interpolation order.
243 * \tparam wrapX A boolean which tells if the grid overlap in dimension X should
245 * \tparam wrapY A boolean which tells if the grid overlap in dimension Y should
247 * \tparam numGrids The number of grids to use in the kernel. Can be 1 or 2.
248 * \tparam writeGlobal Tells if we should read spline values from global memory.
249 * \tparam threadsPerAtom How many threads work on each atom.
250 * \tparam subGroupSize Size of the sub-group.
252 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
253 auto pmeGatherKernel(cl::sycl::handler& cgh,
255 DeviceAccessor<float, mode::read> a_gridA,
256 OptionalAccessor<float, mode::read, numGrids == 2> a_gridB,
257 DeviceAccessor<float, mode::read> a_coefficientsA,
258 OptionalAccessor<float, mode::read, numGrids == 2> a_coefficientsB,
259 OptionalAccessor<Float3, mode::read, !readGlobal> a_coordinates,
260 DeviceAccessor<Float3, mode::read_write> a_forces,
261 DeviceAccessor<float, mode::read> a_theta,
262 DeviceAccessor<float, mode::read> a_dtheta,
263 DeviceAccessor<int, mode::read> a_gridlineIndices,
264 OptionalAccessor<float, mode::read, !readGlobal> a_fractShiftsTable,
265 OptionalAccessor<int, mode::read, !readGlobal> a_gridlineIndicesTable,
266 const gmx::IVec tablesOffsets,
267 const gmx::IVec realGridSize,
268 const gmx::RVec realGridSizeFP,
269 const gmx::IVec realGridSizePadded,
270 const gmx::RVec currentRecipBox0,
271 const gmx::RVec currentRecipBox1,
272 const gmx::RVec currentRecipBox2,
275 static_assert(numGrids == 1 || numGrids == 2);
277 constexpr int threadsPerAtomValue = (threadsPerAtom == ThreadsPerAtom::Order) ? order : order * order;
278 constexpr int atomDataSize = threadsPerAtomValue;
279 constexpr int atomsPerBlock = (c_gatherMaxWarpsPerBlock * subGroupSize) / atomDataSize;
280 // Number of atoms processed by a single warp in spread and gather
281 static_assert(subGroupSize >= atomDataSize);
282 constexpr int atomsPerWarp = subGroupSize / atomDataSize;
283 constexpr int blockSize = atomsPerBlock * atomDataSize;
284 constexpr int splineParamsSize = atomsPerBlock * DIM * order;
285 constexpr int gridlineIndicesSize = atomsPerBlock * DIM;
288 a_coefficientsA.bind(cgh);
291 if constexpr (numGrids == 2)
294 a_coefficientsB.bind(cgh);
297 if constexpr (readGlobal)
301 a_gridlineIndices.bind(cgh);
305 a_coordinates.bind(cgh);
306 a_fractShiftsTable.bind(cgh);
307 a_gridlineIndicesTable.bind(cgh);
310 // Gridline indices, ivec
311 cl::sycl::accessor<int, 1, mode::read_write, target::local> sm_gridlineIndices(
312 cl::sycl::range<1>(atomsPerBlock * DIM), cgh);
314 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_theta(
315 cl::sycl::range<1>(atomsPerBlock * DIM * order), cgh);
316 // Spline derivatives
317 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_dtheta(
318 cl::sycl::range<1>(atomsPerBlock * DIM * order), cgh);
319 // Coefficients prefetch cache
320 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_coefficients(
321 cl::sycl::range<1>(atomsPerBlock), cgh);
322 // Coordinates prefetch cache
323 cl::sycl::accessor<Float3, 1, mode::read_write, target::local> sm_coordinates(
324 cl::sycl::range<1>(atomsPerBlock), cgh);
325 // Reduction of partial force contributions
326 cl::sycl::accessor<Float3, 1, mode::read_write, target::local> sm_forces(
327 cl::sycl::range<1>(atomsPerBlock), cgh);
329 auto sm_fractCoords = [&]() {
330 if constexpr (!readGlobal)
332 return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
333 cl::sycl::range<1>(atomsPerBlock * DIM), cgh);
341 return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
343 assert(blockSize == itemIdx.get_local_range().size());
344 /* These are the atom indices - for the shared and global memory */
345 const int atomIndexLocal = itemIdx.get_local_id(XX);
346 const int blockIndex =
347 itemIdx.get_group(YY) * itemIdx.get_group_range(ZZ) + itemIdx.get_group(ZZ);
348 // itemIdx.get_group_linear_id();
350 const int atomIndexOffset = blockIndex * atomsPerBlock;
351 const int atomIndexGlobal = atomIndexOffset + atomIndexLocal;
352 /* Early return for fully empty blocks at the end
353 * (should only happen for billions of input atoms)
355 if (atomIndexOffset >= nAtoms)
359 /* Spline Z coordinates */
360 const int ithz = itemIdx.get_local_id(ZZ);
361 /* These are the spline contribution indices in shared memory */
362 const int splineIndex =
363 itemIdx.get_local_id(YY) * itemIdx.get_local_range(ZZ) + itemIdx.get_local_id(ZZ);
365 const int threadLocalId = itemIdx.get_local_linear_id();
366 const int threadLocalIdMax = blockSize;
367 assert(threadLocalId < threadLocalIdMax);
369 const int lineIndex =
370 (itemIdx.get_local_id(XX) * (itemIdx.get_local_range(ZZ) * itemIdx.get_local_range(YY)))
371 + splineIndex; // And to all the block's particles
372 assert(lineIndex == threadLocalId);
374 if constexpr (readGlobal)
377 const int localGridlineIndicesIndex = threadLocalId;
378 const int globalGridlineIndicesIndex =
379 blockIndex * gridlineIndicesSize + localGridlineIndicesIndex;
380 // itemIdx.get_group(ZZ) * gridlineIndicesSize + localGridlineIndicesIndex;
381 if (localGridlineIndicesIndex < gridlineIndicesSize)
383 sm_gridlineIndices[localGridlineIndicesIndex] =
384 a_gridlineIndices[globalGridlineIndicesIndex];
385 assert(sm_gridlineIndices[localGridlineIndicesIndex] >= 0);
387 /* The loop needed for order threads per atom to make sure we load all data values, as each thread must load multiple values
388 with order*order threads per atom, it is only required for each thread to load one data value */
391 const int iMax = (threadsPerAtom == ThreadsPerAtom::Order) ? 3 : 1;
393 for (int i = iMin; i < iMax; i++)
395 // i will always be zero for order*order threads per atom
396 const int localSplineParamsIndex = threadLocalId + i * threadLocalIdMax;
397 const int globalSplineParamsIndex = blockIndex * splineParamsSize + localSplineParamsIndex;
398 // const int globalSplineParamsIndex = itemIdx.get_group(ZZ) * splineParamsSize + localSplineParamsIndex;
399 if (localSplineParamsIndex < splineParamsSize)
401 sm_theta[localSplineParamsIndex] = a_theta[globalSplineParamsIndex];
402 sm_dtheta[localSplineParamsIndex] = a_dtheta[globalSplineParamsIndex];
403 assertIsFinite(sm_theta[localSplineParamsIndex]);
404 assertIsFinite(sm_dtheta[localSplineParamsIndex]);
408 itemIdx.barrier(fence_space::local_space);
412 /* Recalculate Splines */
413 /* Staging coefficients/charges */
414 pmeGpuStageAtomData<float, atomsPerBlock, 1>(
415 sm_coefficients.get_pointer(), a_coefficientsA.get_pointer(), itemIdx);
416 /* Staging coordinates */
417 pmeGpuStageAtomData<Float3, atomsPerBlock, 1>(
418 sm_coordinates.get_pointer(), a_coordinates.get_pointer(), itemIdx);
419 itemIdx.barrier(fence_space::local_space);
420 const Float3 atomX = sm_coordinates[atomIndexLocal];
421 const float atomCharge = sm_coefficients[atomIndexLocal];
423 calculateSplines<order, atomsPerBlock, atomsPerWarp, true, false, numGrids, subGroupSize>(
435 a_fractShiftsTable.get_pointer(),
436 a_gridlineIndicesTable.get_pointer(),
437 sm_theta.get_pointer(),
438 sm_dtheta.get_pointer(),
439 sm_gridlineIndices.get_pointer(),
440 sm_fractCoords.get_pointer(),
442 subGroupBarrier(itemIdx);
448 const int chargeCheck = pmeGpuCheckAtomCharge(a_coefficientsA[atomIndexGlobal]);
450 const int nx = realGridSize[XX];
451 const int ny = realGridSize[YY];
452 const int nz = realGridSize[ZZ];
453 const int pny = realGridSizePadded[YY];
454 const int pnz = realGridSizePadded[ZZ];
456 const int atomWarpIndex = atomIndexLocal % atomsPerWarp;
457 const int warpIndex = atomIndexLocal / atomsPerWarp;
459 const int splineIndexBase = getSplineParamIndexBase<order, atomsPerWarp>(warpIndex, atomWarpIndex);
460 const int splineIndexZ = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, ZZ, ithz);
461 const cl::sycl::float2 tdz{ sm_theta[splineIndexZ], sm_dtheta[splineIndexZ] };
463 int iz = sm_gridlineIndices[atomIndexLocal * DIM + ZZ] + ithz;
464 const int ixBase = sm_gridlineIndices[atomIndexLocal * DIM + XX];
471 const int ithyMin = (threadsPerAtom == ThreadsPerAtom::Order) ? 0 : itemIdx.get_local_id(YY);
473 (threadsPerAtom == ThreadsPerAtom::Order) ? order : itemIdx.get_local_id(YY) + 1;
476 sumForceComponents<order, atomsPerWarp, wrapX, wrapY>(&fx,
490 sm_gridlineIndices.get_pointer(),
491 sm_theta.get_pointer(),
492 sm_dtheta.get_pointer(),
493 a_gridA.get_pointer());
495 reduceAtomForces<order, atomDataSize, blockSize, subGroupSize>(
496 itemIdx, sm_forces.get_pointer(), atomIndexLocal, splineIndex, lineIndex, realGridSizeFP, fx, fy, fz);
497 itemIdx.barrier(fence_space::local_space);
499 /* Calculating the final forces with no component branching, atomsPerBlock threads */
500 const int forceIndexLocal = threadLocalId;
501 const int forceIndexGlobal = atomIndexOffset + forceIndexLocal;
502 if (forceIndexLocal < atomsPerBlock)
504 calculateAndStoreGridForces(sm_forces.get_pointer(),
511 a_coefficientsA.get_pointer());
513 itemIdx.barrier(fence_space::local_space);
515 static_assert(atomsPerBlock <= subGroupSize);
517 /* Writing or adding the final forces component-wise, single warp */
518 constexpr int blockForcesSize = atomsPerBlock * DIM;
519 constexpr int numIter = (blockForcesSize + subGroupSize - 1) / subGroupSize;
520 constexpr int iterThreads = blockForcesSize / numIter;
521 if (threadLocalId < iterThreads)
524 for (int i = 0; i < numIter; i++)
526 const int floatIndexLocal = i * iterThreads + threadLocalId;
527 const int float3IndexLocal = floatIndexLocal / 3;
528 const int dimLocal = floatIndexLocal % 3;
529 static_assert(blockForcesSize % DIM == 0); // Assures that dimGlobal == dimLocal
530 const int float3IndexGlobal = blockIndex * atomsPerBlock + float3IndexLocal;
531 // const int float3IndexGlobal = itemIdx.get_group(ZZ) * atomsPerBlock + float3IndexLocal;
532 a_forces[float3IndexGlobal][dimLocal] = sm_forces[float3IndexLocal][dimLocal];
536 if constexpr (numGrids == 2)
538 /* We must sync here since the same shared memory is used as above. */
539 itemIdx.barrier(fence_space::local_space);
543 const bool chargeCheck = pmeGpuCheckAtomCharge(a_coefficientsB[atomIndexGlobal]);
546 sumForceComponents<order, atomsPerWarp, wrapX, wrapY>(&fx,
560 sm_gridlineIndices.get_pointer(),
561 sm_theta.get_pointer(),
562 sm_dtheta.get_pointer(),
563 a_gridB.get_pointer());
565 // Reduction of partial force contributions
566 reduceAtomForces<order, atomDataSize, blockSize, subGroupSize>(
567 itemIdx, sm_forces.get_pointer(), atomIndexLocal, splineIndex, lineIndex, realGridSizeFP, fx, fy, fz);
568 itemIdx.barrier(fence_space::local_space);
570 /* Calculating the final forces with no component branching, atomsPerBlock threads */
571 if (forceIndexLocal < atomsPerBlock)
573 calculateAndStoreGridForces(sm_forces.get_pointer(),
580 a_coefficientsB.get_pointer());
583 itemIdx.barrier(fence_space::local_space);
585 /* Writing or adding the final forces component-wise, single warp */
586 if (threadLocalId < iterThreads)
589 for (int i = 0; i < numIter; i++)
591 const int floatIndexLocal = i * iterThreads + threadLocalId;
592 const int float3IndexLocal = floatIndexLocal / 3;
593 const int dimLocal = floatIndexLocal % 3;
594 static_assert(blockForcesSize % DIM == 0); // Assures that dimGlobal == dimLocal
595 const int float3IndexGlobal = blockIndex * atomsPerBlock + float3IndexLocal;
596 a_forces[float3IndexGlobal][dimLocal] += sm_forces[float3IndexLocal][dimLocal];
603 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
604 PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::PmeGatherKernel()
609 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
610 void PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::setArg(
616 auto* params = reinterpret_cast<PmeGpuKernelParams*>(arg);
617 gridParams_ = ¶ms->grid;
618 atomParams_ = ¶ms->atoms;
619 dynamicParams_ = ¶ms->current;
623 GMX_RELEASE_ASSERT(argIndex == 0, "Trying to pass too many args to the kernel");
628 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
629 cl::sycl::event PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::launch(
630 const KernelLaunchConfig& config,
631 const DeviceStream& deviceStream)
633 GMX_RELEASE_ASSERT(gridParams_, "Can not launch the kernel before setting its args");
634 GMX_RELEASE_ASSERT(atomParams_, "Can not launch the kernel before setting its args");
635 GMX_RELEASE_ASSERT(dynamicParams_, "Can not launch the kernel before setting its args");
637 using kernelNameType =
638 PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>;
640 // SYCL has different multidimensional layout than OpenCL/CUDA.
641 const cl::sycl::range<3> localSize{ config.blockSize[2], config.blockSize[1], config.blockSize[0] };
642 const cl::sycl::range<3> groupRange{ config.gridSize[2], config.gridSize[1], config.gridSize[0] };
643 const cl::sycl::nd_range<3> range{ groupRange * localSize, localSize };
645 cl::sycl::queue q = deviceStream.stream();
647 cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
648 auto kernel = pmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>(
651 gridParams_->d_realGrid[0],
652 gridParams_->d_realGrid[1],
653 atomParams_->d_coefficients[0],
654 atomParams_->d_coefficients[1],
655 atomParams_->d_coordinates,
656 atomParams_->d_forces,
657 atomParams_->d_theta,
658 atomParams_->d_dtheta,
659 atomParams_->d_gridlineIndices,
660 gridParams_->d_fractShiftsTable,
661 gridParams_->d_gridlineIndicesTable,
662 gridParams_->tablesOffsets,
663 gridParams_->realGridSize,
664 gridParams_->realGridSizeFP,
665 gridParams_->realGridSizePadded,
666 dynamicParams_->recipBox[0],
667 dynamicParams_->recipBox[1],
668 dynamicParams_->recipBox[2],
669 dynamicParams_->scale);
670 cgh.parallel_for<kernelNameType>(range, kernel);
673 // Delete set args, so we don't forget to set them before the next launch.
680 template<int order, bool wrapX, bool wrapY, int numGrids, bool readGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
681 void PmeGatherKernel<order, wrapX, wrapY, numGrids, readGlobal, threadsPerAtom, subGroupSize>::reset()
683 gridParams_ = nullptr;
684 atomParams_ = nullptr;
685 dynamicParams_ = nullptr;
689 //! Kernel instantiations
690 /* Disable the "explicit template instantiation 'PmeSplineAndSpreadKernel<...>' will emit a vtable in every
691 * translation unit [-Wweak-template-vtables]" warning.
692 * It is only explicitly instantiated in this translation unit, so we should be safe.
695 # pragma clang diagnostic push
696 # pragma clang diagnostic ignored "-Wweak-template-vtables"
699 #define INSTANTIATE_3(order, numGrids, readGlobal, threadsPerAtom, subGroupSize) \
700 template class PmeGatherKernel<order, true, true, numGrids, readGlobal, threadsPerAtom, subGroupSize>;
702 #define INSTANTIATE_2(order, numGrids, threadsPerAtom, subGroupSize) \
703 INSTANTIATE_3(order, numGrids, true, threadsPerAtom, subGroupSize); \
704 INSTANTIATE_3(order, numGrids, false, threadsPerAtom, subGroupSize);
706 #define INSTANTIATE(order, subGroupSize) \
707 INSTANTIATE_2(order, 1, ThreadsPerAtom::Order, subGroupSize); \
708 INSTANTIATE_2(order, 1, ThreadsPerAtom::OrderSquared, subGroupSize); \
709 INSTANTIATE_2(order, 2, ThreadsPerAtom::Order, subGroupSize); \
710 INSTANTIATE_2(order, 2, ThreadsPerAtom::OrderSquared, subGroupSize);
713 INSTANTIATE(4, 16); // TODO: Choose best value, Issue #4153.
714 #elif GMX_SYCL_HIPSYCL
720 # pragma clang diagnostic pop