From 6eae295b15a92e524a244328915ea07f473d1a02 Mon Sep 17 00:00:00 2001 From: Andrey Alekseenko Date: Thu, 24 Jun 2021 08:43:43 +0000 Subject: [PATCH] SYCL NBNXM: Use hand-crafted energy reduction --- src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp | 38 +++++++++++++++++--- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp b/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp index 560575fdb5..ae703bdb16 100644 --- a/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp +++ b/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp @@ -342,6 +342,36 @@ static inline void reduceForceJShuffle(Float3 f, } } +// This function also requires sm_buf to have a length of at least 1. +// The function returns: +// - for thread #0 in the group: sum of all valueToReduce in a group +// - for other threads: unspecified +template +static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx, + const unsigned int tidxi, + cl::sycl::accessor sm_buf, + float valueToReduce) +{ + constexpr int numSubGroupsInGroup = groupSize / subGroupSize; + static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2); + sycl_2020::sub_group sg = itemIdx.get_sub_group(); + valueToReduce = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus()); + // If we have two sub-groups, we should reduce across them. + if constexpr (numSubGroupsInGroup == 2) + { + if (tidxi == subGroupSize) + { + sm_buf[0] = valueToReduce; + } + itemIdx.barrier(fence_space::local_space); + if (tidxi == 0) + { + valueToReduce += sm_buf[0]; + } + } + return valueToReduce; +} + /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f. * * c_clSize consecutive threads hold the force components of a j-atom which we @@ -1007,10 +1037,10 @@ auto nbnxmKernel(cl::sycl::handler& cgh, if constexpr (doCalcEnergies) { - const float energyVdwGroup = sycl_2020::group_reduce( - itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus()); - const float energyElecGroup = sycl_2020::group_reduce( - itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus()); + const float energyVdwGroup = + groupReduce(itemIdx, tidx, sm_reductionBuffer, energyVdw); + const float energyElecGroup = groupReduce( + itemIdx, tidx, sm_reductionBuffer, energyElec); if (tidx == 0) { -- 2.22.0