SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / mdlib / lincs_gpu_internal_sycl.cpp
index 3c39f43f4b6b967329b9b124b7c72b8d64427315..60b10f3f8430a8eab42760e77ae0d8f3c820f55c 100644 (file)
@@ -118,57 +118,43 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                  const int                                                         numIterations,
                  const int                                                         expansionOrder,
                  DeviceAccessor<Float3, mode::read>                                a_x,
-                 DeviceAccessor<float, mode::read_write>                           a_xp,
+                 DeviceAccessor<Float3, mode::read_write>                          a_xp,
                  const float                                                       invdt,
-                 OptionalAccessor<float, mode::read_write, updateVelocities>       a_v,
+                 OptionalAccessor<Float3, mode::read_write, updateVelocities>      a_v,
                  OptionalAccessor<float, mode::read_write, computeVirial>          a_virialScaled,
                  PbcAiuc                                                           pbcAiuc)
 {
-    cgh.require(a_constraints);
-    cgh.require(a_constraintsTargetLengths);
+    a_constraints.bind(cgh);
+    a_constraintsTargetLengths.bind(cgh);
     if constexpr (haveCoupledConstraints)
     {
-        cgh.require(a_coupledConstraintsCounts);
-        cgh.require(a_coupledConstraintsIndices);
-        cgh.require(a_massFactors);
-        cgh.require(a_matrixA);
+        a_coupledConstraintsCounts.bind(cgh);
+        a_coupledConstraintsIndices.bind(cgh);
+        a_massFactors.bind(cgh);
+        a_matrixA.bind(cgh);
     }
-    cgh.require(a_inverseMasses);
-    cgh.require(a_x);
-    cgh.require(a_xp);
+    a_inverseMasses.bind(cgh);
+    a_x.bind(cgh);
+    a_xp.bind(cgh);
     if constexpr (updateVelocities)
     {
-        cgh.require(a_v);
+        a_v.bind(cgh);
     }
     if constexpr (computeVirial)
     {
-        cgh.require(a_virialScaled);
+        a_virialScaled.bind(cgh);
     }
 
-    // shmem buffer for local distances
-    auto sm_r = [&]() {
-        return cl::sycl::accessor<Float3, 1, mode::read_write, target::local>(
-                cl::sycl::range<1>(c_threadsPerBlock), cgh);
-    }();
-
-    // shmem buffer for right-hand-side values
-    auto sm_rhs = [&]() {
-        return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
-                cl::sycl::range<1>(c_threadsPerBlock), cgh);
-    }();
-
-    // shmem buffer for virial components
-    auto sm_threadVirial = [&]() {
-        if constexpr (computeVirial)
-        {
-            return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
-                    cl::sycl::range<1>(c_threadsPerBlock * 6), cgh);
-        }
-        else
-        {
-            return nullptr;
-        }
-    }();
+    /* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
+     * sh_r: one Float3 per thread.
+     * sh_rhs: two floats per thread.
+     * sm_threadVirial: six floats per thread.
+     * So, without virials we need max(1*3, 2) floats, and with virials we need max(1*3, 2, 6) floats.
+     */
+    static constexpr int smBufferElementsPerThread = computeVirial ? 6 : 3;
+    cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buffer{
+        cl::sycl::range<1>(c_threadsPerBlock * smBufferElementsPerThread), cgh
+    };
 
     return [=](cl::sycl::nd_item<1> itemIdx) {
         const int threadIndex   = itemIdx.get_global_linear_id();
@@ -223,7 +209,9 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             rc         = rlen * dx;
         }
 
-        sm_r[threadIndex] = rc;
+        sm_buffer[threadInBlock * DIM + XX] = rc[XX];
+        sm_buffer[threadInBlock * DIM + YY] = rc[YY];
+        sm_buffer[threadInBlock * DIM + ZZ] = rc[ZZ];
         // Make sure that all r's are saved into shared memory
         // before they are accessed in the loop below
         itemIdx.barrier(fence_space::global_and_local);
@@ -241,7 +229,7 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                 int index = n * numConstraintsThreads + threadIndex;
                 int c1    = a_coupledConstraintsIndices[index];
 
-                Float3 rc1       = sm_r[c1];
+                Float3 rc1{ sm_buffer[c1 * DIM + XX], sm_buffer[c1 * DIM + YY], sm_buffer[c1 * DIM + ZZ] };
                 a_matrixA[index] = a_massFactors[index]
                                    * (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]);
             }
