SYCL: Reduce local memory usage of LINCS kernel
authorAndrey Alekseenko <al42and@gmail.com>
Thu, 7 Oct 2021 14:02:38 +0000 (16:02 +0200)
committerMark Abraham <mark.j.abraham@gmail.com>
Sat, 9 Oct 2021 00:50:20 +0000 (00:50 +0000)
Similar to the CUDA code, we can reuse the local memory buffer.

Closes #4202.

src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp

index 9627ce8d2c660ab257f0da9910736bc8d2cc8802..fcb83c01bd5806d63a804a5abd1d36cb6cf4a691 100644 (file)
@@ -145,30 +145,16 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
         cgh.require(a_virialScaled);
     }
 
-    // shmem buffer for local distances
-    auto sm_r = [&]() {
-        return cl::sycl::accessor<Float3, 1, mode::read_write, target::local>(
-                cl::sycl::range<1>(c_threadsPerBlock), cgh);
-    }();
-
-    // shmem buffer for right-hand-side values
-    auto sm_rhs = [&]() {
-        return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
-                cl::sycl::range<1>(c_threadsPerBlock * 2), cgh);
-    }();
-
-    // shmem buffer for virial components
-    auto sm_threadVirial = [&]() {
-        if constexpr (computeVirial)
-        {
-            return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
-                    cl::sycl::range<1>(c_threadsPerBlock * 6), cgh);
-        }
-        else
-        {
-            return nullptr;
-        }
-    }();
+    /* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
+     * sh_r: one Float3 per thread.
+     * sh_rhs: two floats per thread.
+     * sm_threadVirial: six floats per thread.
+     * So, without virials we need max(1*3, 2) floats, and with virials we need max(1*3, 2, 6) floats.
+     */
+    static constexpr int smBufferElementsPerThread = computeVirial ? 6 : 3;
+    cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buffer{
+        cl::sycl::range<1>(c_threadsPerBlock * smBufferElementsPerThread), cgh
+    };
 
     return [=](cl::sycl::nd_item<1> itemIdx) {
         const int threadIndex   = itemIdx.get_global_linear_id();
@@ -223,7 +209,9 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             rc         = rlen * dx;
         }
 
-        sm_r[threadInBlock] = rc;
+        sm_buffer[threadInBlock * DIM + XX] = rc[XX];
+        sm_buffer[threadInBlock * DIM + YY] = rc[YY];
+        sm_buffer[threadInBlock * DIM + ZZ] = rc[ZZ];
         // Make sure that all r's are saved into shared memory
         // before they are accessed in the loop below
         itemIdx.barrier(fence_space::global_and_local);
@@ -241,7 +229,7 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                 int index = n * numConstraintsThreads + threadIndex;
                 int c1    = a_coupledConstraintsIndices[index];
 
-                Float3 rc1       = sm_r[c1];
+                Float3 rc1{ sm_buffer[c1 * DIM + XX], sm_buffer[c1 * DIM + YY], sm_buffer[c1 * DIM + ZZ] };
                 a_matrixA[index] = a_massFactors[index]
                                    * (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]);
             }
@@ -267,15 +255,16 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
          *  Inverse matrix using a set of expansionOrder matrix multiplications
          */
 
-        // This will use the same memory space as sm_r, which is no longer needed.
-        sm_rhs[threadInBlock] = sol;
+        // This will reuse the same buffer, because the old values are no longer needed.
+        itemIdx.barrier(fence_space::local_space);
+        sm_buffer[threadInBlock] = sol;
 
         // No need to iterate if there are no coupled constraints.
         if constexpr (haveCoupledConstraints)
         {
             for (int rec = 0; rec < expansionOrder; rec++)
             {
-                // Making sure that all sm_rhs are saved before they are accessed in a loop below
+                // Making sure that all sm_buffer values are saved before they are accessed in a loop below
                 itemIdx.barrier(fence_space::global_and_local);
                 float mvb = 0.0F;
                 for (int n = 0; n < coupledConstraintsCount; n++)
@@ -283,13 +272,14 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                     int index = n * numConstraintsThreads + threadIndex;
                     int c1    = a_coupledConstraintsIndices[index];
                     // Convolute current right-hand-side with A
-                    // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
-                    mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+                    // Different, non overlapping parts of sm_buffer[..] are read during odd and even iterations
+                    mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
                 }
                 // 'Switch' rhs vectors, save current result
                 // These values will be accessed in the loop above during the next iteration.
-                sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
-                sol                                                         = sol + mvb;
+                sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+
+                sol = sol + mvb;
             }
         }
 
@@ -351,8 +341,8 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                 proj = sqrtReducedMass * targetLength;
             }
 
-            sm_rhs[threadInBlock] = proj;
-            float sol             = proj;
+            sm_buffer[threadInBlock] = proj;
+            float sol                = proj;
 
             /*
              * Same matrix inversion as above is used for updated data
@@ -369,11 +359,11 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                         int index = n * numConstraintsThreads + threadIndex;
                         int c1    = a_coupledConstraintsIndices[index];
 
-                        mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+                        mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
                     }
 
-                    sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
-                    sol                                                         = sol + mvb;
+                    sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+                    sol                                                            = sol + mvb;
                 }
             }
 
@@ -426,13 +416,15 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             // Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
             // 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
             // lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
-            float mult                                             = targetLength * lagrangeScaled;
-            sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
-            sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
-            sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
-            sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
-            sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
-            sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
+            // We reuse the same shared memory buffer, so we make sure we don't need its old values:
+            itemIdx.barrier(fence_space::local_space);
+            float mult                                       = targetLength * lagrangeScaled;
+            sm_buffer[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
+            sm_buffer[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
+            sm_buffer[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
+            sm_buffer[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
+            sm_buffer[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
+            sm_buffer[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
 
             itemIdx.barrier(fence_space::local_space);
             // This casts unsigned into signed integers to avoid clang warnings
@@ -452,8 +444,7 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                 {
                     for (int d = 0; d < 6; d++)
                     {
-                        sm_threadVirial[d * blockSize + tib] +=
-                                sm_threadVirial[d * blockSize + (tib + dividedAt)];
+                        sm_buffer[d * blockSize + tib] += sm_buffer[d * blockSize + (tib + dividedAt)];
                     }
                 }
                 if (dividedAt > subGroupSize / 2)
@@ -468,7 +459,7 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             // First 6 threads in the block add the 6 components of virial to the global memory address
             if (tib < 6)
             {
-                atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]);
+                atomicFetchAdd(a_virialScaled[tib], sm_buffer[tib * blockSize]);
             }
         }
     };