SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index 770732c70cc9eda458b02717ce98c64135425642..a7b2c6cde6b080ac4946855703037d4209aab454 100644 (file)
 #include "nbnxm_sycl_kernel_utils.h"
 #include "nbnxm_sycl_types.h"
 
+//! \brief Class name for NBNXM kernel
+template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
+class NbnxmKernel;
+
 namespace Nbnxm
 {
 
@@ -66,9 +70,9 @@ struct EnergyFunctionProperties {
     static constexpr bool elecEwaldTab =
             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
     static constexpr bool elecEwaldTwin =
-            (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
-    static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
-    static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
+            (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin); ///< Use twin cut-off.
+    static constexpr bool elecEwald = (elecEwaldAna || elecEwaldTab);  ///< EL_EWALD_ANY
+    static constexpr bool vdwCombLB = (vdwType == VdwType::CutCombLB); ///< LJ_COMB && !LJ_COMB_GEOM
     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
@@ -83,9 +87,6 @@ struct EnergyFunctionProperties {
 template<enum VdwType vdwType>
 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
 
-template<enum ElecType elecType> // Yes, ElecType
-constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
-
 template<enum ElecType elecType>
 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
 
@@ -100,17 +101,18 @@ using cl::sycl::access::fence_space;
 using cl::sycl::access::mode;
 using cl::sycl::access::target;
 
-static inline void convertSigmaEpsilonToC6C12(const float                  sigma,
-                                              const float                  epsilon,
-                                              cl::sycl::private_ptr<float> c6,
-                                              cl::sycl::private_ptr<float> c12)
+//! \brief Convert \p sigma and \p epsilon VdW parameters to \c c6,c12 pair.
+static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
 {
     const float sigma2 = sigma * sigma;
     const float sigma6 = sigma2 * sigma2 * sigma2;
-    *c6                = epsilon * sigma6;
-    *c12               = (*c6) * sigma6;
+    const float c6     = epsilon * sigma6;
+    const float c12    = c6 * sigma6;
+
+    return { c6, c12 };
 }
 
+//! \brief Calculate force and energy for a pair of atoms, VdW force-switch flavor.
 template<bool doCalcEnergies>
 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
                                  const shift_consts_t         repulsionShift,
@@ -147,25 +149,23 @@ static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
 
 //! \brief Fetch C6 grid contribution coefficients and return the product of these.
 template<enum VdwType vdwType>
-static inline float calculateLJEwaldC6Grid(const DeviceAccessor<float, mode::read> a_nbfpComb,
-                                           const int                               typeI,
-                                           const int                               typeJ)
+static inline float calculateLJEwaldC6Grid(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
+                                           const int                                typeI,
+                                           const int                                typeJ)
 {
     if constexpr (vdwType == VdwType::EwaldGeom)
     {
-        return a_nbfpComb[2 * typeI] * a_nbfpComb[2 * typeJ];
+        return a_nbfpComb[typeI][0] * a_nbfpComb[typeJ][0];
     }
     else
     {
         static_assert(vdwType == VdwType::EwaldLB);
         /* sigma and epsilon are scaled to give 6*C6 */
-        const float c6_i  = a_nbfpComb[2 * typeI];
-        const float c12_i = a_nbfpComb[2 * typeI + 1];
-        const float c6_j  = a_nbfpComb[2 * typeJ];
-        const float c12_j = a_nbfpComb[2 * typeJ + 1];
+        const Float2 c6c12_i = a_nbfpComb[typeI];
+        const Float2 c6c12_j = a_nbfpComb[typeJ];
 
-        const float sigma   = c6_i + c6_j;
-        const float epsilon = c12_i * c12_j;
+        const float sigma   = c6c12_i[0] + c6c12_j[0];
+        const float epsilon = c6c12_i[1] * c6c12_j[1];
 
         const float sigma2 = sigma * sigma;
         return epsilon * sigma2 * sigma2 * sigma2;
@@ -174,17 +174,17 @@ static inline float calculateLJEwaldC6Grid(const DeviceAccessor<float, mode::rea
 
 //! Calculate LJ-PME grid force contribution with geometric or LB combination rule.
 template<bool doCalcEnergies, enum VdwType vdwType>
-static inline void ljEwaldComb(const DeviceAccessor<float, mode::read> a_nbfpComb,
-                               const float                             sh_lj_ewald,
-                               const int                               typeI,
-                               const int                               typeJ,
-                               const float                             r2,
-                               const float                             r2Inv,
-                               const float                             lje_coeff2,
-                               const float                             lje_coeff6_6,
-                               const float                             int_bit,
-                               cl::sycl::private_ptr<float>            fInvR,
-                               cl::sycl::private_ptr<float>            eLJ)
+static inline void ljEwaldComb(const DeviceAccessor<Float2, mode::read> a_nbfpComb,
+                               const float                              sh_lj_ewald,
+                               const int                                typeI,
+                               const int                                typeJ,
+                               const float                              r2,
+                               const float                              r2Inv,
+                               const float                              lje_coeff2,
+                               const float                              lje_coeff6_6,
+                               const float                              int_bit,
+                               cl::sycl::private_ptr<float>             fInvR,
+                               cl::sycl::private_ptr<float>             eLJ)
 {
     const float c6grid = calculateLJEwaldC6Grid<vdwType>(a_nbfpComb, typeI, typeJ);
 
@@ -301,25 +301,30 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
     return lerp(left, right, fraction); // TODO: cl::sycl::mix
 }
 
-static inline void reduceForceJShuffle(Float3                                  f,
-                                       const cl::sycl::nd_item<1>              itemIdx,
-                                       const int                               tidxi,
-                                       const int                               aidx,
-                                       DeviceAccessor<float, mode::read_write> a_f)
+/*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
+ *
+ * c_clSize consecutive threads hold the force components of a j-atom which we
+ * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
+ */
+static inline void reduceForceJShuffle(Float3                                   f,
+                                       const cl::sycl::nd_item<3>               itemIdx,
+                                       const int                                tidxi,
+                                       const int                                aidx,
+                                       DeviceAccessor<Float3, mode::read_write> a_f)
 {
     static_assert(c_clSize == 8 || c_clSize == 4);
     sycl_2020::sub_group sg = itemIdx.get_sub_group();
 
-    f[0] += shuffleDown(f[0], 1, sg);
-    f[1] += shuffleUp(f[1], 1, sg);
-    f[2] += shuffleDown(f[2], 1, sg);
+    f[0] += sycl_2020::shift_left(sg, f[0], 1);
+    f[1] += sycl_2020::shift_right(sg, f[1], 1);
+    f[2] += sycl_2020::shift_left(sg, f[2], 1);
     if (tidxi & 1)
     {
         f[0] = f[1];
     }
 
-    f[0] += shuffleDown(f[0], 2, sg);
-    f[2] += shuffleUp(f[2], 2, sg);
+    f[0] += sycl_2020::shift_left(sg, f[0], 2);
+    f[2] += sycl_2020::shift_right(sg, f[2], 2);
     if (tidxi & 2)
     {
         f[0] = f[2];
@@ -327,35 +332,143 @@ static inline void reduceForceJShuffle(Float3                                  f
 
     if constexpr (c_clSize == 8)
     {
-        f[0] += shuffleDown(f[0], 4, sg);
+        f[0] += sycl_2020::shift_left(sg, f[0], 4);
+    }
+
+    if (tidxi < 3)
+    {
+        atomicFetchAdd(a_f[aidx][tidxi], f[0]);
+    }
+}
+
+/*!
+ * \brief Do workgroup-level reduction of a single \c float.
+ *
+ * While SYCL has \c sycl::reduce_over_group, it currently (oneAPI 2021.3.0) uses a very large
+ * shared memory buffer, which leads to a reduced occupancy.
+ *
+ * \note The caller must make sure there are no races when reusing the \p sm_buf.
+ *
+ * \tparam subGroupSize Size of a sub-group.
+ * \tparam groupSize Size of a work-group.
+ * \param itemIdx Current thread's \c sycl::nd_item.
+ * \param tidxi Current thread's linearized local index.
+ * \param sm_buf Accessor for local reduction buffer.
+ * \param valueToReduce Current thread's value. Must have length of at least 1.
+ * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
+ */
+template<int subGroupSize, int groupSize>
+static inline float groupReduce(const cl::sycl::nd_item<3> itemIdx,
+                                const unsigned int         tidxi,
+                                cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                float valueToReduce)
+{
+    constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
+    static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
+    sycl_2020::sub_group sg = itemIdx.get_sub_group();
+    valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
+    // If we have two sub-groups, we should reduce across them.
+    if constexpr (numSubGroupsInGroup == 2)
+    {
+        if (tidxi == subGroupSize)
+        {
+            sm_buf[0] = valueToReduce;
+        }
+        itemIdx.barrier(fence_space::local_space);
+        if (tidxi == 0)
+        {
+            valueToReduce += sm_buf[0];
+        }
     }
+    return valueToReduce;
+}
+
+/*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
+ *
+ * c_clSize consecutive threads hold the force components of a j-atom which we
+ * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
+ *
+ * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
+ */
+static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                       Float3                                   f,
+                                       const cl::sycl::nd_item<3>               itemIdx,
+                                       const int                                tidxi,
+                                       const int                                tidxj,
+                                       const int                                aidx,
+                                       DeviceAccessor<Float3, mode::read_write> a_f)
+{
+    static constexpr int sc_fBufferStride = c_clSizeSq;
+    int                  tidx             = tidxi + tidxj * c_clSize;
+    sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
+    sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
+    sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
+
+    subGroupBarrier(itemIdx);
+
+    // reducing data 8-by-by elements on the leader of same threads as those storing above
+    assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
 
     if (tidxi < 3)
     {
-        atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
+        float fSum = 0.0F;
+        for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
+        {
+            fSum += sm_buf[sc_fBufferStride * tidxi + j];
+        }
+
+        atomicFetchAdd(a_f[aidx][tidxi], fSum);
+    }
+}
+
+
+/*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
+ */
+static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                Float3                                                        f,
+                                const cl::sycl::nd_item<3>               itemIdx,
+                                const int                                tidxi,
+                                const int                                tidxj,
+                                const int                                aidx,
+                                DeviceAccessor<Float3, mode::read_write> a_f)
+{
+    if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
+    {
+        reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
+    }
+    else
+    {
+        reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
     }
 }
 
 
 /*! \brief Final i-force reduction.
+ *
+ * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force components stored in \p fCiBuf[]
+ * accumulating atomically into \p a_f.
+ * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
  *
  * This implementation works only with power of two array sizes.
  */
 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
                                          const bool   calcFShift,
-                                         const cl::sycl::nd_item<1>              itemIdx,
-                                         const int                               tidxi,
-                                         const int                               tidxj,
-                                         const int                               sci,
-                                         const int                               shift,
-                                         DeviceAccessor<float, mode::read_write> a_f,
-                                         DeviceAccessor<float, mode::read_write> a_fShift)
+                                         const cl::sycl::nd_item<3>               itemIdx,
+                                         const int                                tidxi,
+                                         const int                                tidxj,
+                                         const int                                sci,
+                                         const int                                shift,
+                                         DeviceAccessor<Float3, mode::read_write> a_f,
+                                         DeviceAccessor<Float3, mode::read_write> a_fShift)
 {
+    // must have power of two elements in fCiBuf
+    static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
+
     static constexpr int bufStride  = c_clSize * c_clSize;
     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
     const int            tidx       = tidxi + tidxj * c_clSize;
-    float                fShiftBuf  = 0;
+    float                fShiftBuf  = 0.0F;
     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
     {
         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
@@ -391,7 +504,7 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
         {
             const float f =
                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
-            atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
+            atomicFetchAdd(a_f[aidx][tidxj], f);
             if (calcFShift)
             {
                 fShiftBuf += f;
@@ -407,21 +520,39 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
            storing the reduction result above. */
         if (tidxj < 3)
         {
-            atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
+            if constexpr (c_clSize == 4)
+            {
+                /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
+                 * a compare-and-swap (CAS) loop. It has particularly poor performance when
+                 * updating the same memory location from the same work-group.
+                 * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
+                 * but it is unlikely to make a big difference and thus was not evaluated.
+                 */
+                auto sg = itemIdx.get_sub_group();
+                fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
+                fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
+                if (tidxi == 0)
+                {
+                    atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
+                }
+            }
+            else
+            {
+                atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
+            }
         }
     }
 }
 
-
 /*! \brief Main kernel for NBNXM.
  *
  */
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
 auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                  DeviceAccessor<Float4, mode::read>                        a_xq,
-                 DeviceAccessor<float, mode::read_write>                   a_f,
+                 DeviceAccessor<Float3, mode::read_write>                  a_f,
                  DeviceAccessor<Float3, mode::read>                        a_shiftVec,
-                 DeviceAccessor<float, mode::read_write>                   a_fShift,
+                 DeviceAccessor<Float3, mode::read_write>                  a_fShift,
                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
@@ -429,8 +560,8 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
                  OptionalAccessor<Float2, mode::read, ljComb<vdwType>>       a_ljComb,
                  OptionalAccessor<int, mode::read, !ljComb<vdwType>>         a_atomTypes,
-                 OptionalAccessor<float, mode::read, !ljComb<vdwType>>       a_nbfp,
-                 OptionalAccessor<float, mode::read, ljEwald<vdwType>>       a_nbfpComb,
+                 OptionalAccessor<Float2, mode::read, !ljComb<vdwType>>      a_nbfp,
+                 OptionalAccessor<Float2, mode::read, ljEwald<vdwType>>      a_nbfpComb,
                  OptionalAccessor<float, mode::read, elecEwaldTab<elecType>> a_coulombTab,
                  const int                                                   numTypes,
                  const float                                                 rCoulombSq,
@@ -452,34 +583,34 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
 {
     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
 
-    cgh.require(a_xq);
-    cgh.require(a_f);
-    cgh.require(a_shiftVec);
-    cgh.require(a_fShift);
+    a_xq.bind(cgh);
+    a_f.bind(cgh);
+    a_shiftVec.bind(cgh);
+    a_fShift.bind(cgh);
     if constexpr (doCalcEnergies)
     {
-        cgh.require(a_energyElec);
-        cgh.require(a_energyVdw);
+        a_energyElec.bind(cgh);
+        a_energyVdw.bind(cgh);
     }
-    cgh.require(a_plistCJ4);
-    cgh.require(a_plistSci);
-    cgh.require(a_plistExcl);
+    a_plistCJ4.bind(cgh);
+    a_plistSci.bind(cgh);
+    a_plistExcl.bind(cgh);
     if constexpr (!props.vdwComb)
     {
-        cgh.require(a_atomTypes);
-        cgh.require(a_nbfp);
+        a_atomTypes.bind(cgh);
+        a_nbfp.bind(cgh);
     }
     else
     {
-        cgh.require(a_ljComb);
+        a_ljComb.bind(cgh);
     }
     if constexpr (props.vdwEwald)
     {
-        cgh.require(a_nbfpComb);
+        a_nbfpComb.bind(cgh);
     }
     if constexpr (props.elecEwaldTab)
     {
-        cgh.require(a_coulombTab);
+        a_coulombTab.bind(cgh);
     }
 
     // shmem buffer for i x+q pre-loading
@@ -521,23 +652,32 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
     constexpr bool doExclusionForces =
             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
 
-    constexpr int subGroupSize = c_clSize * c_clSize / 2;
-
-    return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
+    // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
+    // Currently, this is ideally suited for 32-wide subgroup size but slightly less so for others,
+    // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
+    // Hence, the two are decoupled.
+    // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
+    constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
+#if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
+    gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
+#else
+    gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
+#endif
+
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
     {
         /* thread/block/warp id-s */
-        const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
-        const unsigned        tidxi   = localId[0];
-        const unsigned        tidxj   = localId[1];
-        const unsigned        tidx    = tidxj * c_clSize + tidxi;
-        const unsigned        tidxz   = 0;
+        const unsigned tidxi = itemIdx.get_local_id(2);
+        const unsigned tidxj = itemIdx.get_local_id(1);
+        const unsigned tidx  = tidxj * c_clSize + tidxi;
+        const unsigned tidxz = 0;
 
-        // Group indexing was flat originally, no need to unflatten it.
         const unsigned bidx = itemIdx.get_group(0);
 
         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
-        // Better use sg.get_group_range, but too much of the logic relies on it anyway
-        const unsigned widx = tidx / subGroupSize;
+        // Could use sg.get_group_range to compute the imask & exclusion Idx, but too much of the logic relies on it anyway
+        // and in cases where prunedClusterPairSize != subGroupSize we can't use it anyway
+        const unsigned imeiIdx = tidx / prunedClusterPairSize;
 
         Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster]; // i force buffer
         for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
@@ -594,7 +734,8 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
         }
         if constexpr (doCalcEnergies && doExclusionForces)
         {
-            if (nbSci.shift == CENTRAL && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
+            if (nbSci.shift == gmx::c_centralShiftIndex
+                && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
             {
                 // we have the diagonal: add the charge and LJ self interaction energy term
                 for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
@@ -609,7 +750,7 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                     {
                         energyVdw +=
                                 a_nbfp[a_atomTypes[(sci * c_nbnxnGpuNumClusterPerSupercluster + i) * c_clSize + tidxi]
-                                       * (numTypes + 1) * 2];
+                                       * (numTypes + 1)][0];
                     }
                 }
                 /* divide the self term(s) equally over the j-threads, then multiply with the coefficients. */
@@ -630,22 +771,23 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                     energyElec /= epsFac * c_clSize;
                     energyElec *= -ewaldBeta * c_OneOverSqrtPi; /* last factor 1/sqrt(pi) */
                 }
-            } // (nbSci.shift == CENTRAL && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
+            } // (nbSci.shift == gmx::c_centralShiftIndex && a_plistCJ4[cij4Start].cj[0] == sci * c_nbnxnGpuNumClusterPerSupercluster)
         }     // (doCalcEnergies && doExclusionForces)
 
         // Only needed if (doExclusionForces)
-        const bool nonSelfInteraction = !(nbSci.shift == CENTRAL & tidxj <= tidxi);
+        const bool nonSelfInteraction = !(nbSci.shift == gmx::c_centralShiftIndex & tidxj <= tidxi);
 
         // loop over the j clusters = seen by any of the atoms in the current super-cluster
         for (int j4 = cij4Start + tidxz; j4 < cij4End; j4 += 1)
         {
-            unsigned imask = a_plistCJ4[j4].imei[widx].imask;
+            unsigned imask = a_plistCJ4[j4].imei[imeiIdx].imask;
             if (!doPruneNBL && !imask)
             {
                 continue;
             }
-            const int wexclIdx = a_plistCJ4[j4].imei[widx].excl_ind;
-            const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (subGroupSize - 1)]; // sg.get_local_linear_id()
+            const int wexclIdx = a_plistCJ4[j4].imei[imeiIdx].excl_ind;
+            static_assert(gmx::isPowerOfTwo(prunedClusterPairSize));
+            const unsigned wexcl = a_plistExcl[wexclIdx].pair[tidx & (prunedClusterPairSize - 1)];
             for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
             {
                 const bool maskSet =
@@ -712,23 +854,21 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                         {
                             const float qi = xqi[3];
                             int         atomTypeI; // Only needed if (!props.vdwComb)
-                            float       c6, c12, sigma, epsilon;
+                            float       sigma, epsilon;
+                            Float2      c6c12;
 
                             if constexpr (!props.vdwComb)
                             {
                                 /* LJ 6*C6 and 12*C12 */
-                                atomTypeI     = sm_atomTypeI[i][tidxi];
-                                const int idx = (numTypes * atomTypeI + atomTypeJ) * 2;
-                                c6            = a_nbfp[idx]; // TODO: Make a_nbfm into float2
-                                c12           = a_nbfp[idx + 1];
+                                atomTypeI = sm_atomTypeI[i][tidxi];
+                                c6c12     = a_nbfp[numTypes * atomTypeI + atomTypeJ];
                             }
                             else
                             {
                                 const Float2 ljCombI = sm_ljCombI[i][tidxi];
                                 if constexpr (props.vdwCombGeom)
                                 {
-                                    c6  = ljCombI[0] * ljCombJ[0];
-                                    c12 = ljCombI[1] * ljCombJ[1];
+                                    c6c12 = Float2(ljCombI[0] * ljCombJ[0], ljCombI[1] * ljCombJ[1]);
                                 }
                                 else
                                 {
@@ -738,15 +878,24 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                                     epsilon = ljCombI[1] * ljCombJ[1];
                                     if constexpr (doCalcEnergies)
                                     {
-                                        convertSigmaEpsilonToC6C12(sigma, epsilon, &c6, &c12);
+                                        c6c12 = convertSigmaEpsilonToC6C12(sigma, epsilon);
                                     }
                                 } // props.vdwCombGeom
                             }     // !props.vdwComb
 
+                            // c6 and c12 are unused and garbage iff props.vdwCombLB && !doCalcEnergies
+                            const float c6  = c6c12[0];
+                            const float c12 = c6c12[1];
+
                             // Ensure distance do not become so small that r^-12 overflows
                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
+#if GMX_SYCL_HIPSYCL
+                            // No fast/native functions in some compilation passes
+                            const float rInv = cl::sycl::rsqrt(r2);
+#else
                             // SYCL-TODO: sycl::half_precision::rsqrt?
-                            const float rInv  = cl::sycl::native::rsqrt(r2);
+                            const float rInv = cl::sycl::native::rsqrt(r2);
+#endif
                             const float r2Inv = rInv * rInv;
                             float       r6Inv, fInvR, energyLJPair;
                             if constexpr (!props.vdwCombLB || doCalcEnergies)
@@ -843,7 +992,7 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                                 fInvR += qi * qj
                                          * (pairExclMask * r2Inv
                                             - interpolateCoulombForceR(
-                                                      a_coulombTab, coulombTabScale, r2 * rInv))
+                                                    a_coulombTab, coulombTabScale, r2 * rInv))
                                          * rInv;
                             }
 
@@ -856,7 +1005,7 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                                 if constexpr (props.elecRF)
                                 {
                                     energyElec +=
-                                            qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
+                                            qi * qj * (pairExclMask * rInv + 0.5F * twoKRf * r2 - cRF);
                                 }
                                 if constexpr (props.elecEwald)
                                 {
@@ -879,47 +1028,44 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                     maskJI += maskJI;
                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
                 /* reduce j forces */
-                reduceForceJShuffle(fCjBuf, itemIdx, tidxi, aj, a_f);
+                reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
             if constexpr (doPruneNBL)
             {
                 /* Update the imask with the new one which does not contain the
                  * out of range clusters anymore. */
-                a_plistCJ4[j4].imei[widx].imask = imask;
+                a_plistCJ4[j4].imei[imeiIdx].imask = imask;
             }
         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
 
         /* skip central shifts when summing shift forces */
-        const bool doCalcShift = (calcShift && !(nbSci.shift == CENTRAL));
+        const bool doCalcShift = (calcShift && nbSci.shift != gmx::c_centralShiftIndex);
 
         reduceForceIAndFShift(
                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
 
         if constexpr (doCalcEnergies)
         {
-            const float energyVdwGroup = sycl_2020::group_reduce(
-                    itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
-            const float energyElecGroup = sycl_2020::group_reduce(
-                    itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
+            const float energyVdwGroup =
+                    groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
+            itemIdx.barrier(fence_space::local_space); // Prevent the race on sm_reductionBuffer.
+            const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
+                    itemIdx, tidx, sm_reductionBuffer, energyElec);
 
             if (tidx == 0)
             {
-                atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
-                atomicFetchAdd(a_energyElec, 0, energyElecGroup);
+                atomicFetchAdd(a_energyVdw[0], energyVdwGroup);
+                atomicFetchAdd(a_energyElec[0], energyElecGroup);
             }
         }
     };
 }
 
-// SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
-template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
-class NbnxmKernelName;
-
+//! \brief NBNXM kernel launch code.
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
 {
-    // Should not be needed for SYCL2020.
-    using kernelNameType = NbnxmKernelName<doPruneNBL, doCalcEnergies, elecType, vdwType>;
+    using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
 
     /* Kernel launch config:
      * - The thread block dimensions match the size of i-clusters, j-clusters,
@@ -927,7 +1073,7 @@ cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int nu
      * - The 1D block-grid contains as many blocks as super-clusters.
      */
     const int                   numBlocks = numSci;
-    const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
+    const cl::sycl::range<3>    blockSize{ 1, c_clSize, c_clSize };
     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
 
@@ -936,12 +1082,13 @@ cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int nu
     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
                 cgh, std::forward<Args>(args)...);
-        cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
+        cgh.parallel_for<kernelNameType>(range, kernel);
     });
 
     return e;
 }
 
+//! \brief Select templated kernel and launch it.
 template<class... Args>
 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
                                            bool          doCalcEnergies,
@@ -962,18 +1109,12 @@ cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
 
 void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const InteractionLocality iloc)
 {
-    NBAtomData*         adat         = nb->atdat;
+    NBAtomDataGpu*      adat         = nb->atdat;
     NBParamGpu*         nbp          = nb->nbparam;
     gpu_plist*          plist        = nb->plist[iloc];
     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
 
-    // Casting to float simplifies using atomic ops in the kernel
-    cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
-    auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
-    cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
-    auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
-
     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
                                                    stepWork.computeEnergy,
                                                    nbp->elecType,
@@ -981,9 +1122,9 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In
                                                    deviceStream,
                                                    plist->nsci,
                                                    adat->xq,
-                                                   fAsFloat,
+                                                   adat->f,
                                                    adat->shiftVec,
-                                                   fShiftAsFloat,
+                                                   adat->fShift,
                                                    adat->eElec,
                                                    adat->eLJ,
                                                    plist->cj4,