From: Andrey Alekseenko Date: Tue, 2 Nov 2021 12:53:09 +0000 (+0100) Subject: Avoid allocating SYCL buffer on each call to PME solve X-Git-Url: http://biod.pnpi.spb.ru/gitweb/?p=alexxy%2Fgromacs.git;a=commitdiff_plain;h=5eeff915b3e0bf1a03a80789d119a764943aded7 Avoid allocating SYCL buffer on each call to PME solve Refs #4153 --- diff --git a/src/gromacs/ewald/pme_solve_sycl.cpp b/src/gromacs/ewald/pme_solve_sycl.cpp index 5a0125c44d..b93dd456b3 100644 --- a/src/gromacs/ewald/pme_solve_sycl.cpp +++ b/src/gromacs/ewald/pme_solve_sycl.cpp @@ -62,14 +62,13 @@ using cl::sycl::access::mode; * \tparam subGroupSize Describes the width of a SYCL subgroup */ template -auto makeSolveKernel(cl::sycl::handler& cgh, - DeviceAccessor a_splineModuli, - DeviceAccessor a_solveKernelParams, +auto makeSolveKernel(cl::sycl::handler& cgh, + DeviceAccessor a_splineModuli, + SolveKernelParams solveKernelParams, OptionalAccessor a_virialAndEnergy, DeviceAccessor a_fourierGrid) { a_splineModuli.bind(cgh); - a_solveKernelParams.bind(cgh); if constexpr (computeEnergyAndVirial) { a_virialAndEnergy.bind(cgh); @@ -112,11 +111,11 @@ auto makeSolveKernel(cl::sycl::handler& cgh, /* Global memory pointers */ const float* __restrict__ gm_splineValueMajor = - a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[majorDim]; + a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[majorDim]; const float* __restrict__ gm_splineValueMiddle = - a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[middleDim]; + a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[middleDim]; const float* __restrict__ gm_splineValueMinor = - a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[minorDim]; + a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[minorDim]; // The Fourier grid is allocated as float values, even though // it logically contains complex values. (It also can be // the same memory as the real grid for in-place transforms.) @@ -134,13 +133,13 @@ auto makeSolveKernel(cl::sycl::handler& cgh, /* Various grid sizes and indices */ const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0; - const int localSizeMinor = a_solveKernelParams[0].complexGridSizePadded[minorDim]; - const int localSizeMiddle = a_solveKernelParams[0].complexGridSizePadded[middleDim]; - const int localCountMiddle = a_solveKernelParams[0].complexGridSize[middleDim]; - const int localCountMinor = a_solveKernelParams[0].complexGridSize[minorDim]; - const int nMajor = a_solveKernelParams[0].realGridSize[majorDim]; - const int nMiddle = a_solveKernelParams[0].realGridSize[middleDim]; - const int nMinor = a_solveKernelParams[0].realGridSize[minorDim]; + const int localSizeMinor = solveKernelParams.complexGridSizePadded[minorDim]; + const int localSizeMiddle = solveKernelParams.complexGridSizePadded[middleDim]; + const int localCountMiddle = solveKernelParams.complexGridSize[middleDim]; + const int localCountMinor = solveKernelParams.complexGridSize[minorDim]; + const int nMajor = solveKernelParams.realGridSize[majorDim]; + const int nMiddle = solveKernelParams.realGridSize[middleDim]; + const int nMinor = solveKernelParams.realGridSize[minorDim]; const int maxkMajor = (nMajor + 1) / 2; // X or Y const int maxkMiddle = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX const int maxkMinor = (nMinor + 1) / 2; // Z or X => only check for YZX @@ -165,7 +164,7 @@ auto makeSolveKernel(cl::sycl::handler& cgh, float viryz = 0.0F; float virzz = 0.0F; - assert(indexMajor < a_solveKernelParams[0].complexGridSize[majorDim]); + assert(indexMajor < solveKernelParams.complexGridSize[majorDim]); if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor) & (gridLineIndex < gridLinesPerBlock)) { @@ -235,23 +234,22 @@ auto makeSolveKernel(cl::sycl::handler& cgh, if (notZeroPoint) { - const float mhxk = mX * a_solveKernelParams[0].recipBox[XX][XX]; - const float mhyk = mX * a_solveKernelParams[0].recipBox[XX][YY] - + mY * a_solveKernelParams[0].recipBox[YY][YY]; - const float mhzk = mX * a_solveKernelParams[0].recipBox[XX][ZZ] - + mY * a_solveKernelParams[0].recipBox[YY][ZZ] - + mZ * a_solveKernelParams[0].recipBox[ZZ][ZZ]; + const float mhxk = mX * solveKernelParams.recipBox[XX][XX]; + const float mhyk = mX * solveKernelParams.recipBox[XX][YY] + + mY * solveKernelParams.recipBox[YY][YY]; + const float mhzk = mX * solveKernelParams.recipBox[XX][ZZ] + + mY * solveKernelParams.recipBox[YY][ZZ] + + mZ * solveKernelParams.recipBox[ZZ][ZZ]; const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk; assert(m2k != 0.0F); - float denom = m2k * float(M_PI) * a_solveKernelParams[0].boxVolume - * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle] - * gm_splineValueMinor[kMinor]; + float denom = m2k * float(M_PI) * solveKernelParams.boxVolume * gm_splineValueMajor[kMajor] + * gm_splineValueMiddle[kMiddle] * gm_splineValueMinor[kMinor]; assert(sycl_2020::isfinite(denom)); assert(denom != 0.0F); - const float tmp1 = cl::sycl::exp(-a_solveKernelParams[0].ewaldFactor * m2k); - const float etermk = a_solveKernelParams[0].elFactor * tmp1 / denom; + const float tmp1 = cl::sycl::exp(-solveKernelParams.ewaldFactor * m2k); + const float etermk = solveKernelParams.elFactor * tmp1 / denom; // sycl::float2::load and store are buggy in hipSYCL, // but can probably be used after resolution of @@ -267,7 +265,7 @@ auto makeSolveKernel(cl::sycl::handler& cgh, { const float tmp1k = 2.0F * cl::sycl::dot(gridValue, oldGridValue); - float vfactor = (a_solveKernelParams[0].ewaldFactor + 1.0F / m2k) * 2.0F; + float vfactor = (solveKernelParams.ewaldFactor + 1.0F / m2k) * 2.0F; float ets2 = corner_fac * tmp1k; energy = ets2; @@ -438,12 +436,11 @@ cl::sycl::event PmeSolveKernel d_solveKernelParams(&solveKernelParams_, 1); - cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) { + cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) { auto kernel = makeSolveKernel( cgh, gridParams_->d_splineModuli[gridIndex], - d_solveKernelParams, + solveKernelParams_, constParams_->d_virialAndEnergy[gridIndex], gridParams_->d_fourierGrid[gridIndex]); cgh.parallel_for(range, kernel);