SYCL NBNXM: Use hand-crafted energy reduction
authorAndrey Alekseenko <al42and@gmail.com>
Thu, 24 Jun 2021 08:43:43 +0000 (08:43 +0000)
committerPaul Bauer <paul.bauer.q@gmail.com>
Thu, 24 Jun 2021 08:43:43 +0000 (08:43 +0000)
src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp

index 560575fdb5cdf95e0dd5859b29c6aed4384dac30..ae703bdb16f042f7ea63e5749805c86418ff1656 100644 (file)
@@ -342,6 +342,36 @@ static inline void reduceForceJShuffle(Float3                             f,
     }
 }
 
+// This function also requires sm_buf to have a length of at least 1.
+// The function returns:
+//     - for thread #0 in the group: sum of all valueToReduce in a group
+//     - for other threads: unspecified
+template<int subGroupSize, int groupSize>
+static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
+                                const unsigned int         tidxi,
+                                cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                float valueToReduce)
+{
+    constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
+    static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
+    sycl_2020::sub_group sg = itemIdx.get_sub_group();
+    valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
+    // If we have two sub-groups, we should reduce across them.
+    if constexpr (numSubGroupsInGroup == 2)
+    {
+        if (tidxi == subGroupSize)
+        {
+            sm_buf[0] = valueToReduce;
+        }
+        itemIdx.barrier(fence_space::local_space);
+        if (tidxi == 0)
+        {
+            valueToReduce += sm_buf[0];
+        }
+    }
+    return valueToReduce;
+}
+
 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
  *
  * c_clSize consecutive threads hold the force components of a j-atom which we
@@ -1007,10 +1037,10 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
 
         if constexpr (doCalcEnergies)
         {
-            const float energyVdwGroup = sycl_2020::group_reduce(
-                    itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
-            const float energyElecGroup = sycl_2020::group_reduce(
-                    itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
+            const float energyVdwGroup =
+                    groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
+            const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
+                    itemIdx, tidx, sm_reductionBuffer, energyElec);
 
             if (tidx == 0)
             {