Avoid allocating SYCL buffer on each call to PME solve
authorAndrey Alekseenko <al42and@gmail.com>
Tue, 2 Nov 2021 12:53:09 +0000 (13:53 +0100)
committerAndrey Alekseenko <al42and@gmail.com>
Wed, 3 Nov 2021 06:49:42 +0000 (07:49 +0100)
Refs #4153

src/gromacs/ewald/pme_solve_sycl.cpp

index 5a0125c44d286ddcc71bc29f6dd5d80fd5024388..b93dd456b3b788354efdab185ea3288a67b6d692 100644 (file)
@@ -62,14 +62,13 @@ using cl::sycl::access::mode;
  * \tparam     subGroupSize             Describes the width of a SYCL subgroup
  */
 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, int subGroupSize>
-auto makeSolveKernel(cl::sycl::handler&                            cgh,
-                     DeviceAccessor<float, mode::read>             a_splineModuli,
-                     DeviceAccessor<SolveKernelParams, mode::read> a_solveKernelParams,
+auto makeSolveKernel(cl::sycl::handler&                cgh,
+                     DeviceAccessor<float, mode::read> a_splineModuli,
+                     SolveKernelParams                 solveKernelParams,
                      OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
                      DeviceAccessor<float, mode::read_write> a_fourierGrid)
 {
     a_splineModuli.bind(cgh);
-    a_solveKernelParams.bind(cgh);
     if constexpr (computeEnergyAndVirial)
     {
         a_virialAndEnergy.bind(cgh);
@@ -112,11 +111,11 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
 
         /* Global memory pointers */
         const float* __restrict__ gm_splineValueMajor =
-                a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[majorDim];
+                a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[majorDim];
         const float* __restrict__ gm_splineValueMiddle =
-                a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[middleDim];
+                a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[middleDim];
         const float* __restrict__ gm_splineValueMinor =
-                a_splineModuli.get_pointer() + a_solveKernelParams[0].splineValuesOffset[minorDim];
+                a_splineModuli.get_pointer() + solveKernelParams.splineValuesOffset[minorDim];
         // The Fourier grid is allocated as float values, even though
         // it logically contains complex values. (It also can be
         // the same memory as the real grid for in-place transforms.)
@@ -134,13 +133,13 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
 
         /* Various grid sizes and indices */
         const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0;
-        const int localSizeMinor   = a_solveKernelParams[0].complexGridSizePadded[minorDim];
-        const int localSizeMiddle  = a_solveKernelParams[0].complexGridSizePadded[middleDim];
-        const int localCountMiddle = a_solveKernelParams[0].complexGridSize[middleDim];
-        const int localCountMinor  = a_solveKernelParams[0].complexGridSize[minorDim];
-        const int nMajor           = a_solveKernelParams[0].realGridSize[majorDim];
-        const int nMiddle          = a_solveKernelParams[0].realGridSize[middleDim];
-        const int nMinor           = a_solveKernelParams[0].realGridSize[minorDim];
+        const int localSizeMinor   = solveKernelParams.complexGridSizePadded[minorDim];
+        const int localSizeMiddle  = solveKernelParams.complexGridSizePadded[middleDim];
+        const int localCountMiddle = solveKernelParams.complexGridSize[middleDim];
+        const int localCountMinor  = solveKernelParams.complexGridSize[minorDim];
+        const int nMajor           = solveKernelParams.realGridSize[majorDim];
+        const int nMiddle          = solveKernelParams.realGridSize[middleDim];
+        const int nMinor           = solveKernelParams.realGridSize[minorDim];
         const int maxkMajor        = (nMajor + 1) / 2;  // X or Y
         const int maxkMiddle       = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
         const int maxkMinor        = (nMinor + 1) / 2;  // Z or X => only check for YZX
@@ -165,7 +164,7 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
         float viryz  = 0.0F;
         float virzz  = 0.0F;
 
-        assert(indexMajor < a_solveKernelParams[0].complexGridSize[majorDim]);
+        assert(indexMajor < solveKernelParams.complexGridSize[majorDim]);
         if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
             & (gridLineIndex < gridLinesPerBlock))
         {
@@ -235,23 +234,22 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
 
             if (notZeroPoint)
             {
-                const float mhxk = mX * a_solveKernelParams[0].recipBox[XX][XX];
-                const float mhyk = mX * a_solveKernelParams[0].recipBox[XX][YY]
-                                   + mY * a_solveKernelParams[0].recipBox[YY][YY];
-                const float mhzk = mX * a_solveKernelParams[0].recipBox[XX][ZZ]
-                                   + mY * a_solveKernelParams[0].recipBox[YY][ZZ]
-                                   + mZ * a_solveKernelParams[0].recipBox[ZZ][ZZ];
+                const float mhxk = mX * solveKernelParams.recipBox[XX][XX];
+                const float mhyk = mX * solveKernelParams.recipBox[XX][YY]
+                                   + mY * solveKernelParams.recipBox[YY][YY];
+                const float mhzk = mX * solveKernelParams.recipBox[XX][ZZ]
+                                   + mY * solveKernelParams.recipBox[YY][ZZ]
+                                   + mZ * solveKernelParams.recipBox[ZZ][ZZ];
 
                 const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
                 assert(m2k != 0.0F);
-                float denom = m2k * float(M_PI) * a_solveKernelParams[0].boxVolume
-                              * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
-                              * gm_splineValueMinor[kMinor];
+                float denom = m2k * float(M_PI) * solveKernelParams.boxVolume * gm_splineValueMajor[kMajor]
+                              * gm_splineValueMiddle[kMiddle] * gm_splineValueMinor[kMinor];
                 assert(sycl_2020::isfinite(denom));
                 assert(denom != 0.0F);
 
-                const float tmp1   = cl::sycl::exp(-a_solveKernelParams[0].ewaldFactor * m2k);
-                const float etermk = a_solveKernelParams[0].elFactor * tmp1 / denom;
+                const float tmp1   = cl::sycl::exp(-solveKernelParams.ewaldFactor * m2k);
+                const float etermk = solveKernelParams.elFactor * tmp1 / denom;
 
                 // sycl::float2::load and store are buggy in hipSYCL,
                 // but can probably be used after resolution of
@@ -267,7 +265,7 @@ auto makeSolveKernel(cl::sycl::handler&                            cgh,
                 {
                     const float tmp1k = 2.0F * cl::sycl::dot(gridValue, oldGridValue);
 
-                    float vfactor = (a_solveKernelParams[0].ewaldFactor + 1.0F / m2k) * 2.0F;
+                    float vfactor = (solveKernelParams.ewaldFactor + 1.0F / m2k) * 2.0F;
                     float ets2    = corner_fac * tmp1k;
                     energy        = ets2;
 
@@ -438,12 +436,11 @@ cl::sycl::event PmeSolveKernel<gridOrdering, computeEnergyAndVirial, gridIndex,
 
     cl::sycl::queue q = deviceStream.stream();
 
-    cl::sycl::buffer<SolveKernelParams, 1> d_solveKernelParams(&solveKernelParams_, 1);
-    cl::sycl::event                        e = q.submit([&](cl::sycl::handler& cgh) {
+    cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = makeSolveKernel<gridOrdering, computeEnergyAndVirial, subGroupSize>(
                 cgh,
                 gridParams_->d_splineModuli[gridIndex],
-                d_solveKernelParams,
+                solveKernelParams_,
                 constParams_->d_virialAndEnergy[gridIndex],
                 gridParams_->d_fourierGrid[gridIndex]);
         cgh.parallel_for<KernelNameType>(range, kernel);