* \tparam subGroupSize Describes the width of a SYCL subgroup
*/
template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int subGroupSize>
-auto makeSolveKernel(cl::sycl::handler& cgh,
- DeviceAccessor<float, mode::read> a_splineModuli,
- DeviceAccessor<SolveKernelParams, mode::read> a_solveKernelParams,
+auto makeSolveKernel(cl::sycl::handler& cgh,
+ DeviceAccessor<float, mode::read> a_splineModuli,
+ SolveKernelParams solveKernelParams,
OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
DeviceAccessor<float, mode::read_write> a_fourierGrid)
{
a_splineModuli.bind(cgh);
- a_solveKernelParams.bind(cgh);
if constexpr (computeEnergyAndVirial)
{
a_virialAndEnergy.bind(cgh);
/* Global memory pointers */
const float* __restrict__ gm_splineValueMajor =
- a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[majorDim];
+ a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[majorDim];
const float* __restrict__ gm_splineValueMiddle =
- a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[middleDim];
+ a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[middleDim];
const float* __restrict__ gm_splineValueMinor =
- a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[minorDim];
+ a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[minorDim];
// The Fourier grid is allocated as float values, even though
// it logically contains complex values. (It also can be
// the same memory as the real grid for in-place transforms.)
/* Various grid sizes and indices */
const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0;
- const int localSizeMinor = a_solveKernelParams[0].complexGridSizePadded[minorDim];
- const int localSizeMiddle = a_solveKernelParams[0].complexGridSizePadded[middleDim];
- const int localCountMiddle = a_solveKernelParams[0].complexGridSize[middleDim];
- const int localCountMinor = a_solveKernelParams[0].complexGridSize[minorDim];
- const int nMajor = a_solveKernelParams[0].realGridSize[majorDim];
- const int nMiddle = a_solveKernelParams[0].realGridSize[middleDim];
- const int nMinor = a_solveKernelParams[0].realGridSize[minorDim];
+ const int localSizeMinor = solveKernelParams.complexGridSizePadded[minorDim];
+ const int localSizeMiddle = solveKernelParams.complexGridSizePadded[middleDim];
+ const int localCountMiddle = solveKernelParams.complexGridSize[middleDim];
+ const int localCountMinor = solveKernelParams.complexGridSize[minorDim];
+ const int nMajor = solveKernelParams.realGridSize[majorDim];
+ const int nMiddle = solveKernelParams.realGridSize[middleDim];
+ const int nMinor = solveKernelParams.realGridSize[minorDim];
const int maxkMajor = (nMajor + 1) / 2; // X or Y
const int maxkMiddle = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
const int maxkMinor = (nMinor + 1) / 2; // Z or X => only check for YZX
float viryz = 0.0F;
float virzz = 0.0F;
- assert(indexMajor < a_solveKernelParams[0].complexGridSize[majorDim]);
+ assert(indexMajor < solveKernelParams.complexGridSize[majorDim]);
if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
& (gridLineIndex < gridLinesPerBlock))
{
if (notZeroPoint)
{
- const float mhxk = mX * a_solveKernelParams[0].recipBox[XX][XX];
- const float mhyk = mX * a_solveKernelParams[0].recipBox[XX][YY]
- + mY * a_solveKernelParams[0].recipBox[YY][YY];
- const float mhzk = mX * a_solveKernelParams[0].recipBox[XX][ZZ]
- + mY * a_solveKernelParams[0].recipBox[YY][ZZ]
- + mZ * a_solveKernelParams[0].recipBox[ZZ][ZZ];
+ const float mhxk = mX * solveKernelParams.recipBox[XX][XX];
+ const float mhyk = mX * solveKernelParams.recipBox[XX][YY]
+ + mY * solveKernelParams.recipBox[YY][YY];
+ const float mhzk = mX * solveKernelParams.recipBox[XX][ZZ]
+ + mY * solveKernelParams.recipBox[YY][ZZ]
+ + mZ * solveKernelParams.recipBox[ZZ][ZZ];
const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
assert(m2k != 0.0F);
- float denom = m2k * float(M_PI) * a_solveKernelParams[0].boxVolume
- * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
- * gm_splineValueMinor[kMinor];
+ float denom = m2k * float(M_PI) * solveKernelParams.boxVolume * gm_splineValueMajor[kMajor]
+ * gm_splineValueMiddle[kMiddle] * gm_splineValueMinor[kMinor];
assert(sycl_2020::isfinite(denom));
assert(denom != 0.0F);
- const float tmp1 = cl::sycl::exp(-a_solveKernelParams[0].ewaldFactor * m2k);
- const float etermk = a_solveKernelParams[0].elFactor * tmp1 / denom;
+ const float tmp1 = cl::sycl::exp(-solveKernelParams.ewaldFactor * m2k);
+ const float etermk = solveKernelParams.elFactor * tmp1 / denom;
// sycl::float2::load and store are buggy in hipSYCL,
// but can probably be used after resolution of
{
const float tmp1k = 2.0F * cl::sycl::dot(gridValue, oldGridValue);
- float vfactor = (a_solveKernelParams[0].ewaldFactor + 1.0F / m2k) * 2.0F;
+ float vfactor = (solveKernelParams.ewaldFactor + 1.0F / m2k) * 2.0F;
float ets2 = corner_fac * tmp1k;
energy = ets2;
cl::sycl::queue q = deviceStream.stream();
- cl::sycl::buffer<SolveKernelParams, 1> d_solveKernelParams(&solveKernelParams_, 1);
- cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
+ cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
auto kernel = makeSolveKernel<gridOrdering, computeEnergyAndVirial, subGroupSize>(
cgh,
gridParams_->d_splineModuli[gridIndex],
- d_solveKernelParams,
+ solveKernelParams_,
constParams_->d_virialAndEnergy[gridIndex],
gridParams_->d_fourierGrid[gridIndex]);
cgh.parallel_for<KernelNameType>(range, kernel);