SYCL: remove (un)flatten
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel_pruneonly.cpp
index 1562cbc6fddbf0bcb549e16580de564e97458c29..b779ad79e3cc58d403293759c68ba773d753ab21 100644 (file)
@@ -98,16 +98,15 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
     constexpr int gmx_unused requiredSubGroupSize = (c_clSize == 4) ? 16 : warpSize;
 
     /* Requirements:
-     * Work group (block) must have range (c_clSize, c_clSize, ...) (for localId calculation, easy
+     * Work group (block) must have range (c_clSize, c_clSize, ...) (for itemIdx calculation, easy
      * to change). */
-    return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
     {
-        const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
         // thread/block/warp id-s
-        const unsigned tidxi = localId[0];
-        const unsigned tidxj = localId[1];
+        const unsigned tidxi = itemIdx.get_local_id(2);
+        const unsigned tidxj = itemIdx.get_local_id(1);
         const int      tidx  = tidxj * c_clSize + tidxi;
-        const unsigned tidxz = localId[2];
+        const unsigned tidxz = itemIdx.get_local_id(0);
         const unsigned bidx  = itemIdx.get_group(0);
 
         const sycl_2020::sub_group sg   = itemIdx.get_sub_group();
@@ -234,7 +233,7 @@ cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
      * - The 1D block-grid contains as many blocks as super-clusters.
      */
     const unsigned long         numBlocks = numSciInPart;
-    const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, c_syclPruneKernelJ4Concurrency };
+    const cl::sycl::range<3>    blockSize{ c_syclPruneKernelJ4Concurrency, c_clSize, c_clSize };
     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
 
@@ -242,7 +241,7 @@ cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
 
     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = nbnxmKernelPruneOnly<haveFreshList>(cgh, std::forward<Args>(args)...);
-        cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
+        cgh.parallel_for<kernelNameType>(range, kernel);
     });
 
     return e;