+// This function also requires sm_buf to have a length of at least 1.
+// The function returns:
+// - for thread #0 in the group: sum of all valueToReduce in a group
+// - for other threads: unspecified
+template<int subGroupSize, int groupSize>
+static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
+ const unsigned int tidxi,
+ cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+ float valueToReduce)
+{
+ constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
+ static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
+ sycl_2020::sub_group sg = itemIdx.get_sub_group();
+ valueToReduce = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
+ // If we have two sub-groups, we should reduce across them.
+ if constexpr (numSubGroupsInGroup == 2)
+ {
+ if (tidxi == subGroupSize)
+ {
+ sm_buf[0] = valueToReduce;
+ }
+ itemIdx.barrier(fence_space::local_space);
+ if (tidxi == 0)
+ {
+ valueToReduce += sm_buf[0];
+ }
+ }
+ return valueToReduce;
+}
+