Get rid of sycl::buffer::reinterpret
authorAndrey Alekseenko <al42and@gmail.com>
Thu, 21 Oct 2021 14:20:38 +0000 (16:20 +0200)
committerAndrey Alekseenko <al42and@gmail.com>
Fri, 22 Oct 2021 14:42:41 +0000 (16:42 +0200)
It is not fully supported with hipSYCL, and, while it does work in
practice, we better avoid it.

Refs #4063

src/gromacs/mdlib/lincs_gpu_internal_sycl.cpp

index fcb83c01bd5806d63a804a5abd1d36cb6cf4a691..658a45b32c7f6fa7ac9aa8c86a09ad4941aaaafe 100644 (file)
@@ -118,9 +118,9 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
                  const int                                                         numIterations,
                  const int                                                         expansionOrder,
                  DeviceAccessor<Float3, mode::read>                                a_x,
-                 DeviceAccessor<float, mode::read_write>                           a_xp,
+                 DeviceAccessor<Float3, mode::read_write>                          a_xp,
                  const float                                                       invdt,
-                 OptionalAccessor<float, mode::read_write, updateVelocities>       a_v,
+                 OptionalAccessor<Float3, mode::read_write, updateVelocities>      a_v,
                  OptionalAccessor<float, mode::read_write, computeVirial>          a_virialScaled,
                  PbcAiuc                                                           pbcAiuc)
 {
@@ -238,12 +238,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
         // Skipping in dummy threads
         if (!isDummyThread)
         {
-            xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
-            xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
-            xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
-            xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
-            xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
-            xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
+            xi[XX] = atomicLoad(a_xp[i][XX]);
+            xi[YY] = atomicLoad(a_xp[i][YY]);
+            xi[ZZ] = atomicLoad(a_xp[i][ZZ]);
+            xj[XX] = atomicLoad(a_xp[j][XX]);
+            xj[YY] = atomicLoad(a_xp[j][YY]);
+            xj[ZZ] = atomicLoad(a_xp[j][ZZ]);
         }
 
         Float3 dx;
@@ -296,12 +296,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
              * Note: Using memory_scope::work_group for atomic_ref can be better here,
              * but for now we re-use the existing function for memory_scope::device atomics.
              */
-            atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
-            atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
-            atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
-            atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
-            atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
-            atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+            atomicFetchAdd(a_xp[i][XX], -tmp[XX] * inverseMassi);
+            atomicFetchAdd(a_xp[i][YY], -tmp[YY] * inverseMassi);
+            atomicFetchAdd(a_xp[i][ZZ], -tmp[ZZ] * inverseMassi);
+            atomicFetchAdd(a_xp[j][XX], tmp[XX] * inverseMassj);
+            atomicFetchAdd(a_xp[j][YY], tmp[YY] * inverseMassj);
+            atomicFetchAdd(a_xp[j][ZZ], tmp[ZZ] * inverseMassj);
         }
 
         /*
@@ -315,12 +315,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
 
             if (!isDummyThread)
             {
-                xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
-                xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
-                xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
-                xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
-                xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
-                xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
+                xi[XX] = atomicLoad(a_xp[i][XX]);
+                xi[YY] = atomicLoad(a_xp[i][YY]);
+                xi[ZZ] = atomicLoad(a_xp[i][ZZ]);
+                xj[XX] = atomicLoad(a_xp[j][XX]);
+                xj[YY] = atomicLoad(a_xp[j][YY]);
+                xj[ZZ] = atomicLoad(a_xp[j][ZZ]);
             }
 
             Float3 dx;
@@ -376,12 +376,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             if (!isDummyThread)
             {
                 Float3 tmp = rc * sqrtmu_sol;
-                atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
-                atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
-                atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
-                atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
-                atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
-                atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+                atomicFetchAdd(a_xp[i][XX], -tmp[XX] * inverseMassi);
+                atomicFetchAdd(a_xp[i][YY], -tmp[YY] * inverseMassi);
+                atomicFetchAdd(a_xp[i][ZZ], -tmp[ZZ] * inverseMassi);
+                atomicFetchAdd(a_xp[j][XX], tmp[XX] * inverseMassj);
+                atomicFetchAdd(a_xp[j][YY], tmp[YY] * inverseMassj);
+                atomicFetchAdd(a_xp[j][ZZ], tmp[ZZ] * inverseMassj);
             }
         }
 
@@ -391,12 +391,12 @@ auto lincsKernel(cl::sycl::handler&                   cgh,
             if (!isDummyThread)
             {
                 Float3 tmp = rc * invdt * lagrangeScaled;
-                atomicFetchAdd(a_v[i * DIM + XX], -tmp[XX] * inverseMassi);
-                atomicFetchAdd(a_v[i * DIM + YY], -tmp[YY] * inverseMassi);
-                atomicFetchAdd(a_v[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
-                atomicFetchAdd(a_v[j * DIM + XX], tmp[XX] * inverseMassj);
-                atomicFetchAdd(a_v[j * DIM + YY], tmp[YY] * inverseMassj);
-                atomicFetchAdd(a_v[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+                atomicFetchAdd(a_v[i][XX], -tmp[XX] * inverseMassi);
+                atomicFetchAdd(a_v[i][YY], -tmp[YY] * inverseMassi);
+                atomicFetchAdd(a_v[i][ZZ], -tmp[ZZ] * inverseMassi);
+                atomicFetchAdd(a_v[j][XX], tmp[XX] * inverseMassj);
+                atomicFetchAdd(a_v[j][YY], tmp[YY] * inverseMassj);
+                atomicFetchAdd(a_v[j][ZZ], tmp[ZZ] * inverseMassj);
             }
         }
 
@@ -514,12 +514,6 @@ void launchLincsGpuKernel(LincsGpuKernelParameters*   kernelParams,
                           const bool                  computeVirial,
                           const DeviceStream&         deviceStream)
 {
-    cl::sycl::buffer<Float3, 1> xp(*d_xp.buffer_);
-    auto                        d_xpAsFloat = xp.reinterpret<float, 1>(xp.get_count() * DIM);
-
-    cl::sycl::buffer<Float3, 1> v(*d_v.buffer_);
-    auto                        d_vAsFloat = v.reinterpret<float, 1>(v.get_count() * DIM);
-
     launchLincsKernel(updateVelocities,
                       computeVirial,
                       kernelParams->haveCoupledConstraints,
@@ -535,9 +529,9 @@ void launchLincsGpuKernel(LincsGpuKernelParameters*   kernelParams,
                       kernelParams->numIterations,
                       kernelParams->expansionOrder,
                       d_x,
-                      d_xpAsFloat,
+                      d_xp,
                       invdt,
-                      d_vAsFloat,
+                      d_v,
                       kernelParams->d_virialScaled,
                       kernelParams->pbcAiuc);
     return;