SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel_pruneonly.cpp
index 1562cbc6fddbf0bcb549e16580de564e97458c29..4b2ca582c3e5b98905313ca72428bb9c4527b406 100644 (file)
@@ -76,11 +76,11 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
                           const int   numParts,
                           const int   part)
 {
-    cgh.require(a_xq);
-    cgh.require(a_shiftVec);
-    cgh.require(a_plistCJ4);
-    cgh.require(a_plistSci);
-    cgh.require(a_plistIMask);
+    a_xq.bind(cgh);
+    a_shiftVec.bind(cgh);
+    a_plistCJ4.bind(cgh);
+    a_plistSci.bind(cgh);
+    a_plistIMask.bind(cgh);
 
     /* shmem buffer for i x+q pre-loading */
     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
@@ -98,16 +98,15 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
     constexpr int gmx_unused requiredSubGroupSize = (c_clSize == 4) ? 16 : warpSize;
 
     /* Requirements:
-     * Work group (block) must have range (c_clSize, c_clSize, ...) (for localId calculation, easy
+     * Work group (block) must have range (c_clSize, c_clSize, ...) (for itemIdx calculation, easy
      * to change). */
-    return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
     {
-        const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
         // thread/block/warp id-s
-        const unsigned tidxi = localId[0];
-        const unsigned tidxj = localId[1];
+        const unsigned tidxi = itemIdx.get_local_id(2);
+        const unsigned tidxj = itemIdx.get_local_id(1);
         const int      tidx  = tidxj * c_clSize + tidxi;
-        const unsigned tidxz = localId[2];
+        const unsigned tidxz = itemIdx.get_local_id(0);
         const unsigned bidx  = itemIdx.get_group(0);
 
         const sycl_2020::sub_group sg   = itemIdx.get_sub_group();
@@ -234,7 +233,7 @@ cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
      * - The 1D block-grid contains as many blocks as super-clusters.
      */
     const unsigned long         numBlocks = numSciInPart;
-    const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, c_syclPruneKernelJ4Concurrency };
+    const cl::sycl::range<3>    blockSize{ c_syclPruneKernelJ4Concurrency, c_clSize, c_clSize };
     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
 
@@ -242,7 +241,7 @@ cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
 
     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = nbnxmKernelPruneOnly<haveFreshList>(cgh, std::forward<Args>(args)...);
-        cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
+        cgh.parallel_for<kernelNameType>(range, kernel);
     });
 
     return e;