5a0125c44d286ddcc71bc29f6dd5d80fd5024388
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve_sycl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements PME GPU Fourier grid solving in SYCL.
38  *
39  *  \author Mark Abraham <mark.j.abraham@gmail.com>
40  */
41
42 #include "gmxpre.h"
43
44 #include "pme_solve_sycl.h"
45
46 #include <cassert>
47
48 #include "gromacs/gpu_utils/gmxsycl.h"
49 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
50 #include "gromacs/math/units.h"
51
52 #include "pme_gpu_constants.h"
53
54 using cl::sycl::access::mode;
55
56 /*! \brief
57  * PME complex grid solver kernel function.
58  *
59  * \tparam     gridOrdering             Specifies the dimension ordering of the complex grid.
60  * \tparam     computeEnergyAndVirial   Tells if the reciprocal energy and virial should be
61  *                                        computed.
62  * \tparam     subGroupSize             Describes the width of a SYCL subgroup
63  */
64 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int subGroupSize>
65 auto makeSolveKernel(cl::sycl::handler&                            cgh,
66                      DeviceAccessor<float, mode::read>             a_splineModuli,
67                      DeviceAccessor<SolveKernelParams, mode::read> a_solveKernelParams,
68                      OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
69                      DeviceAccessor<float, mode::read_write> a_fourierGrid)
70 {
71     a_splineModuli.bind(cgh);
72     a_solveKernelParams.bind(cgh);
73     if constexpr (computeEnergyAndVirial)
74     {
75         a_virialAndEnergy.bind(cgh);
76     }
77     a_fourierGrid.bind(cgh);
78
79     /* Reduce 7 outputs per warp in the shared memory */
80     const int stride =
81             8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
82     static_assert(c_virialAndEnergyCount == 7);
83     const int reductionBufferSize = c_solveMaxWarpsPerBlock * stride;
84     cl::sycl::accessor<float, 1, mode::read_write, cl::sycl::target::local> sm_virialAndEnergy(
85             cl::sycl::range<1>(reductionBufferSize), cgh);
86
87     /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
88      * Each block handles up to c_solveMaxWarpsPerBlock * subGroupSize cells -
89      * depending on the grid contiguous dimension size,
90      * that can range from a part of a single gridline to several complete gridlines.
91      */
92     return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
93     {
94         /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
95         int majorDim, middleDim, minorDim;
96         switch (gridOrdering)
97         {
98             case GridOrdering::YZX:
99                 majorDim  = YY;
100                 middleDim = ZZ;
101                 minorDim  = XX;
102                 break;
103
104             case GridOrdering::XYZ:
105                 majorDim  = XX;
106                 middleDim = YY;
107                 minorDim  = ZZ;
108                 break;
109
110             default: assert(false);
111         }
112
113         /* Global memory pointers */
114         const float* __restrict__ gm_splineValueMajor =
115                 a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[majorDim];
116         const float* __restrict__ gm_splineValueMiddle =
117                 a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[middleDim];
118         const float* __restrict__ gm_splineValueMinor =
119                 a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[minorDim];
120         // The Fourier grid is allocated as float values, even though
121         // it logically contains complex values. (It also can be
122         // the same memory as the real grid for in-place transforms.)
123         // The buffer underlying the accessor may have a size that is
124         // larger than the active grid, because it is allocated with
125         // reallocateDeviceBuffer. The size of that larger-than-needed
126         // grid can be an odd number of floats, even though actual
127         // grid code only accesses up to an even number of floats. If
128         // we would use the reinterpet method of the accessor to
129         // convert from float to float2, runtime boundary checks can
130         // fail because of this mismatch. So, we extract the
131         // underlying global_ptr and use that to construct
132         // cl::sycl::float2 values when needed.
133         cl::sycl::global_ptr<float> gm_fourierGrid = a_fourierGrid.get_pointer();
134
135         /* Various grid sizes and indices */
136         const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0;
137         const int localSizeMinor   = a_solveKernelParams[0].complexGridSizePadded[minorDim];
138         const int localSizeMiddle  = a_solveKernelParams[0].complexGridSizePadded[middleDim];
139         const int localCountMiddle = a_solveKernelParams[0].complexGridSize[middleDim];
140         const int localCountMinor  = a_solveKernelParams[0].complexGridSize[minorDim];
141         const int nMajor           = a_solveKernelParams[0].realGridSize[majorDim];
142         const int nMiddle          = a_solveKernelParams[0].realGridSize[middleDim];
143         const int nMinor           = a_solveKernelParams[0].realGridSize[minorDim];
144         const int maxkMajor        = (nMajor + 1) / 2;  // X or Y
145         const int maxkMiddle       = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
146         const int maxkMinor        = (nMinor + 1) / 2;  // Z or X => only check for YZX
147
148         const int threadLocalId     = itemIdx.get_local_linear_id();
149         const int gridLineSize      = localCountMinor;
150         const int gridLineIndex     = threadLocalId / gridLineSize;
151         const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
152         const int gridLinesPerBlock =
153                 cl::sycl::max(itemIdx.get_local_range(2) / size_t(gridLineSize), size_t(1));
154         const int activeWarps = (itemIdx.get_local_range(2) / subGroupSize);
155         const int indexMinor = itemIdx.get_group(2) * itemIdx.get_local_range(2) + gridLineCellIndex;
156         const int indexMiddle = itemIdx.get_group(1) * gridLinesPerBlock + gridLineIndex;
157         const int indexMajor  = itemIdx.get_group(0);
158
159         /* Optional outputs */
160         float energy = 0.0F;
161         float virxx  = 0.0F;
162         float virxy  = 0.0F;
163         float virxz  = 0.0F;
164         float viryy  = 0.0F;
165         float viryz  = 0.0F;
166         float virzz  = 0.0F;
167
168         assert(indexMajor < a_solveKernelParams[0].complexGridSize[majorDim]);
169         if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
170             & (gridLineIndex < gridLinesPerBlock))
171         {
172             /* The offset should be equal to the global thread index for coalesced access */
173             const int gridThreadIndex =
174                     (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
175
176             const int kMajor = indexMajor + localOffsetMajor;
177             /* Checking either X in XYZ, or Y in YZX cases */
178             const float mMajor = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
179
180             const int kMiddle = indexMiddle + localOffsetMiddle;
181             float     mMiddle = kMiddle;
182             /* Checking Y in XYZ case */
183             if (gridOrdering == GridOrdering::XYZ)
184             {
185                 mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
186             }
187             const int kMinor = localOffsetMinor + indexMinor;
188             float     mMinor = kMinor;
189             /* Checking X in YZX case */
190             if (gridOrdering == GridOrdering::YZX)
191             {
192                 mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
193             }
194             /* We should skip the k-space point (0,0,0) */
195             const bool notZeroPoint = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
196
197             float mX, mY, mZ;
198             switch (gridOrdering)
199             {
200                 case GridOrdering::YZX:
201                     mX = mMinor;
202                     mY = mMajor;
203                     mZ = mMiddle;
204                     break;
205
206                 case GridOrdering::XYZ:
207                     mX = mMajor;
208                     mY = mMiddle;
209                     mZ = mMinor;
210                     break;
211
212                 default: assert(false);
213             }
214
215             /* 0.5 correction factor for the first and last components of a Z dimension */
216             float corner_fac = 1.0F;
217             switch (gridOrdering)
218             {
219                 case GridOrdering::YZX:
220                     if ((kMiddle == 0) | (kMiddle == maxkMiddle))
221                     {
222                         corner_fac = 0.5F;
223                     }
224                     break;
225
226                 case GridOrdering::XYZ:
227                     if ((kMinor == 0) | (kMinor == maxkMinor))
228                     {
229                         corner_fac = 0.5F;
230                     }
231                     break;
232
233                 default: assert(false);
234             }
235
236             if (notZeroPoint)
237             {
238                 const float mhxk = mX * a_solveKernelParams[0].recipBox[XX][XX];
239                 const float mhyk = mX * a_solveKernelParams[0].recipBox[XX][YY]
240                                    + mY * a_solveKernelParams[0].recipBox[YY][YY];
241                 const float mhzk = mX * a_solveKernelParams[0].recipBox[XX][ZZ]
242                                    + mY * a_solveKernelParams[0].recipBox[YY][ZZ]
243                                    + mZ * a_solveKernelParams[0].recipBox[ZZ][ZZ];
244
245                 const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
246                 assert(m2k != 0.0F);
247                 float denom = m2k * float(M_PI) * a_solveKernelParams[0].boxVolume
248                               * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
249                               * gm_splineValueMinor[kMinor];
250                 assert(sycl_2020::isfinite(denom));
251                 assert(denom != 0.0F);
252
253                 const float tmp1   = cl::sycl::exp(-a_solveKernelParams[0].ewaldFactor * m2k);
254                 const float etermk = a_solveKernelParams[0].elFactor * tmp1 / denom;
255
256                 // sycl::float2::load and store are buggy in hipSYCL,
257                 // but can probably be used after resolution of
258                 // https://github.com/illuhad/hipSYCL/issues/647
259                 cl::sycl::float2 gridValue;
260                 sycl_2020::loadToVec(
261                         gridThreadIndex, cl::sycl::global_ptr<const float>(gm_fourierGrid), &gridValue);
262                 const cl::sycl::float2 oldGridValue = gridValue;
263                 gridValue *= etermk;
264                 sycl_2020::storeFromVec(gridValue, gridThreadIndex, gm_fourierGrid);
265
266                 if (computeEnergyAndVirial)
267                 {
268                     const float tmp1k = 2.0F * cl::sycl::dot(gridValue, oldGridValue);
269
270                     float vfactor = (a_solveKernelParams[0].ewaldFactor + 1.0F / m2k) * 2.0F;
271                     float ets2    = corner_fac * tmp1k;
272                     energy        = ets2;
273
274                     float ets2vf = ets2 * vfactor;
275
276                     virxx = ets2vf * mhxk * mhxk - ets2;
277                     virxy = ets2vf * mhxk * mhyk;
278                     virxz = ets2vf * mhxk * mhzk;
279                     viryy = ets2vf * mhyk * mhyk - ets2;
280                     viryz = ets2vf * mhyk * mhzk;
281                     virzz = ets2vf * mhzk * mhzk - ets2;
282                 }
283             }
284         }
285
286         /* Optional energy/virial reduction */
287         if constexpr (computeEnergyAndVirial)
288         {
289             /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
290              * The idea is to reduce 7 energy/virial components into a single variable (aligned by
291              * 8). We will reduce everything into virxx.
292              */
293
294             /* We can only reduce warp-wise */
295             const int width = subGroupSize;
296             static_assert(subGroupSize >= 8);
297
298             sycl_2020::sub_group sg = itemIdx.get_sub_group();
299
300             /* Making pair sums */
301             virxx += sycl_2020::shift_left(sg, virxx, 1);
302             viryy += sycl_2020::shift_right(sg, viryy, 1);
303             virzz += sycl_2020::shift_left(sg, virzz, 1);
304             virxy += sycl_2020::shift_right(sg, virxy, 1);
305             virxz += sycl_2020::shift_left(sg, virxz, 1);
306             viryz += sycl_2020::shift_right(sg, viryz, 1);
307             energy += sycl_2020::shift_left(sg, energy, 1);
308             if (threadLocalId & 1)
309             {
310                 virxx = viryy; // virxx now holds virxx and viryy pair sums
311                 virzz = virxy; // virzz now holds virzz and virxy pair sums
312                 virxz = viryz; // virxz now holds virxz and viryz pair sums
313             }
314
315             /* Making quad sums */
316             virxx += sycl_2020::shift_left(sg, virxx, 2);
317             virzz += sycl_2020::shift_right(sg, virzz, 2);
318             virxz += sycl_2020::shift_left(sg, virxz, 2);
319             energy += sycl_2020::shift_right(sg, energy, 2);
320             if (threadLocalId & 2)
321             {
322                 virxx = virzz; // virxx now holds quad sums of virxx, virxy, virzz and virxy
323                 virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
324             }
325
326             /* Making octet sums */
327             virxx += sycl_2020::shift_left(sg, virxx, 4);
328             virxz += sycl_2020::shift_right(sg, virxz, 4);
329             if (threadLocalId & 4)
330             {
331                 virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
332             }
333
334             /* We only need to reduce virxx now */
335 #pragma unroll
336             for (int delta = 8; delta < width; delta <<= 1)
337             {
338                 virxx += sycl_2020::shift_left(sg, virxx, delta);
339             }
340             /* Now first 7 threads of each warp have the full output contributions in virxx */
341
342             const int  componentIndex      = threadLocalId & (subGroupSize - 1);
343             const bool validComponentIndex = (componentIndex < c_virialAndEnergyCount);
344
345             if (validComponentIndex)
346             {
347                 const int warpIndex = threadLocalId / subGroupSize;
348                 sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
349             }
350             itemIdx.barrier(cl::sycl::access::fence_space::local_space);
351
352             /* Reduce to the single warp size */
353             const int targetIndex = threadLocalId;
354 #pragma unroll
355             for (int reductionStride = reductionBufferSize >> 1; reductionStride >= subGroupSize;
356                  reductionStride >>= 1)
357             {
358                 const int sourceIndex = targetIndex + reductionStride;
359                 if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
360                 {
361                     sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
362                 }
363                 itemIdx.barrier(cl::sycl::access::fence_space::local_space);
364             }
365
366             /* Now use shuffle again */
367             /* NOTE: This reduction assumes there are at least 4 warps (asserted).
368              *       To use fewer warps, add to the conditional:
369              *       && threadLocalId < activeWarps * stride
370              */
371             assert(activeWarps * stride >= subGroupSize);
372             if (threadLocalId < subGroupSize)
373             {
374                 float output = sm_virialAndEnergy[threadLocalId];
375 #pragma unroll
376                 for (int delta = stride; delta < subGroupSize; delta <<= 1)
377                 {
378                     output += sycl_2020::shift_left(sg, output, delta);
379                 }
380                 /* Final output */
381                 if (validComponentIndex)
382                 {
383                     assert(sycl_2020::isfinite(output));
384                     atomicFetchAdd(a_virialAndEnergy[componentIndex], output);
385                 }
386             }
387         }
388     };
389 }
390
391 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
392 PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::PmeSolveKernel()
393 {
394     reset();
395 }
396
397 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
398 void PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::setArg(size_t argIndex,
399                                                                                            void* arg)
400 {
401     if (argIndex == 0)
402     {
403         auto* params = reinterpret_cast<PmeGpuKernelParams*>(arg);
404
405         constParams_                             = &params->constants;
406         gridParams_                              = &params->grid;
407         solveKernelParams_.ewaldFactor           = params->grid.ewaldFactor;
408         solveKernelParams_.realGridSize          = params->grid.realGridSize;
409         solveKernelParams_.complexGridSize       = params->grid.complexGridSize;
410         solveKernelParams_.complexGridSizePadded = params->grid.complexGridSizePadded;
411         solveKernelParams_.splineValuesOffset    = params->grid.splineValuesOffset;
412         solveKernelParams_.recipBox[XX]          = params->current.recipBox[XX];
413         solveKernelParams_.recipBox[YY]          = params->current.recipBox[YY];
414         solveKernelParams_.recipBox[ZZ]          = params->current.recipBox[ZZ];
415         solveKernelParams_.boxVolume             = params->current.boxVolume;
416         solveKernelParams_.elFactor              = params->constants.elFactor;
417     }
418     else
419     {
420         GMX_RELEASE_ASSERT(argIndex == 0, "Trying to pass too many args to the solve kernel");
421     }
422 }
423
424 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
425 cl::sycl::event PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::launch(
426         const KernelLaunchConfig& config,
427         const DeviceStream&       deviceStream)
428 {
429     GMX_RELEASE_ASSERT(gridParams_, "Can not launch the kernel before setting its args");
430     GMX_RELEASE_ASSERT(constParams_, "Can not launch the kernel before setting its args");
431
432     using KernelNameType = PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>;
433
434     // SYCL has different multidimensional layout than OpenCL/CUDA.
435     const cl::sycl::range<3> localSize{ config.blockSize[2], config.blockSize[1], config.blockSize[0] };
436     const cl::sycl::range<3> groupRange{ config.gridSize[2], config.gridSize[1], config.gridSize[0] };
437     const cl::sycl::nd_range<3> range{ groupRange * localSize, localSize };
438
439     cl::sycl::queue q = deviceStream.stream();
440
441     cl::sycl::buffer<SolveKernelParams, 1> d_solveKernelParams(&solveKernelParams_, 1);
442     cl::sycl::event                        e = q.submit([&](cl::sycl::handler& cgh) {
443         auto kernel = makeSolveKernel<gridOrdering, computeEnergyAndVirial, subGroupSize>(
444                 cgh,
445                 gridParams_->d_splineModuli[gridIndex],
446                 d_solveKernelParams,
447                 constParams_->d_virialAndEnergy[gridIndex],
448                 gridParams_->d_fourierGrid[gridIndex]);
449         cgh.parallel_for<KernelNameType>(range, kernel);
450     });
451
452     // Delete set args, so we don't forget to set them before the next launch.
453     reset();
454
455     return e;
456 }
457
458 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
459 void PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::reset()
460 {
461     gridParams_  = nullptr;
462     constParams_ = nullptr;
463 }
464
465 //! Kernel class instantiations
466 /* Disable the "explicit template instantiation 'PmeSplineAndSpreadKernel<...>' will emit a vtable in every
467  * translation unit [-Wweak-template-vtables]" warning.
468  * It is only explicitly instantiated in this translation unit, so we should be safe.
469  */
470 #ifdef __clang__
471 #    pragma clang diagnostic push
472 #    pragma clang diagnostic ignored "-Wweak-template-vtables"
473 #endif
474
475 #define INSTANTIATE(subGroupSize)                                             \
476     template class PmeSolveKernel<GridOrdering::XYZ, false, 0, subGroupSize>; \
477     template class PmeSolveKernel<GridOrdering::XYZ, true, 0, subGroupSize>;  \
478     template class PmeSolveKernel<GridOrdering::YZX, false, 0, subGroupSize>; \
479     template class PmeSolveKernel<GridOrdering::YZX, true, 0, subGroupSize>;  \
480     template class PmeSolveKernel<GridOrdering::XYZ, false, 1, subGroupSize>; \
481     template class PmeSolveKernel<GridOrdering::XYZ, true, 1, subGroupSize>;  \
482     template class PmeSolveKernel<GridOrdering::YZX, false, 1, subGroupSize>; \
483     template class PmeSolveKernel<GridOrdering::YZX, true, 1, subGroupSize>;
484
485 #if GMX_SYCL_DPCPP
486 INSTANTIATE(16);
487 #elif GMX_SYCL_HIPSYCL
488 INSTANTIATE(32);
489 INSTANTIATE(64);
490 #endif
491
492 #ifdef __clang__
493 #    pragma clang diagnostic pop
494 #endif