SYCL: remove (un)flatten
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index 16be5bc7df8ec5f8f38ce0ef64a92e19c2e1dadf..39811fd11d26f9f4f2086f1c382a3c13998535b4 100644 (file)
@@ -307,7 +307,7 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
  */
 static inline void reduceForceJShuffle(Float3                                   f,
-                                       const cl::sycl::nd_item<1>               itemIdx,
+                                       const cl::sycl::nd_item<3>               itemIdx,
                                        const int                                tidxi,
                                        const int                                aidx,
                                        DeviceAccessor<Float3, mode::read_write> a_f)
@@ -356,7 +356,7 @@ static inline void reduceForceJShuffle(Float3
  * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
  */
 template<int subGroupSize, int groupSize>
-static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
+static inline float groupReduce(const cl::sycl::nd_item<3> itemIdx,
                                 const unsigned int         tidxi,
                                 cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                 float valueToReduce)
@@ -390,7 +390,7 @@ static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
  */
 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                        Float3                                   f,
-                                       const cl::sycl::nd_item<1>               itemIdx,
+                                       const cl::sycl::nd_item<3>               itemIdx,
                                        const int                                tidxi,
                                        const int                                tidxj,
                                        const int                                aidx,
@@ -424,7 +424,7 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
  */
 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                 Float3                                                        f,
-                                const cl::sycl::nd_item<1>               itemIdx,
+                                const cl::sycl::nd_item<3>               itemIdx,
                                 const int                                tidxi,
                                 const int                                tidxj,
                                 const int                                aidx,
@@ -452,7 +452,7 @@ static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, t
 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
                                          const bool   calcFShift,
-                                         const cl::sycl::nd_item<1>               itemIdx,
+                                         const cl::sycl::nd_item<3>               itemIdx,
                                          const int                                tidxi,
                                          const int                                tidxj,
                                          const int                                sci,
@@ -662,16 +662,14 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
 #endif
 
-    return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
     {
         /* thread/block/warp id-s */
-        const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
-        const unsigned        tidxi   = localId[0];
-        const unsigned        tidxj   = localId[1];
-        const unsigned        tidx    = tidxj * c_clSize + tidxi;
-        const unsigned        tidxz   = 0;
+        const unsigned tidxi = itemIdx.get_local_id(2);
+        const unsigned tidxj = itemIdx.get_local_id(1);
+        const unsigned tidx  = tidxj * c_clSize + tidxi;
+        const unsigned tidxz = 0;
 
-        // Group indexing was flat originally, no need to unflatten it.
         const unsigned bidx = itemIdx.get_group(0);
 
         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
@@ -1072,7 +1070,7 @@ cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int nu
      * - The 1D block-grid contains as many blocks as super-clusters.
      */
     const int                   numBlocks = numSci;
-    const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
+    const cl::sycl::range<3>    blockSize{ 1, c_clSize, c_clSize };
     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
 
@@ -1081,7 +1079,7 @@ cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int nu
     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
                 cgh, std::forward<Args>(args)...);
-        cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
+        cgh.parallel_for<kernelNameType>(range, kernel);
     });
 
     return e;