2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2021, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 * \brief Implements PME GPU Fourier grid solving in SYCL.
39 * \author Mark Abraham <mark.j.abraham@gmail.com>
44 #include "pme_solve_sycl.h"
48 #include "gromacs/gpu_utils/gmxsycl.h"
49 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
50 #include "gromacs/math/units.h"
52 #include "pme_gpu_constants.h"
54 using cl::sycl::access::mode;
57 * PME complex grid solver kernel function.
59 * \tparam gridOrdering Specifies the dimension ordering of the complex grid.
60 * \tparam computeEnergyAndVirial Tells if the reciprocal energy and virial should be
62 * \tparam subGroupSize Describes the width of a SYCL subgroup
64 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int subGroupSize>
65 auto makeSolveKernel(cl::sycl::handler& cgh,
66 DeviceAccessor<float, mode::read> a_splineModuli,
67 DeviceAccessor<SolveKernelParams, mode::read> a_solveKernelParams,
68 OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
69 DeviceAccessor<float, mode::read_write> a_fourierGrid)
71 a_splineModuli.bind(cgh);
72 a_solveKernelParams.bind(cgh);
73 if constexpr (computeEnergyAndVirial)
75 a_virialAndEnergy.bind(cgh);
77 a_fourierGrid.bind(cgh);
79 /* Reduce 7 outputs per warp in the shared memory */
81 8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
82 static_assert(c_virialAndEnergyCount == 7);
83 const int reductionBufferSize = c_solveMaxWarpsPerBlock * stride;
84 cl::sycl::accessor<float, 1, mode::read_write, cl::sycl::target::local> sm_virialAndEnergy(
85 cl::sycl::range<1>(reductionBufferSize), cgh);
87 /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
88 * Each block handles up to c_solveMaxWarpsPerBlock * subGroupSize cells -
89 * depending on the grid contiguous dimension size,
90 * that can range from a part of a single gridline to several complete gridlines.
92 return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
94 /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
95 int majorDim, middleDim, minorDim;
98 case GridOrdering::YZX:
104 case GridOrdering::XYZ:
110 default: assert(false);
113 /* Global memory pointers */
114 const float* __restrict__ gm_splineValueMajor =
115 a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[majorDim];
116 const float* __restrict__ gm_splineValueMiddle =
117 a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[middleDim];
118 const float* __restrict__ gm_splineValueMinor =
119 a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[minorDim];
120 // The Fourier grid is allocated as float values, even though
121 // it logically contains complex values. (It also can be
122 // the same memory as the real grid for in-place transforms.)
123 // The buffer underlying the accessor may have a size that is
124 // larger than the active grid, because it is allocated with
125 // reallocateDeviceBuffer. The size of that larger-than-needed
126 // grid can be an odd number of floats, even though actual
127 // grid code only accesses up to an even number of floats. If
128 // we would use the reinterpet method of the accessor to
129 // convert from float to float2, runtime boundary checks can
130 // fail because of this mismatch. So, we extract the
131 // underlying global_ptr and use that to construct
132 // cl::sycl::float2 values when needed.
133 cl::sycl::global_ptr<float> gm_fourierGrid = a_fourierGrid.get_pointer();
135 /* Various grid sizes and indices */
136 const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0;
137 const int localSizeMinor = a_solveKernelParams[0].complexGridSizePadded[minorDim];
138 const int localSizeMiddle = a_solveKernelParams[0].complexGridSizePadded[middleDim];
139 const int localCountMiddle = a_solveKernelParams[0].complexGridSize[middleDim];
140 const int localCountMinor = a_solveKernelParams[0].complexGridSize[minorDim];
141 const int nMajor = a_solveKernelParams[0].realGridSize[majorDim];
142 const int nMiddle = a_solveKernelParams[0].realGridSize[middleDim];
143 const int nMinor = a_solveKernelParams[0].realGridSize[minorDim];
144 const int maxkMajor = (nMajor + 1) / 2; // X or Y
145 const int maxkMiddle = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
146 const int maxkMinor = (nMinor + 1) / 2; // Z or X => only check for YZX
148 const int threadLocalId = itemIdx.get_local_linear_id();
149 const int gridLineSize = localCountMinor;
150 const int gridLineIndex = threadLocalId / gridLineSize;
151 const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
152 const int gridLinesPerBlock =
153 cl::sycl::max(itemIdx.get_local_range(2) / size_t(gridLineSize), size_t(1));
154 const int activeWarps = (itemIdx.get_local_range(2) / subGroupSize);
155 const int indexMinor = itemIdx.get_group(2) * itemIdx.get_local_range(2) + gridLineCellIndex;
156 const int indexMiddle = itemIdx.get_group(1) * gridLinesPerBlock + gridLineIndex;
157 const int indexMajor = itemIdx.get_group(0);
159 /* Optional outputs */
168 assert(indexMajor < a_solveKernelParams[0].complexGridSize[majorDim]);
169 if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
170 & (gridLineIndex < gridLinesPerBlock))
172 /* The offset should be equal to the global thread index for coalesced access */
173 const int gridThreadIndex =
174 (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
176 const int kMajor = indexMajor + localOffsetMajor;
177 /* Checking either X in XYZ, or Y in YZX cases */
178 const float mMajor = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
180 const int kMiddle = indexMiddle + localOffsetMiddle;
181 float mMiddle = kMiddle;
182 /* Checking Y in XYZ case */
183 if (gridOrdering == GridOrdering::XYZ)
185 mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
187 const int kMinor = localOffsetMinor + indexMinor;
188 float mMinor = kMinor;
189 /* Checking X in YZX case */
190 if (gridOrdering == GridOrdering::YZX)
192 mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
194 /* We should skip the k-space point (0,0,0) */
195 const bool notZeroPoint = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
198 switch (gridOrdering)
200 case GridOrdering::YZX:
206 case GridOrdering::XYZ:
212 default: assert(false);
215 /* 0.5 correction factor for the first and last components of a Z dimension */
216 float corner_fac = 1.0F;
217 switch (gridOrdering)
219 case GridOrdering::YZX:
220 if ((kMiddle == 0) | (kMiddle == maxkMiddle))
226 case GridOrdering::XYZ:
227 if ((kMinor == 0) | (kMinor == maxkMinor))
233 default: assert(false);
238 const float mhxk = mX * a_solveKernelParams[0].recipBox[XX][XX];
239 const float mhyk = mX * a_solveKernelParams[0].recipBox[XX][YY]
240 + mY * a_solveKernelParams[0].recipBox[YY][YY];
241 const float mhzk = mX * a_solveKernelParams[0].recipBox[XX][ZZ]
242 + mY * a_solveKernelParams[0].recipBox[YY][ZZ]
243 + mZ * a_solveKernelParams[0].recipBox[ZZ][ZZ];
245 const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
247 float denom = m2k * float(M_PI) * a_solveKernelParams[0].boxVolume
248 * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
249 * gm_splineValueMinor[kMinor];
250 assert(sycl_2020::isfinite(denom));
251 assert(denom != 0.0F);
253 const float tmp1 = cl::sycl::exp(-a_solveKernelParams[0].ewaldFactor * m2k);
254 const float etermk = a_solveKernelParams[0].elFactor * tmp1 / denom;
256 // sycl::float2::load and store are buggy in hipSYCL,
257 // but can probably be used after resolution of
258 // https://github.com/illuhad/hipSYCL/issues/647
259 cl::sycl::float2 gridValue;
260 sycl_2020::loadToVec(
261 gridThreadIndex, cl::sycl::global_ptr<const float>(gm_fourierGrid), &gridValue);
262 const cl::sycl::float2 oldGridValue = gridValue;
264 sycl_2020::storeFromVec(gridValue, gridThreadIndex, gm_fourierGrid);
266 if (computeEnergyAndVirial)
268 const float tmp1k = 2.0F * cl::sycl::dot(gridValue, oldGridValue);
270 float vfactor = (a_solveKernelParams[0].ewaldFactor + 1.0F / m2k) * 2.0F;
271 float ets2 = corner_fac * tmp1k;
274 float ets2vf = ets2 * vfactor;
276 virxx = ets2vf * mhxk * mhxk - ets2;
277 virxy = ets2vf * mhxk * mhyk;
278 virxz = ets2vf * mhxk * mhzk;
279 viryy = ets2vf * mhyk * mhyk - ets2;
280 viryz = ets2vf * mhyk * mhzk;
281 virzz = ets2vf * mhzk * mhzk - ets2;
286 /* Optional energy/virial reduction */
287 if constexpr (computeEnergyAndVirial)
289 /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
290 * The idea is to reduce 7 energy/virial components into a single variable (aligned by
291 * 8). We will reduce everything into virxx.
294 /* We can only reduce warp-wise */
295 const int width = subGroupSize;
296 static_assert(subGroupSize >= 8);
298 sycl_2020::sub_group sg = itemIdx.get_sub_group();
300 /* Making pair sums */
301 virxx += sycl_2020::shift_left(sg, virxx, 1);
302 viryy += sycl_2020::shift_right(sg, viryy, 1);
303 virzz += sycl_2020::shift_left(sg, virzz, 1);
304 virxy += sycl_2020::shift_right(sg, virxy, 1);
305 virxz += sycl_2020::shift_left(sg, virxz, 1);
306 viryz += sycl_2020::shift_right(sg, viryz, 1);
307 energy += sycl_2020::shift_left(sg, energy, 1);
308 if (threadLocalId & 1)
310 virxx = viryy; // virxx now holds virxx and viryy pair sums
311 virzz = virxy; // virzz now holds virzz and virxy pair sums
312 virxz = viryz; // virxz now holds virxz and viryz pair sums
315 /* Making quad sums */
316 virxx += sycl_2020::shift_left(sg, virxx, 2);
317 virzz += sycl_2020::shift_right(sg, virzz, 2);
318 virxz += sycl_2020::shift_left(sg, virxz, 2);
319 energy += sycl_2020::shift_right(sg, energy, 2);
320 if (threadLocalId & 2)
322 virxx = virzz; // virxx now holds quad sums of virxx, virxy, virzz and virxy
323 virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
326 /* Making octet sums */
327 virxx += sycl_2020::shift_left(sg, virxx, 4);
328 virxz += sycl_2020::shift_right(sg, virxz, 4);
329 if (threadLocalId & 4)
331 virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
334 /* We only need to reduce virxx now */
336 for (int delta = 8; delta < width; delta <<= 1)
338 virxx += sycl_2020::shift_left(sg, virxx, delta);
340 /* Now first 7 threads of each warp have the full output contributions in virxx */
342 const int componentIndex = threadLocalId & (subGroupSize - 1);
343 const bool validComponentIndex = (componentIndex < c_virialAndEnergyCount);
345 if (validComponentIndex)
347 const int warpIndex = threadLocalId / subGroupSize;
348 sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
350 itemIdx.barrier(cl::sycl::access::fence_space::local_space);
352 /* Reduce to the single warp size */
353 const int targetIndex = threadLocalId;
355 for (int reductionStride = reductionBufferSize >> 1; reductionStride >= subGroupSize;
356 reductionStride >>= 1)
358 const int sourceIndex = targetIndex + reductionStride;
359 if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
361 sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
363 itemIdx.barrier(cl::sycl::access::fence_space::local_space);
366 /* Now use shuffle again */
367 /* NOTE: This reduction assumes there are at least 4 warps (asserted).
368 * To use fewer warps, add to the conditional:
369 * && threadLocalId < activeWarps * stride
371 assert(activeWarps * stride >= subGroupSize);
372 if (threadLocalId < subGroupSize)
374 float output = sm_virialAndEnergy[threadLocalId];
376 for (int delta = stride; delta < subGroupSize; delta <<= 1)
378 output += sycl_2020::shift_left(sg, output, delta);
381 if (validComponentIndex)
383 assert(sycl_2020::isfinite(output));
384 atomicFetchAdd(a_virialAndEnergy[componentIndex], output);
391 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
392 PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::PmeSolveKernel()
397 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
398 void PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::setArg(size_t argIndex,
403 auto* params = reinterpret_cast<PmeGpuKernelParams*>(arg);
405 constParams_ = ¶ms->constants;
406 gridParams_ = ¶ms->grid;
407 solveKernelParams_.ewaldFactor = params->grid.ewaldFactor;
408 solveKernelParams_.realGridSize = params->grid.realGridSize;
409 solveKernelParams_.complexGridSize = params->grid.complexGridSize;
410 solveKernelParams_.complexGridSizePadded = params->grid.complexGridSizePadded;
411 solveKernelParams_.splineValuesOffset = params->grid.splineValuesOffset;
412 solveKernelParams_.recipBox[XX] = params->current.recipBox[XX];
413 solveKernelParams_.recipBox[YY] = params->current.recipBox[YY];
414 solveKernelParams_.recipBox[ZZ] = params->current.recipBox[ZZ];
415 solveKernelParams_.boxVolume = params->current.boxVolume;
416 solveKernelParams_.elFactor = params->constants.elFactor;
420 GMX_RELEASE_ASSERT(argIndex == 0, "Trying to pass too many args to the solve kernel");
424 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
425 cl::sycl::event PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::launch(
426 const KernelLaunchConfig& config,
427 const DeviceStream& deviceStream)
429 GMX_RELEASE_ASSERT(gridParams_, "Can not launch the kernel before setting its args");
430 GMX_RELEASE_ASSERT(constParams_, "Can not launch the kernel before setting its args");
432 using KernelNameType = PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>;
434 // SYCL has different multidimensional layout than OpenCL/CUDA.
435 const cl::sycl::range<3> localSize{ config.blockSize[2], config.blockSize[1], config.blockSize[0] };
436 const cl::sycl::range<3> groupRange{ config.gridSize[2], config.gridSize[1], config.gridSize[0] };
437 const cl::sycl::nd_range<3> range{ groupRange * localSize, localSize };
439 cl::sycl::queue q = deviceStream.stream();
441 cl::sycl::buffer<SolveKernelParams, 1> d_solveKernelParams(&solveKernelParams_, 1);
442 cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
443 auto kernel = makeSolveKernel<gridOrdering, computeEnergyAndVirial, subGroupSize>(
445 gridParams_->d_splineModuli[gridIndex],
447 constParams_->d_virialAndEnergy[gridIndex],
448 gridParams_->d_fourierGrid[gridIndex]);
449 cgh.parallel_for<KernelNameType>(range, kernel);
452 // Delete set args, so we don't forget to set them before the next launch.
458 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int gridIndex, int subGroupSize>
459 void PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex, subGroupSize>::reset()
461 gridParams_ = nullptr;
462 constParams_ = nullptr;
465 //! Kernel class instantiations
466 /* Disable the "explicit template instantiation 'PmeSplineAndSpreadKernel<...>' will emit a vtable in every
467 * translation unit [-Wweak-template-vtables]" warning.
468 * It is only explicitly instantiated in this translation unit, so we should be safe.
471 # pragma clang diagnostic push
472 # pragma clang diagnostic ignored "-Wweak-template-vtables"
475 #define INSTANTIATE(subGroupSize) \
476 template class PmeSolveKernel<GridOrdering::XYZ, false, 0, subGroupSize>; \
477 template class PmeSolveKernel<GridOrdering::XYZ, true, 0, subGroupSize>; \
478 template class PmeSolveKernel<GridOrdering::YZX, false, 0, subGroupSize>; \
479 template class PmeSolveKernel<GridOrdering::YZX, true, 0, subGroupSize>; \
480 template class PmeSolveKernel<GridOrdering::XYZ, false, 1, subGroupSize>; \
481 template class PmeSolveKernel<GridOrdering::XYZ, true, 1, subGroupSize>; \
482 template class PmeSolveKernel<GridOrdering::YZX, false, 1, subGroupSize>; \
483 template class PmeSolveKernel<GridOrdering::YZX, true, 1, subGroupSize>;
487 #elif GMX_SYCL_HIPSYCL
493 # pragma clang diagnostic pop