@@ -250,12 +238,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
         // Skipping in dummy threads
         if (!isDummyThread)
         {
-            xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
-            xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
-            xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
-            xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
-            xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
-            xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
+            xi[XX] = atomicLoad(a_xp[i][XX]);
+            xi[YY] = atomicLoad(a_xp[i][YY]);
+            xi[ZZ] = atomicLoad(a_xp[i][ZZ]);
+            xj[XX] = atomicLoad(a_xp[j][XX]);
+            xj[YY] = atomicLoad(a_xp[j][YY]);
+            xj[ZZ] = atomicLoad(a_xp[j][ZZ]);
         }
 
         Float3 dx;
@@ -267,15 +255,16 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
          *  Inverse matrix using a set of expansionOrder matrix multiplications
          */
 
-        // This will use the same memory space as sm_r, which is no longer needed.
-        sm_rhs[threadInBlock] = sol;
+        // This will reuse the same buffer, because the old values are no longer needed.
+        itemIdx.barrier(fence_space::local_space);
+        sm_buffer[threadInBlock] = sol;
 
         // No need to iterate if there are no coupled constraints.
         if constexpr (haveCoupledConstraints)
         {
             for (int rec = 0; rec < expansionOrder; rec++)
             {
-                // Making sure that all sm_rhs are saved before they are accessed in a loop below
+                // Making sure that all sm_buffer values are saved before they are accessed in a loop below
                 itemIdx.barrier(fence_space::global_and_local);
                 float mvb = 0.0F;
                 for (int n = 0; n < coupledConstraintsCount; n++)
@@ -283,13 +272,14 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                     int index = n * numConstraintsThreads + threadIndex;
                     int c1    = a_coupledConstraintsIndices[index];
                     // Convolute current right-hand-side with A
-                    // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
-                    mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+                    // Different, non overlapping parts of sm_buffer[..] are read during odd and even iterations
+                    mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
                 }
                 // 'Switch' rhs vectors, save current result
                 // These values will be accessed in the loop above during the next iteration.
-                sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
-                sol                                                         = sol + mvb;
+                sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+
+                sol = sol + mvb;
             }
         }
 
@@ -306,12 +296,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
              * Note: Using memory_scope::work_group for atomic_ref can be better here,
              * but for now we re-use the existing function for memory_scope::device atomics.
              */
