SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel_pruneonly.cpp
index a2cc1f8a0ddc41f140ad364ff6e0fc5c1b5f52d2..4b2ca582c3e5b98905313ca72428bb9c4527b406 100644 (file)
@@ -54,6 +54,10 @@ using cl::sycl::access::fence_space;
 using cl::sycl::access::mode;
 using cl::sycl::access::target;
 
+//! \brief Class name for NBNXM prune-only kernel
+template<bool haveFreshList>
+class NbnxmKernelPruneOnly;
+
 namespace Nbnxm
 {
 
@@ -72,11 +76,11 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
                           const int   numParts,
                           const int   part)
 {
-    cgh.require(a_xq);
-    cgh.require(a_shiftVec);
-    cgh.require(a_plistCJ4);
-    cgh.require(a_plistSci);
-    cgh.require(a_plistIMask);
+    a_xq.bind(cgh);
+    a_shiftVec.bind(cgh);
+    a_plistCJ4.bind(cgh);
+    a_plistSci.bind(cgh);
+    a_plistIMask.bind(cgh);
 
     /* shmem buffer for i x+q pre-loading */
     cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
@@ -94,16 +98,15 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
     constexpr int gmx_unused requiredSubGroupSize = (c_clSize == 4) ? 16 : warpSize;
 
     /* Requirements:
-     * Work group (block) must have range (c_clSize, c_clSize, ...) (for localId calculation, easy
+     * Work group (block) must have range (c_clSize, c_clSize, ...) (for itemIdx calculation, easy
      * to change). */
-    return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(requiredSubGroupSize)]]
     {
-        const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
         // thread/block/warp id-s
-        const unsigned tidxi = localId[0];
-        const unsigned tidxj = localId[1];
+        const unsigned tidxi = itemIdx.get_local_id(2);
+        const unsigned tidxj = itemIdx.get_local_id(1);
         const int      tidx  = tidxj * c_clSize + tidxi;
-        const unsigned tidxz = localId[2];
+        const unsigned tidxz = itemIdx.get_local_id(0);
         const unsigned bidx  = itemIdx.get_group(0);
 
         const sycl_2020::sub_group sg   = itemIdx.get_sub_group();
@@ -216,17 +219,13 @@ auto nbnxmKernelPruneOnly(cl::sycl::handler&                            cgh,
     };
 }
 
-// SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
-template<bool haveFreshList>
-class NbnxmKernelPruneOnlyName;
-
+//! \brief Leap Frog SYCL prune-only kernel launch code.
 template<bool haveFreshList, class... Args>
 cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
                                            const int           numSciInPart,
                                            Args&&... args)
 {
-    // Should not be needed for SYCL2020.
-    using kernelNameType = NbnxmKernelPruneOnlyName<haveFreshList>;
+    using kernelNameType = NbnxmKernelPruneOnly<haveFreshList>;
 
     /* Kernel launch config:
      * - The thread block dimensions match the size of i-clusters, j-clusters,
@@ -234,7 +233,7 @@ cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
      * - The 1D block-grid contains as many blocks as super-clusters.
      */
     const unsigned long         numBlocks = numSciInPart;
-    const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, c_syclPruneKernelJ4Concurrency };
+    const cl::sycl::range<3>    blockSize{ c_syclPruneKernelJ4Concurrency, c_clSize, c_clSize };
     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
 
@@ -242,12 +241,13 @@ cl::sycl::event launchNbnxmKernelPruneOnly(const DeviceStream& deviceStream,
 
     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = nbnxmKernelPruneOnly<haveFreshList>(cgh, std::forward<Args>(args)...);
-        cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
+        cgh.parallel_for<kernelNameType>(range, kernel);
     });
 
     return e;
 }
 
+//! \brief Select templated kernel and launch it.
 template<class... Args>
 cl::sycl::event chooseAndLaunchNbnxmKernelPruneOnly(bool haveFreshList, Args&&... args)
 {