c88c1de374ad0b094852f16da9ddf281e285688a
[alexxy/gromacs.git] / src / gromacs / ewald / pme_spread_sycl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements PME GPU spline calculation and charge spreading in SYCL.
38  *
39  *  \author Andrey Alekseenko <al42and@gmail.com>
40  */
41
42 #include "gmxpre.h"
43
44 #include "pme_spread_sycl.h"
45
46 #include "gromacs/gpu_utils/gmxsycl.h"
47 #include "gromacs/gpu_utils/gputraits_sycl.h"
48 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
49 #include "gromacs/gpu_utils/syclutils.h"
50
51 #include "pme_gpu_calculate_splines_sycl.h"
52 #include "pme_grid.h"
53 #include "pme_gpu_types_host.h"
54
55 /*! \brief
56  * Charge spreading onto the grid.
57  * This corresponds to the CPU function spread_coefficients_bsplines_thread().
58  * Optional second stage of the spline_and_spread_kernel.
59  *
60  * \tparam     order                PME interpolation order.
61  * \tparam     wrapX                Whether the grid overlap in dimension X should be wrapped.
62  * \tparam     wrapY                Whether the grid overlap in dimension Y should be wrapped.
63  * \tparam     threadsPerAtom       How many threads work on each atom.
64  * \tparam     subGroupSize         Size of the sub-group.
65  *
66  * \param[in]  atomCharge           Atom charge/coefficient of atom processed by thread.
67  * \param[in]  realGridSize         Size of the real grid.
68  * \param[in]  realGridSizePadded   Padded of the real grid.
69  * \param[in,out]  gm_grid          Device pointer to the real grid to which charges are added.
70  * \param[in]  sm_gridlineIndices   Atom gridline indices in the local memory.
71  * \param[in]  sm_theta             Atom spline values in the local memory.
72  * \param[in]  itemIdx              Current thread ID.
73  */
74 template<int order, bool wrapX, bool wrapY, ThreadsPerAtom threadsPerAtom, int subGroupSize>
75 inline void spread_charges(const float                      atomCharge,
76                            const int                        realGridSize[DIM],
77                            const int                        realGridSizePadded[DIM],
78                            cl::sycl::global_ptr<float>      gm_grid,
79                            const cl::sycl::local_ptr<int>   sm_gridlineIndices,
80                            const cl::sycl::local_ptr<float> sm_theta,
81                            const cl::sycl::nd_item<3>&      itemIdx)
82 {
83     //! Number of atoms processed by a single warp in spread and gather
84     const int threadsPerAtomValue = (threadsPerAtom == ThreadsPerAtom::Order) ? order : order * order;
85     const int atomsPerWarp        = subGroupSize / threadsPerAtomValue;
86
87     const int nx  = realGridSize[XX];
88     const int ny  = realGridSize[YY];
89     const int nz  = realGridSize[ZZ];
90     const int pny = realGridSizePadded[YY];
91     const int pnz = realGridSizePadded[ZZ];
92
93     const int atomIndexLocal = itemIdx.get_local_id(0);
94
95     const int chargeCheck = pmeGpuCheckAtomCharge(atomCharge);
96
97     if (chargeCheck)
98     {
99         // Spline Z coordinates
100         const int ithz = itemIdx.get_local_id(2);
101
102         const int ixBase = sm_gridlineIndices[atomIndexLocal * DIM + XX];
103         const int iyBase = sm_gridlineIndices[atomIndexLocal * DIM + YY];
104         int       iz     = sm_gridlineIndices[atomIndexLocal * DIM + ZZ] + ithz;
105         if (iz >= nz)
106         {
107             iz -= nz;
108         }
109         /* Atom index w.r.t. warp - alternating 0 1 0 1 ... */
110         const int atomWarpIndex = atomIndexLocal % atomsPerWarp;
111         /* Warp index w.r.t. block - could probably be obtained easier? */
112         const int warpIndex = atomIndexLocal / atomsPerWarp;
113
114         const int splineIndexBase = getSplineParamIndexBase<order, atomsPerWarp>(warpIndex, atomWarpIndex);
115         const int splineIndexZ = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, ZZ, ithz);
116         const float thetaZ     = sm_theta[splineIndexZ];
117
118         /* loop not used if order*order threads per atom */
119         const int ithyMin = (threadsPerAtom == ThreadsPerAtom::Order) ? 0 : itemIdx.get_local_id(YY);
120         const int ithyMax =
121                 (threadsPerAtom == ThreadsPerAtom::Order) ? order : itemIdx.get_local_id(YY) + 1;
122         for (int ithy = ithyMin; ithy < ithyMax; ithy++)
123         {
124             int iy = iyBase + ithy;
125             if (wrapY & (iy >= ny))
126             {
127                 iy -= ny;
128             }
129
130             const int splineIndexY = getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, YY, ithy);
131             float       thetaY = sm_theta[splineIndexY];
132             const float Val    = thetaZ * thetaY * (atomCharge);
133             assertIsFinite(Val);
134             const int offset = iy * pnz + iz;
135
136 #pragma unroll
137             for (int ithx = 0; (ithx < order); ithx++)
138             {
139                 int ix = ixBase + ithx;
140                 if (wrapX & (ix >= nx))
141                 {
142                     ix -= nx;
143                 }
144                 const int gridIndexGlobal = ix * pny * pnz + offset;
145                 const int splineIndexX =
146                         getSplineParamIndex<order, atomsPerWarp>(splineIndexBase, XX, ithx);
147                 const float thetaX = sm_theta[splineIndexX];
148                 assertIsFinite(thetaX);
149                 assertIsFinite(gm_grid[gridIndexGlobal]);
150                 atomicFetchAdd(gm_grid[gridIndexGlobal], thetaX * Val);
151             }
152         }
153     }
154 }
155
156
157 /*! \brief
158  * A spline computation and charge spreading kernel function.
159  *
160  * Two tuning parameters can be used for additional performance. For small systems and for debugging
161  * writeGlobal should be used removing the need to recalculate the theta values in the gather kernel.
162  * Similarly for large systems, with useOrderThreads, using order threads per atom gives higher
163  * performance than order*order threads.
164  *
165  * \tparam order          PME interpolation order.
166  * \tparam computeSplines A boolean which tells if the spline parameter and gridline indices'
167  *                        computation should be performed.
168  * \tparam spreadCharges  A boolean which tells if the charge spreading should be performed.
169  * \tparam wrapX          A boolean which tells if the grid overlap in dimension X should be wrapped.
170  * \tparam wrapY          A boolean which tells if the grid overlap in dimension Y should be wrapped.
171  * \tparam numGrids       The number of grids to use in the kernel. Can be 1 or 2.
172  * \tparam writeGlobal    A boolean which tells if the theta values and gridlines should be written
173  *                        to global memory.
174  * \tparam threadsPerAtom How many threads work on each atom.
175  * \tparam subGroupSize   Size of the sub-group.
176  */
177 template<int order, bool computeSplines, bool spreadCharges, bool wrapX, bool wrapY, int numGrids, bool writeGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
178 auto pmeSplineAndSpreadKernel(
179         cl::sycl::handler&                                                        cgh,
180         const int                                                                 nAtoms,
181         OptionalAccessor<float, mode::read_write, spreadCharges>                  a_realGrid_0,
182         OptionalAccessor<float, mode::read_write, numGrids == 2 && spreadCharges> a_realGrid_1,
183         OptionalAccessor<float, mode::read_write, writeGlobal || computeSplines>  a_theta,
184         OptionalAccessor<float, mode::write, computeSplines && writeGlobal>       a_dtheta,
185         OptionalAccessor<int, mode::write, writeGlobal>                           a_gridlineIndices,
186         OptionalAccessor<float, mode::read, computeSplines>                 a_fractShiftsTable,
187         OptionalAccessor<int, mode::read, computeSplines>                   a_gridlineIndicesTable,
188         DeviceAccessor<float, mode::read>                                   a_coefficients_0,
189         OptionalAccessor<float, mode::read, numGrids == 2 && spreadCharges> a_coefficients_1,
190         OptionalAccessor<Float3, mode::read, computeSplines>                a_coordinates,
191         const gmx::IVec                                                     tablesOffsets,
192         const gmx::IVec                                                     realGridSize,
193         const gmx::RVec                                                     realGridSizeFP,
194         const gmx::IVec                                                     realGridSizePadded,
195         const gmx::RVec                                                     currentRecipBox0,
196         const gmx::RVec                                                     currentRecipBox1,
197         const gmx::RVec                                                     currentRecipBox2)
198 {
199     constexpr int threadsPerAtomValue = (threadsPerAtom == ThreadsPerAtom::Order) ? order : order * order;
200     constexpr int spreadMaxThreadsPerBlock = c_spreadMaxWarpsPerBlock * subGroupSize;
201     constexpr int atomsPerBlock            = spreadMaxThreadsPerBlock / threadsPerAtomValue;
202     // Number of atoms processed by a single warp in spread and gather
203     static_assert(subGroupSize >= threadsPerAtomValue);
204     constexpr int atomsPerWarp = subGroupSize / threadsPerAtomValue;
205
206     if constexpr (spreadCharges)
207     {
208         cgh.require(a_realGrid_0);
209     }
210     if constexpr (writeGlobal || computeSplines)
211     {
212         cgh.require(a_theta);
213     }
214     if constexpr (computeSplines && writeGlobal)
215     {
216         cgh.require(a_dtheta);
217     }
218     if constexpr (writeGlobal)
219     {
220         cgh.require(a_gridlineIndices);
221     }
222     if constexpr (computeSplines)
223     {
224         cgh.require(a_fractShiftsTable);
225         cgh.require(a_gridlineIndicesTable);
226         cgh.require(a_coordinates);
227     }
228     cgh.require(a_coefficients_0);
229     if constexpr (numGrids == 2 && spreadCharges)
230     {
231         cgh.require(a_realGrid_1);
232         cgh.require(a_coefficients_1);
233     }
234
235     // Gridline indices, ivec
236     cl::sycl::accessor<int, 1, mode::read_write, target::local> sm_gridlineIndices(
237             cl::sycl::range<1>(atomsPerBlock * DIM), cgh);
238     // Charges
239     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_coefficients(
240             cl::sycl::range<1>(atomsPerBlock), cgh);
241     // Spline values
242     cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_theta(
243             cl::sycl::range<1>(atomsPerBlock * DIM * order), cgh);
244     auto sm_fractCoords = [&]() {
245         if constexpr (computeSplines)
246         {
247             return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
248                     cl::sycl::range<1>(atomsPerBlock * DIM), cgh);
249         }
250         else
251         {
252             return nullptr;
253         }
254     }();
255
256     return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
257     {
258         const int blockIndex      = itemIdx.get_group_linear_id();
259         const int atomIndexOffset = blockIndex * atomsPerBlock;
260
261         /* Thread index w.r.t. block */
262         const int threadLocalId = itemIdx.get_local_linear_id();
263         /* Warp index w.r.t. block - could probably be obtained easier? */
264         const int warpIndex = threadLocalId / subGroupSize;
265
266         /* Atom index w.r.t. warp */
267         const int atomWarpIndex = itemIdx.get_local_id(XX) % atomsPerWarp;
268         /* Atom index w.r.t. block/shared memory */
269         const int atomIndexLocal = warpIndex * atomsPerWarp + atomWarpIndex;
270         /* Atom index w.r.t. global memory */
271         const int atomIndexGlobal = atomIndexOffset + atomIndexLocal;
272
273         /* Early return for fully empty blocks at the end
274          * (should only happen for billions of input atoms) */
275         if (atomIndexOffset >= nAtoms)
276         {
277             return;
278         }
279
280         /* Charges, required for both spline and spread */
281         pmeGpuStageAtomData<float, atomsPerBlock, 1>(
282                 sm_coefficients.get_pointer(), a_coefficients_0.get_pointer(), itemIdx);
283         itemIdx.barrier(fence_space::local_space);
284         const float atomCharge = sm_coefficients[atomIndexLocal];
285
286         if constexpr (computeSplines)
287         {
288             // SYCL-TODO: Use prefetching? Issue #4153.
289             const Float3 atomX = a_coordinates[atomIndexGlobal];
290             // Lambdas below can be avoided when hipSYCL merges https://github.com/illuhad/hipSYCL/pull/629.
291             cl::sycl::global_ptr<float> gm_dtheta = [&]() {
292                 if constexpr (writeGlobal)
293                 {
294                     return a_dtheta.get_pointer();
295                 }
296                 else
297                 {
298                     return nullptr;
299                 }
300             }();
301             cl::sycl::global_ptr<int> gm_gridlineIndices = [&]() {
302                 if constexpr (writeGlobal)
303                 {
304                     return a_gridlineIndices.get_pointer();
305                 }
306                 else
307                 {
308                     return nullptr;
309                 }
310             }();
311             calculateSplines<order, atomsPerBlock, atomsPerWarp, false, writeGlobal, numGrids, subGroupSize>(
312                     atomIndexOffset,
313                     atomX,
314                     atomCharge,
315                     tablesOffsets,
316                     realGridSizeFP,
317                     currentRecipBox0,
318                     currentRecipBox1,
319                     currentRecipBox2,
320                     a_theta.get_pointer(),
321                     gm_dtheta,
322                     gm_gridlineIndices,
323                     a_fractShiftsTable.get_pointer(),
324                     a_gridlineIndicesTable.get_pointer(),
325                     sm_theta.get_pointer(),
326                     nullptr,
327                     sm_gridlineIndices.get_pointer(),
328                     sm_fractCoords.get_pointer(),
329                     itemIdx);
330             subGroupBarrier(itemIdx);
331         }
332         else
333         {
334             /* Staging the data for spread
335              * (the data is assumed to be in GPU global memory with proper layout already,
336              * as in after running the spline kernel)
337              */
338             /* Spline data - only thetas (dthetas will only be needed in gather) */
339             pmeGpuStageAtomData<float, atomsPerBlock, DIM * order>(
340                     sm_theta.get_pointer(), a_theta.get_pointer(), itemIdx);
341             /* Gridline indices */
342             pmeGpuStageAtomData<int, atomsPerBlock, DIM>(
343                     sm_gridlineIndices.get_pointer(), a_gridlineIndices.get_pointer(), itemIdx);
344
345             itemIdx.barrier(fence_space::local_space);
346         }
347
348         /* Spreading */
349         if constexpr (spreadCharges)
350         {
351             spread_charges<order, wrapX, wrapY, threadsPerAtom, subGroupSize>(
352                     atomCharge,
353                     realGridSize,
354                     realGridSizePadded,
355                     a_realGrid_0.get_pointer(),
356                     sm_gridlineIndices.get_pointer(),
357                     sm_theta.get_pointer(),
358                     itemIdx);
359         }
360         if constexpr (numGrids == 2 && spreadCharges)
361         {
362             itemIdx.barrier(fence_space::local_space);
363             pmeGpuStageAtomData<float, atomsPerBlock, 1>(
364                     sm_coefficients.get_pointer(), a_coefficients_1.get_pointer(), itemIdx);
365             itemIdx.barrier(fence_space::local_space);
366             const float atomCharge = sm_coefficients[atomIndexLocal];
367
368             spread_charges<order, wrapX, wrapY, threadsPerAtom, subGroupSize>(
369                     atomCharge,
370                     realGridSize,
371                     realGridSizePadded,
372                     a_realGrid_1.get_pointer(),
373                     sm_gridlineIndices.get_pointer(),
374                     sm_theta.get_pointer(),
375                     itemIdx);
376         }
377     };
378 }
379
380 template<int order, bool computeSplines, bool spreadCharges, bool wrapX, bool wrapY, int numGrids, bool writeGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
381 PmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, wrapX, wrapY, numGrids, writeGlobal, threadsPerAtom, subGroupSize>::PmeSplineAndSpreadKernel()
382 {
383     reset();
384 }
385
386 template<int order, bool computeSplines, bool spreadCharges, bool wrapX, bool wrapY, int numGrids, bool writeGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
387 void PmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, wrapX, wrapY, numGrids, writeGlobal, threadsPerAtom, subGroupSize>::setArg(
388         size_t argIndex,
389         void*  arg)
390 {
391     if (argIndex == 0)
392     {
393         auto* params   = reinterpret_cast<PmeGpuKernelParams*>(arg);
394         gridParams_    = &params->grid;
395         atomParams_    = &params->atoms;
396         dynamicParams_ = &params->current;
397     }
398     else
399     {
400         GMX_RELEASE_ASSERT(argIndex == 0, "Trying to pass too many args to the kernel");
401     }
402 }
403
404
405 template<int order, bool computeSplines, bool spreadCharges, bool wrapX, bool wrapY, int numGrids, bool writeGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
406 cl::sycl::event
407 PmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, wrapX, wrapY, numGrids, writeGlobal, threadsPerAtom, subGroupSize>::launch(
408         const KernelLaunchConfig& config,
409         const DeviceStream&       deviceStream)
410 {
411     GMX_RELEASE_ASSERT(gridParams_, "Can not launch the kernel before setting its args");
412     GMX_RELEASE_ASSERT(atomParams_, "Can not launch the kernel before setting its args");
413     GMX_RELEASE_ASSERT(dynamicParams_, "Can not launch the kernel before setting its args");
414
415     using kernelNameType =
416             PmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, wrapX, wrapY, numGrids, writeGlobal, threadsPerAtom, subGroupSize>;
417
418     // SYCL has different multidimensional layout than OpenCL/CUDA.
419     const cl::sycl::range<3> localSize{ config.blockSize[2], config.blockSize[1], config.blockSize[0] };
420     const cl::sycl::range<3> groupRange{ config.gridSize[2], config.gridSize[1], config.gridSize[0] };
421     const cl::sycl::nd_range<3> range{ groupRange * localSize, localSize };
422
423     cl::sycl::queue q = deviceStream.stream();
424
425
426     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
427         auto kernel =
428                 pmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, wrapX, wrapY, numGrids, writeGlobal, threadsPerAtom, subGroupSize>(
429                         cgh,
430                         atomParams_->nAtoms,
431                         gridParams_->d_realGrid[0],
432                         gridParams_->d_realGrid[1],
433                         atomParams_->d_theta,
434                         atomParams_->d_dtheta,
435                         atomParams_->d_gridlineIndices,
436                         gridParams_->d_fractShiftsTable,
437                         gridParams_->d_gridlineIndicesTable,
438                         atomParams_->d_coefficients[0],
439                         atomParams_->d_coefficients[1],
440                         atomParams_->d_coordinates,
441                         gridParams_->tablesOffsets,
442                         gridParams_->realGridSize,
443                         gridParams_->realGridSizeFP,
444                         gridParams_->realGridSizePadded,
445                         dynamicParams_->recipBox[0],
446                         dynamicParams_->recipBox[1],
447                         dynamicParams_->recipBox[2]);
448         cgh.parallel_for<kernelNameType>(range, kernel);
449     });
450
451     // Delete set args, so we don't forget to set them before the next launch.
452     reset();
453
454     return e;
455 }
456
457
458 template<int order, bool computeSplines, bool spreadCharges, bool wrapX, bool wrapY, int numGrids, bool writeGlobal, ThreadsPerAtom threadsPerAtom, int subGroupSize>
459 void PmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, wrapX, wrapY, numGrids, writeGlobal, threadsPerAtom, subGroupSize>::reset()
460 {
461     gridParams_    = nullptr;
462     atomParams_    = nullptr;
463     dynamicParams_ = nullptr;
464 }
465
466
467 //! Kernel instantiations
468 /* Disable the "explicit template instantiation 'PmeSplineAndSpreadKernel<...>' will emit a vtable in every
469  * translation unit [-Wweak-template-vtables]" warning.
470  * It is only explicitly instantiated in this translation unit, so we should be safe.
471  */
472 #ifdef __clang__
473 #    pragma clang diagnostic push
474 #    pragma clang diagnostic ignored "-Wweak-template-vtables"
475 #endif
476
477 #define INSTANTIATE_3(order, computeSplines, spreadCharges, numGrids, writeGlobal, threadsPerAtom, subGroupSize) \
478     template class PmeSplineAndSpreadKernel<order, computeSplines, spreadCharges, true, true, numGrids, writeGlobal, threadsPerAtom, subGroupSize>;
479
480 #define INSTANTIATE_2(order, numGrids, threadsPerAtom, subGroupSize)                 \
481     INSTANTIATE_3(order, true, true, numGrids, true, threadsPerAtom, subGroupSize);  \
482     INSTANTIATE_3(order, true, false, numGrids, true, threadsPerAtom, subGroupSize); \
483     INSTANTIATE_3(order, false, true, numGrids, true, threadsPerAtom, subGroupSize); \
484     INSTANTIATE_3(order, true, true, numGrids, false, threadsPerAtom, subGroupSize);
485
486 #define INSTANTIATE(order, subGroupSize)                                 \
487     INSTANTIATE_2(order, 1, ThreadsPerAtom::Order, subGroupSize);        \
488     INSTANTIATE_2(order, 1, ThreadsPerAtom::OrderSquared, subGroupSize); \
489     INSTANTIATE_2(order, 2, ThreadsPerAtom::Order, subGroupSize);        \
490     INSTANTIATE_2(order, 2, ThreadsPerAtom::OrderSquared, subGroupSize);
491
492 #if GMX_SYCL_DPCPP
493 INSTANTIATE(4, 16); // TODO: Choose best value, Issue #4153.
494 #elif GMX_SYCL_HIPSYCL
495 INSTANTIATE(4, 32);
496 INSTANTIATE(4, 64);
497 #endif
498
499 #ifdef __clang__
500 #    pragma clang diagnostic pop
501 #endif