-            atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
-            atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
-            atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
-            atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
-            atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
-            atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+            atomicFetchAdd(a_xp[i][XX], -tmp[XX] * inverseMassi);
+            atomicFetchAdd(a_xp[i][YY], -tmp[YY] * inverseMassi);
+            atomicFetchAdd(a_xp[i][ZZ], -tmp[ZZ] * inverseMassi);
+            atomicFetchAdd(a_xp[j][XX], tmp[XX] * inverseMassj);
+            atomicFetchAdd(a_xp[j][YY], tmp[YY] * inverseMassj);
+            atomicFetchAdd(a_xp[j][ZZ], tmp[ZZ] * inverseMassj);
         }
 
         /*
@@ -325,12 +315,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
 
             if (!isDummyThread)
             {
-                xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
-                xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
-                xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
-                xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
-                xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
-                xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
+                xi[XX] = atomicLoad(a_xp[i][XX]);
+                xi[YY] = atomicLoad(a_xp[i][YY]);
+                xi[ZZ] = atomicLoad(a_xp[i][ZZ]);
+                xj[XX] = atomicLoad(a_xp[j][XX]);
+                xj[YY] = atomicLoad(a_xp[j][YY]);
+                xj[ZZ] = atomicLoad(a_xp[j][ZZ]);
             }
 
             Float3 dx;
@@ -351,8 +341,8 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                 proj = sqrtReducedMass * targetLength;
             }
 
-            sm_rhs[threadInBlock] = proj;
-            float sol             = proj;
+            sm_buffer[threadInBlock] = proj;
+            float sol                = proj;
 
             /*
              * Same matrix inversion as above is used for updated data
@@ -369,11 +359,11 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                         int index = n * numConstraintsThreads + threadIndex;
                         int c1    = a_coupledConstraintsIndices[index];
 
-                        mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+                        mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
                     }
 
-                    sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
-                    sol                                                         = sol + mvb;
+                    sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+                    sol                                                            = sol + mvb;
                 }
             }
 
@@ -386,12 +376,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             if (!isDummyThread)
             {
                 Float3 tmp = rc * sqrtmu_sol;
-                atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
-                atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
-                atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
-                atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
-                atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
-                atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+                atomicFetchAdd(a_xp[i][XX], -tmp[XX] * inverseMassi);
+                atomicFetchAdd(a_xp[i][YY], -tmp[YY] * inverseMassi);
+                atomicFetchAdd(a_xp[i][ZZ], -tmp[ZZ] * inverseMassi);
+                atomicFetchAdd(a_xp[j][XX], tmp[XX] * inverseMassj);
+                atomicFetchAdd(a_xp[j][YY], tmp[YY] * inverseMassj);
+                atomicFetchAdd(a_xp[j][ZZ], tmp[ZZ] * inverseMassj);
             }
         }
 
@@ -401,12 +391,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             if (!isDummyThread)
             {
                 Float3 tmp = rc * invdt * lagrangeScaled;
-                atomicFetchAdd(a_v[i * DIM + XX], -tmp[XX] * inverseMassi);
-                atomicFetchAdd(a_v[i * DIM + YY], -tmp[YY] * inverseMassi);
-                atomicFetchAdd(a_v[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
-                atomicFetchAdd(a_v[j * DIM + XX], tmp[XX] * inverseMassj);
-                atomicFetchAdd(a_v[j * DIM + YY], tmp[YY] * inverseMassj);
-                atomicFetchAdd(a_v[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+                atomicFetchAdd(a_v[i][XX], -tmp[XX] * inverseMassi);
+                atomicFetchAdd(a_v[i][YY], -tmp[YY] * inverseMassi);
+                atomicFetchAdd(a_v[i][ZZ], -tmp[ZZ] * inverseMassi);
+                atomicFetchAdd(a_v[j][XX], tmp[XX] * inverseMassj);
+                atomicFetchAdd(a_v[j][YY], tmp[YY] * inverseMassj);
+                atomicFetchAdd(a_v[j][ZZ], tmp[ZZ] * inverseMassj);
             }
         }
 
@@ -426,13 +416,15 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             // Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
             // 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
             // lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
-            float mult                                             = targetLength * lagrangeScaled;
-            sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
-            sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
-            sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
-            sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
-            sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
-            sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
+            // We reuse the same shared memory buffer, so we make sure we don't need its old values:
+            itemIdx.barrier(fence_space::local_space);
+            float mult                                       = targetLength * lagrangeScaled;
+            sm_buffer[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
+            sm_buffer[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
+            sm_buffer[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
+            sm_buffer[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
+            sm_buffer[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
+            sm_buffer[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
 
             itemIdx.barrier(fence_space::local_space);
             // This casts unsigned into signed integers to avoid clang warnings
@@ -452,8 +444,7 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                 {
                     for (int d = 0; d < 6; d++)
                     {
-                        sm_threadVirial[d * blockSize + tib] +=
-                                sm_threadVirial[d * blockSize + (tib + dividedAt)];
+                        sm_buffer[d * blockSize + tib] += sm_buffer[d * blockSize + (tib + dividedAt)];
                     }
                 }
                 if (dividedAt > subGroupSize / 2)
@@ -468,7 +459,7 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             // First 6 threads in the block add the 6 components of virial to the global memory address
             if (tib < 6)
             {
-                atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]);
+                atomicFetchAdd(a_virialScaled[tib], sm_buffer[tib * blockSize]);
             }
         }
     };
@@ -523,12 +514,6 @@ void launchLincsGpuKernel(LincsGpuKernelParameters*   kernelParams,
                           const bool                  computeVirial,
                           const DeviceStream&         deviceStream)
 {
-    cl::sycl::buffer<Float3, 1> xp(*d_xp.buffer_);
-    auto                        d_xpAsFloat = xp.reinterpret<float, 1>(xp.get_count() * DIM);
-
-    cl::sycl::buffer<Float3, 1> v(*d_v.buffer_);
-    auto                        d_vAsFloat = v.reinterpret<float, 1>(v.get_count() * DIM);
-
     launchLincsKernel(updateVelocities,
                       computeVirial,
                       kernelParams->haveCoupledConstraints,
@@ -544,9 +529,9 @@ void launchLincsGpuKernel(LincsGpuKernelParameters*   kernelParams,
                       kernelParams->numIterations,
                       kernelParams->expansionOrder,
                       d_x,
-                      d_xpAsFloat,
+                      d_xp,
                       invdt,
-                      d_vAsFloat,
+                      d_v,
                       kernelParams->d_virialScaled,
                       kernelParams->pbcAiuc);
     return;