SYCL PME Spread kernel
[alexxy/gromacs.git] / src / gromacs / ewald / pme_gpu_calculate_splines_sycl.h
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements helper routines for PME gather and spline routines.
38  *
39  *  \author Andrey Alekseenko <al42and@gmail.com>
40  */
41
42 #include "gmxpre.h"
43
44 #include <cassert>
45
46 #include "gromacs/gpu_utils/gmxsycl.h"
47 #include "gromacs/gpu_utils/gputraits_sycl.h"
48 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
49
50 #include "pme_grid.h"
51 #include "pme_gpu_constants.h"
52 #include "pme_gpu_types.h"
53
54 namespace
55 {
56
57 /*! \brief Asserts if the argument is finite.
58  *
59  *  The function works for any data type, that can be casted to float. Note that there is also
60  *  a specialized implementation for float3 data type.
61  *
62  * \param[in] arg  Argument to check.
63  */
64 template<typename T>
65 inline void assertIsFinite(T arg);
66
67 #if defined(NDEBUG) || GMX_SYCL_HIPSYCL
68 // We have no cl::sycl::isfinite in hipSYCL yet
69 template<typename T>
70 inline void assertIsFinite(T /* arg */)
71 {
72 }
73 #else
74 template<>
75 inline void assertIsFinite(Float3 gmx_used_in_debug arg)
76 {
77     assert(cl::sycl::isfinite(arg[0]));
78     assert(cl::sycl::isfinite(arg[1]));
79     assert(cl::sycl::isfinite(arg[2]));
80 }
81
82 template<typename T>
83 inline void assertIsFinite(T gmx_used_in_debug arg)
84 {
85     assert(cl::sycl::isfinite(static_cast<float>(arg)));
86 }
87 #endif
88
89 } // namespace
90
91 using cl::sycl::access::fence_space;
92 using cl::sycl::access::mode;
93 using cl::sycl::access::target;
94
95 /*! \internal \brief
96  * Gets a base of the unique index to an element in a spline parameter buffer (theta/dtheta),
97  * which is laid out for GPU spread/gather kernels. The base only corresponds to the atom index within the execution block.
98  * Feed the result into getSplineParamIndex() to get a full index.
99  * TODO: it's likely that both parameters can be just replaced with a single atom index, as they are derived from it.
100  * Do that, verifying that the generated code is not bloated, and/or revise the spline indexing scheme.
101  * Removing warp dependency would also be nice (and would probably coincide with removing c_pmeSpreadGatherAtomsPerWarp).
102  *
103  * \tparam order                 PME order
104  * \tparam atomsPerSubGroup      Number of atoms processed by a sub group
105  * \param[in] subGroupIndex      Sub group index in the work group.
106  * \param[in] atomSubGroupIndex  Atom index in the sub group (from 0 to atomsPerSubGroup - 1).
107  *
108  * \returns Index into theta or dtheta array using GPU layout.
109  */
110 template<int order, int atomsPerSubGroup>
111 static inline int getSplineParamIndexBase(int subGroupIndex, int atomSubGroupIndex)
112 {
113     assert((atomSubGroupIndex >= 0) && (atomSubGroupIndex < atomsPerSubGroup));
114     constexpr int dimIndex    = 0;
115     constexpr int splineIndex = 0;
116     // The zeroes are here to preserve the full index formula for reference
117     return (((splineIndex + order * subGroupIndex) * DIM + dimIndex) * atomsPerSubGroup + atomSubGroupIndex);
118 }
119
120 /*! \internal \brief
121  * Gets a unique index to an element in a spline parameter buffer (theta/dtheta),
122  * which is laid out for GPU spread/gather kernels. The index is wrt to the execution block,
123  * in range(0, atomsPerBlock * order * DIM).
124  * This function consumes result of getSplineParamIndexBase() and adjusts it for \p dimIndex and \p splineIndex.
125  *
126  * \tparam order               PME order
127  * \tparam atomsPerSubGroup    Number of atoms processed by a sub group
128  * \param[in] paramIndexBase   Must be result of getSplineParamIndexBase().
129  * \param[in] dimIndex         Dimension index (from 0 to 2)
130  * \param[in] splineIndex      Spline contribution index (from 0 to \p order - 1)
131  *
132  * \returns Index into theta or dtheta array using GPU layout.
133  */
134 template<int order, int atomsPerSubGroup>
135 static inline int getSplineParamIndex(int paramIndexBase, int dimIndex, int splineIndex)
136 {
137     assert((dimIndex >= XX) && (dimIndex < DIM));
138     assert((splineIndex >= 0) && (splineIndex < order));
139     return (paramIndexBase + (splineIndex * DIM + dimIndex) * atomsPerSubGroup);
140 }
141
142 /*! \internal \brief
143  * An inline function for skipping the zero-charge atoms when we have \c c_skipNeutralAtoms set to \c true.
144  *
145  * \returns                   \c true if atom should be processed, \c false otherwise.
146  * \param[in] charge          The atom charge.
147  */
148 static inline bool pmeGpuCheckAtomCharge(const float charge)
149 {
150     assertIsFinite(charge);
151     return c_skipNeutralAtoms ? (charge != 0.0F) : true;
152 }
153
154 /*! \brief
155  * General purpose function for loading atom-related data from global to shared memory.
156  *
157  * \tparam T Data type (float/int/...).
158  * \tparam atomsPerWorkGroup Number of atoms processed by a block - should be accounted for
159  *                           in the size of the shared memory array.
160  * \tparam dataCountPerAtom Number of data elements
161  *                          per single atom (e.g. \c DIM for an rvec coordinates array).
162  * \param[out] sm_destination Shared memory array for output.
163  * \param[in]  gm_source Global memory array for input.
164  * \param[in]  itemIdx SYCL thread ID.
165  */
166 template<typename T, int atomsPerWorkGroup, int dataCountPerAtom>
167 static inline void pmeGpuStageAtomData(cl::sycl::local_ptr<T>        sm_destination,
168                                        const cl::sycl::global_ptr<T> gm_source,
169                                        cl::sycl::nd_item<3>          itemIdx)
170 {
171     const int blockIndex      = itemIdx.get_group_linear_id();
172     const int localIndex      = itemIdx.get_local_linear_id();
173     const int globalIndexBase = blockIndex * atomsPerWorkGroup * dataCountPerAtom;
174     const int globalIndex     = globalIndexBase + localIndex;
175     if (localIndex < atomsPerWorkGroup * dataCountPerAtom)
176     {
177         assertIsFinite(gm_source[globalIndex]);
178         sm_destination[localIndex] = gm_source[globalIndex];
179     }
180 }
181
182 /*! \brief
183  * PME GPU spline parameter and gridline indices calculation.
184  * This corresponds to the CPU functions calc_interpolation_idx() and make_bsplines().
185  * First stage of the whole kernel.
186  *
187  * \tparam order                PME interpolation order.
188  * \tparam atomsPerBlock        Number of atoms processed by a block - should be accounted for
189  *                              in the sizes of the shared memory arrays.
190  * \tparam atomsPerWarp         Number of atoms processed by a warp
191  * \tparam writeSmDtheta        Bool controlling if the theta derivative should be written to
192  *                              shared memory. Enables calculation of dtheta if set.
193  * \tparam writeGlobal          A boolean which tells if the theta values and gridlines should
194  *                              be written to global memory. Enables calculation of dtheta if set.
195  * \tparam numGrids             The number of grids using the splines.
196  * \tparam subGroupSize         The size of a sub-group (warp).
197  * \param[in]  atomIndexOffset        Starting atom index for the execution block in the global
198  *                                    memory.
199  * \param[in]  atomX                  Coordinates of atom processed by thread.
200  * \param[in]  atomCharge             Charge/coefficient of atom processed by thread.
201  * \param[in]  tablesOffsets          Offsets for X/Y/Z components of \p gm_fractShiftsTable and
202  *                                    \p gm_gridlineIndicesTable.
203  * \param[in]  realGridSizeFP         Real-space grid dimensions, converted to floating point.
204  * \param[in]  currentRecipBox0       Current reciprocal (inverted unit cell) box, vector 1.
205  * \param[in]  currentRecipBox1       Current reciprocal (inverted unit cell) box, vector 2.
206  * \param[in]  currentRecipBox2       Current reciprocal (inverted unit cell) box, vector 3.
207  * \param[out] gm_theta               Atom spline values in the global memory.
208  *                                    Used only if \p writeGlobal is \c true.
209  * \param[out] gm_dtheta              Derivatives of atom spline values in the global memory.
210  *                                    Used only if \p writeGlobal is \c true.
211  * \param[out] gm_gridlineIndices     Atom gridline indices in the global memory.
212  *                                    Used only if \p writeGlobal is \c true.
213  * \param[in] gm_fractShiftsTable     Fractional shifts lookup table in the global memory.
214  * \param[in] gm_gridlineIndicesTable Gridline indices lookup table in the global memory.
215  * \param[out] sm_theta               Atom spline values in the local memory.
216  * \param[out] sm_dtheta              Derivatives of atom spline values in the local memory.
217  * \param[out] sm_gridlineIndices     Atom gridline indices in the local memory.
218  * \param[out] sm_fractCoords         Fractional coordinates in the local memory.
219  * \param[in]  itemIdx                SYCL thread ID.
220  */
221
222 template<int order, int atomsPerBlock, int atomsPerWarp, bool writeSmDtheta, bool writeGlobal, int numGrids, int subGroupSize>
223 static inline void calculateSplines(const int                         atomIndexOffset,
224                                     const Float3                      atomX,
225                                     const float                       atomCharge,
226                                     const gmx::IVec                   tablesOffsets,
227                                     const gmx::RVec                   realGridSizeFP,
228                                     const gmx::RVec                   currentRecipBox0,
229                                     const gmx::RVec                   currentRecipBox1,
230                                     const gmx::RVec                   currentRecipBox2,
231                                     cl::sycl::global_ptr<float>       gm_theta,
232                                     cl::sycl::global_ptr<float>       gm_dtheta,
233                                     cl::sycl::global_ptr<int>         gm_gridlineIndices,
234                                     const cl::sycl::global_ptr<float> gm_fractShiftsTable,
235                                     const cl::sycl::global_ptr<int>   gm_gridlineIndicesTable,
236                                     cl::sycl::local_ptr<float>        sm_theta,
237                                     cl::sycl::local_ptr<float>        sm_dtheta,
238                                     cl::sycl::local_ptr<int>          sm_gridlineIndices,
239                                     cl::sycl::local_ptr<float>        sm_fractCoords,
240                                     cl::sycl::nd_item<3>              itemIdx)
241 {
242     static_assert(numGrids == 1 || numGrids == 2);
243     static_assert(numGrids == 1 || c_skipNeutralAtoms == false);
244
245     /* Thread index w.r.t. block */
246     const int threadLocalId = itemIdx.get_local_linear_id();
247     /* Warp index w.r.t. block - could probably be obtained easier? */
248     const int warpIndex = threadLocalId / subGroupSize;
249     /* Atom index w.r.t. warp - alternating 0 1 0 1 ... */
250     const int atomWarpIndex = itemIdx.get_local_id(0) % atomsPerWarp;
251     /* Atom index w.r.t. block/shared memory */
252     const int atomIndexLocal = warpIndex * atomsPerWarp + atomWarpIndex;
253
254     /* Spline contribution index in one dimension */
255     const int threadLocalIdXY =
256             (itemIdx.get_local_id(1) * itemIdx.get_group_range(2)) + itemIdx.get_local_id(2);
257     const int orderIndex = threadLocalIdXY / DIM;
258     /* Dimension index */
259     const int dimIndex = threadLocalIdXY % DIM;
260
261     /* Multi-purpose index of rvec/ivec atom data */
262     const int sharedMemoryIndex = atomIndexLocal * DIM + dimIndex;
263
264     float splineData[order];
265
266     const int localCheck = (dimIndex < DIM) && (orderIndex < 1);
267
268     /* we have 4 threads per atom, but can only use 3 here for the dimensions */
269     if (localCheck)
270     {
271         /* Indices interpolation */
272         if (orderIndex == 0)
273         {
274             int   tableIndex, tInt;
275             float n, t;
276             assert(atomIndexLocal < DIM * atomsPerBlock);
277             // Switch structure inherited from CUDA.
278             // TODO: Issue #4153: Direct indexing with dimIndex can be better with SYCL
279             switch (dimIndex)
280             {
281                 case XX:
282                     tableIndex = tablesOffsets[XX];
283                     n          = realGridSizeFP[XX];
284                     t          = atomX[XX] * currentRecipBox0[XX] + atomX[YY] * currentRecipBox0[YY]
285                         + atomX[ZZ] * currentRecipBox0[ZZ];
286                     break;
287
288                 case YY:
289                     tableIndex = tablesOffsets[YY];
290                     n          = realGridSizeFP[YY];
291                     t = atomX[YY] * currentRecipBox1[YY] + atomX[ZZ] * currentRecipBox1[ZZ];
292                     break;
293
294                 case ZZ:
295                     tableIndex = tablesOffsets[ZZ];
296                     n          = realGridSizeFP[ZZ];
297                     t          = atomX[ZZ] * currentRecipBox2[ZZ];
298                     break;
299             }
300             const float shift = c_pmeMaxUnitcellShift;
301             /* Fractional coordinates along box vectors, adding a positive shift to ensure t is positive for triclinic boxes */
302             t    = (t + shift) * n;
303             tInt = static_cast<int>(t);
304             assert(sharedMemoryIndex < atomsPerBlock * DIM);
305             sm_fractCoords[sharedMemoryIndex] = t - tInt;
306             tableIndex += tInt;
307             assert(tInt >= 0);
308             assert(tInt < c_pmeNeighborUnitcellCount * n);
309
310             // TODO: Issue #4153: use shared table for both parameters to share the fetch, as index is always same.
311             sm_fractCoords[sharedMemoryIndex] += gm_fractShiftsTable[tableIndex];
312             sm_gridlineIndices[sharedMemoryIndex] = gm_gridlineIndicesTable[tableIndex];
313             if constexpr (writeGlobal)
314             {
315                 gm_gridlineIndices[atomIndexOffset * DIM + sharedMemoryIndex] =
316                         sm_gridlineIndices[sharedMemoryIndex];
317             }
318         }
319
320         /* B-spline calculation */
321         const int chargeCheck = pmeGpuCheckAtomCharge(atomCharge);
322         /* With FEP (numGrids == 2), we might have 0 charge in state A, but !=0 in state B, so we always calculate splines */
323         if (numGrids == 2 || chargeCheck)
324         {
325             const float dr = sm_fractCoords[sharedMemoryIndex];
326             assertIsFinite(dr);
327
328             /* dr is relative offset from lower cell limit */
329             splineData[order - 1] = 0.0F;
330             splineData[1]         = dr;
331             splineData[0]         = 1.0F - dr;
332
333 #pragma unroll
334             for (int k = 3; k < order; k++)
335             {
336                 const float div   = 1.0F / (k - 1.0F);
337                 splineData[k - 1] = div * dr * splineData[k - 2];
338 #pragma unroll
339                 for (int l = 1; l < (k - 1); l++)
340                 {
341                     splineData[k - l - 1] =
342                             div * ((dr + l) * splineData[k - l - 2] + (k - l - dr) * splineData[k - l - 1]);
343                 }
344                 splineData[0] = div * (1.0F - dr) * splineData[0];
345             }
346
347             const int thetaIndexBase =
348                     getSplineParamIndexBase<order, atomsPerWarp>(warpIndex, atomWarpIndex);
349             const int thetaGlobalOffsetBase = atomIndexOffset * DIM * order;
350             /* only calculate dtheta if we are saving it to shared or global memory */
351             if constexpr (writeSmDtheta || writeGlobal)
352             {
353                 /* Differentiation and storing the spline derivatives (dtheta) */
354 #pragma unroll
355                 for (int o = 0; o < order; o++)
356                 {
357                     const int thetaIndex =
358                             getSplineParamIndex<order, atomsPerWarp>(thetaIndexBase, dimIndex, o);
359
360                     const float dtheta = ((o > 0) ? splineData[o - 1] : 0.0F) - splineData[o];
361                     assertIsFinite(dtheta);
362                     assert(thetaIndex < order * DIM * atomsPerBlock);
363                     if constexpr (writeSmDtheta)
364                     {
365                         sm_dtheta[thetaIndex] = dtheta;
366                     }
367                     if constexpr (writeGlobal)
368                     {
369                         const int thetaGlobalIndex  = thetaGlobalOffsetBase + thetaIndex;
370                         gm_dtheta[thetaGlobalIndex] = dtheta;
371                     }
372                 }
373             }
374
375             const float div       = 1.0F / (order - 1.0F);
376             splineData[order - 1] = div * dr * splineData[order - 2];
377 #pragma unroll
378             for (int k = 1; k < (order - 1); k++)
379             {
380                 splineData[order - k - 1] = div
381                                             * ((dr + k) * splineData[order - k - 2]
382                                                + (order - k - dr) * splineData[order - k - 1]);
383             }
384             splineData[0] = div * (1.0F - dr) * splineData[0];
385
386             /* Storing the spline values (theta) */
387 #pragma unroll
388             for (int o = 0; o < order; o++)
389             {
390                 const int thetaIndex =
391                         getSplineParamIndex<order, atomsPerWarp>(thetaIndexBase, dimIndex, o);
392                 assert(thetaIndex < order * DIM * atomsPerBlock);
393                 sm_theta[thetaIndex] = splineData[o];
394                 assertIsFinite(sm_theta[thetaIndex]);
395                 if constexpr (writeGlobal)
396                 {
397                     const int thetaGlobalIndex = thetaGlobalOffsetBase + thetaIndex;
398                     gm_theta[thetaGlobalIndex] = splineData[o];
399                 }
400             }
401         }
402     }
403 }