From: Andrey Alekseenko Date: Thu, 7 Oct 2021 14:02:38 +0000 (+0200) Subject: SYCL: Reduce local memory usage of LINCS kernel X-Git-Url: http://biod.pnpi.spb.ru/gitweb/?a=commitdiff_plain;h=139efd08f9428376d2dc8964c15f8278e8c50a14;p=alexxy%2Fgromacs.git SYCL: Reduce local memory usage of LINCS kernel Similar to the CUDA code, we can reuse the local memory buffer. Closes #4202. --- diff --git a/src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp b/src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp index 9627ce8d2c..fcb83c01bd 100644 --- a/src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp +++ b/src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp @@ -145,30 +145,16 @@ auto lincsKernel(cl::sycl::handler& cgh, cgh.require(a_virialScaled); } - // shmem buffer for local distances - auto sm_r = [&]() { - return cl::sycl::accessor( - cl::sycl::range<1>(c_threadsPerBlock), cgh); - }(); - - // shmem buffer for right-hand-side values - auto sm_rhs = [&]() { - return cl::sycl::accessor( - cl::sycl::range<1>(c_threadsPerBlock * 2), cgh); - }(); - - // shmem buffer for virial components - auto sm_threadVirial = [&]() { - if constexpr (computeVirial) - { - return cl::sycl::accessor( - cl::sycl::range<1>(c_threadsPerBlock * 6), cgh); - } - else - { - return nullptr; - } - }(); + /* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA. + * sh_r: one Float3 per thread. + * sh_rhs: two floats per thread. + * sm_threadVirial: six floats per thread. + * So, without virials we need max(1*3, 2) floats, and with virials we need max(1*3, 2, 6) floats. + */ + static constexpr int smBufferElementsPerThread = computeVirial ? 6 : 3; + cl::sycl::accessor sm_buffer{ + cl::sycl::range<1>(c_threadsPerBlock * smBufferElementsPerThread), cgh + }; return [=](cl::sycl::nd_item<1> itemIdx) { const int threadIndex = itemIdx.get_global_linear_id(); @@ -223,7 +209,9 @@ auto lincsKernel(cl::sycl::handler& cgh, rc = rlen * dx; } - sm_r[threadInBlock] = rc; + sm_buffer[threadInBlock * DIM + XX] = rc[XX]; + sm_buffer[threadInBlock * DIM + YY] = rc[YY]; + sm_buffer[threadInBlock * DIM + ZZ] = rc[ZZ]; // Make sure that all r's are saved into shared memory // before they are accessed in the loop below itemIdx.barrier(fence_space::global_and_local); @@ -241,7 +229,7 @@ auto lincsKernel(cl::sycl::handler& cgh, int index = n * numConstraintsThreads + threadIndex; int c1 = a_coupledConstraintsIndices[index]; - Float3 rc1 = sm_r[c1]; + Float3 rc1{ sm_buffer[c1 * DIM + XX], sm_buffer[c1 * DIM + YY], sm_buffer[c1 * DIM + ZZ] }; a_matrixA[index] = a_massFactors[index] * (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]); } @@ -267,15 +255,16 @@ auto lincsKernel(cl::sycl::handler& cgh, * Inverse matrix using a set of expansionOrder matrix multiplications */ - // This will use the same memory space as sm_r, which is no longer needed. - sm_rhs[threadInBlock] = sol; + // This will reuse the same buffer, because the old values are no longer needed. + itemIdx.barrier(fence_space::local_space); + sm_buffer[threadInBlock] = sol; // No need to iterate if there are no coupled constraints. if constexpr (haveCoupledConstraints) { for (int rec = 0; rec < expansionOrder; rec++) { - // Making sure that all sm_rhs are saved before they are accessed in a loop below + // Making sure that all sm_buffer values are saved before they are accessed in a loop below itemIdx.barrier(fence_space::global_and_local); float mvb = 0.0F; for (int n = 0; n < coupledConstraintsCount; n++) @@ -283,13 +272,14 @@ auto lincsKernel(cl::sycl::handler& cgh, int index = n * numConstraintsThreads + threadIndex; int c1 = a_coupledConstraintsIndices[index]; // Convolute current right-hand-side with A - // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations - mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)]; + // Different, non overlapping parts of sm_buffer[..] are read during odd and even iterations + mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)]; } // 'Switch' rhs vectors, save current result // These values will be accessed in the loop above during the next iteration. - sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb; - sol = sol + mvb; + sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb; + + sol = sol + mvb; } } @@ -351,8 +341,8 @@ auto lincsKernel(cl::sycl::handler& cgh, proj = sqrtReducedMass * targetLength; } - sm_rhs[threadInBlock] = proj; - float sol = proj; + sm_buffer[threadInBlock] = proj; + float sol = proj; /* * Same matrix inversion as above is used for updated data @@ -369,11 +359,11 @@ auto lincsKernel(cl::sycl::handler& cgh, int index = n * numConstraintsThreads + threadIndex; int c1 = a_coupledConstraintsIndices[index]; - mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)]; + mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)]; } - sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb; - sol = sol + mvb; + sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb; + sol = sol + mvb; } } @@ -426,13 +416,15 @@ auto lincsKernel(cl::sycl::handler& cgh, // Save virial for each thread into the shared memory. Tensor is symmetrical, hence only // 6 values are saved. Dummy threads will have zeroes in their virial: targetLength, // lagrangeScaled and rc are all set to zero for them in the beginning of the kernel. - float mult = targetLength * lagrangeScaled; - sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX]; - sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY]; - sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ]; - sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY]; - sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ]; - sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ]; + // We reuse the same shared memory buffer, so we make sure we don't need its old values: + itemIdx.barrier(fence_space::local_space); + float mult = targetLength * lagrangeScaled; + sm_buffer[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX]; + sm_buffer[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY]; + sm_buffer[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ]; + sm_buffer[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY]; + sm_buffer[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ]; + sm_buffer[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ]; itemIdx.barrier(fence_space::local_space); // This casts unsigned into signed integers to avoid clang warnings @@ -452,8 +444,7 @@ auto lincsKernel(cl::sycl::handler& cgh, { for (int d = 0; d < 6; d++) { - sm_threadVirial[d * blockSize + tib] += - sm_threadVirial[d * blockSize + (tib + dividedAt)]; + sm_buffer[d * blockSize + tib] += sm_buffer[d * blockSize + (tib + dividedAt)]; } } if (dividedAt > subGroupSize / 2) @@ -468,7 +459,7 @@ auto lincsKernel(cl::sycl::handler& cgh, // First 6 threads in the block add the 6 components of virial to the global memory address if (tib < 6) { - atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]); + atomicFetchAdd(a_virialScaled[tib], sm_buffer[tib * blockSize]); } } };