SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index 63be54335b0ccd6d9ea4e96976a8fb7dd9ecd85c..a7b2c6cde6b080ac4946855703037d4209aab454 100644 (file)
 #include "nbnxm_sycl_kernel_utils.h"
 #include "nbnxm_sycl_types.h"
 
+//! \brief Class name for NBNXM kernel
+template<bool doPruneNBL, bool doCalcEnergies, enum Nbnxm::ElecType elecType, enum Nbnxm::VdwType vdwType>
+class NbnxmKernel;
+
 namespace Nbnxm
 {
 
@@ -66,9 +70,9 @@ struct EnergyFunctionProperties {
     static constexpr bool elecEwaldTab =
             (elecType == ElecType::EwaldTab || elecType == ElecType::EwaldTabTwin); ///< EL_EWALD_TAB
     static constexpr bool elecEwaldTwin =
-            (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin);
-    static constexpr bool elecEwald        = (elecEwaldAna || elecEwaldTab); ///< EL_EWALD_ANY
-    static constexpr bool vdwCombLB        = (vdwType == VdwType::CutCombLB);
+            (elecType == ElecType::EwaldAnaTwin || elecType == ElecType::EwaldTabTwin); ///< Use twin cut-off.
+    static constexpr bool elecEwald = (elecEwaldAna || elecEwaldTab);  ///< EL_EWALD_ANY
+    static constexpr bool vdwCombLB = (vdwType == VdwType::CutCombLB); ///< LJ_COMB && !LJ_COMB_GEOM
     static constexpr bool vdwCombGeom      = (vdwType == VdwType::CutCombGeom); ///< LJ_COMB_GEOM
     static constexpr bool vdwComb          = (vdwCombLB || vdwCombGeom);        ///< LJ_COMB
     static constexpr bool vdwEwaldCombGeom = (vdwType == VdwType::EwaldGeom); ///< LJ_EWALD_COMB_GEOM
@@ -83,9 +87,6 @@ struct EnergyFunctionProperties {
 template<enum VdwType vdwType>
 constexpr bool ljComb = EnergyFunctionProperties<ElecType::Count, vdwType>().vdwComb;
 
-template<enum ElecType elecType> // Yes, ElecType
-constexpr bool vdwCutoffCheck = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwaldTwin;
-
 template<enum ElecType elecType>
 constexpr bool elecEwald = EnergyFunctionProperties<elecType, VdwType::Count>().elecEwald;
 
@@ -100,6 +101,7 @@ using cl::sycl::access::fence_space;
 using cl::sycl::access::mode;
 using cl::sycl::access::target;
 
+//! \brief Convert \p sigma and \p epsilon VdW parameters to \c c6,c12 pair.
 static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float epsilon)
 {
     const float sigma2 = sigma * sigma;
@@ -107,9 +109,10 @@ static inline Float2 convertSigmaEpsilonToC6C12(const float sigma, const float e
     const float c6     = epsilon * sigma6;
     const float c12    = c6 * sigma6;
 
-    return Float2(c6, c12);
+    return { c6, c12 };
 }
 
+//! \brief Calculate force and energy for a pair of atoms, VdW force-switch flavor.
 template<bool doCalcEnergies>
 static inline void ljForceSwitch(const shift_consts_t         dispersionShift,
                                  const shift_consts_t         repulsionShift,
@@ -303,11 +306,11 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
  * c_clSize consecutive threads hold the force components of a j-atom which we
  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
  */
-static inline void reduceForceJShuffle(Float3                             f,
-                                       const cl::sycl::nd_item<1>         itemIdx,
-                                       const int                          tidxi,
-                                       const int                          aidx,
-                                       DeviceAccessor<float, mode_atomic> a_f)
+static inline void reduceForceJShuffle(Float3                                   f,
+                                       const cl::sycl::nd_item<3>               itemIdx,
+                                       const int                                tidxi,
+                                       const int                                aidx,
+                                       DeviceAccessor<Float3, mode::read_write> a_f)
 {
     static_assert(c_clSize == 8 || c_clSize == 4);
     sycl_2020::sub_group sg = itemIdx.get_sub_group();
@@ -334,10 +337,52 @@ static inline void reduceForceJShuffle(Float3                             f,
 
     if (tidxi < 3)
     {
-        atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
+        atomicFetchAdd(a_f[aidx][tidxi], f[0]);
     }
 }
 
+/*!
+ * \brief Do workgroup-level reduction of a single \c float.
+ *
+ * While SYCL has \c sycl::reduce_over_group, it currently (oneAPI 2021.3.0) uses a very large
+ * shared memory buffer, which leads to a reduced occupancy.
+ *
+ * \note The caller must make sure there are no races when reusing the \p sm_buf.
+ *
+ * \tparam subGroupSize Size of a sub-group.
+ * \tparam groupSize Size of a work-group.
+ * \param itemIdx Current thread's \c sycl::nd_item.
+ * \param tidxi Current thread's linearized local index.
+ * \param sm_buf Accessor for local reduction buffer.
+ * \param valueToReduce Current thread's value. Must have length of at least 1.
+ * \return For thread with \p tidxi 0: sum of all \p valueToReduce. Other threads: unspecified.
+ */
+template<int subGroupSize, int groupSize>
+static inline float groupReduce(const cl::sycl::nd_item<3> itemIdx,
+                                const unsigned int         tidxi,
+                                cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                float valueToReduce)
+{
+    constexpr int numSubGroupsInGroup = groupSize / subGroupSize;
+    static_assert(numSubGroupsInGroup == 1 || numSubGroupsInGroup == 2);
+    sycl_2020::sub_group sg = itemIdx.get_sub_group();
+    valueToReduce           = sycl_2020::group_reduce(sg, valueToReduce, sycl_2020::plus<float>());
+    // If we have two sub-groups, we should reduce across them.
+    if constexpr (numSubGroupsInGroup == 2)
+    {
+        if (tidxi == subGroupSize)
+        {
+            sm_buf[0] = valueToReduce;
+        }
+        itemIdx.barrier(fence_space::local_space);
+        if (tidxi == 0)
+        {
+            valueToReduce += sm_buf[0];
+        }
+    }
+    return valueToReduce;
+}
+
 /*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
  *
  * c_clSize consecutive threads hold the force components of a j-atom which we
@@ -346,18 +391,18 @@ static inline void reduceForceJShuffle(Float3                             f,
  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
  */
 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
-                                       Float3                             f,
-                                       const cl::sycl::nd_item<1>         itemIdx,
-                                       const int                          tidxi,
-                                       const int                          tidxj,
-                                       const int                          aidx,
-                                       DeviceAccessor<float, mode_atomic> a_f)
+                                       Float3                                   f,
+                                       const cl::sycl::nd_item<3>               itemIdx,
+                                       const int                                tidxi,
+                                       const int                                tidxj,
+                                       const int                                aidx,
+                                       DeviceAccessor<Float3, mode::read_write> a_f)
 {
     static constexpr int sc_fBufferStride = c_clSizeSq;
-    int                  tidx            = tidxi + tidxj * c_clSize;
-    sm_buf[0 * sc_fBufferStride + tidx]  = f[0];
-    sm_buf[1 * sc_fBufferStride + tidx]  = f[1];
-    sm_buf[2 * sc_fBufferStride + tidx]  = f[2];
+    int                  tidx             = tidxi + tidxj * c_clSize;
+    sm_buf[0 * sc_fBufferStride + tidx]   = f[0];
+    sm_buf[1 * sc_fBufferStride + tidx]   = f[1];
+    sm_buf[2 * sc_fBufferStride + tidx]   = f[2];
 
     subGroupBarrier(itemIdx);
 
@@ -372,7 +417,7 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
             fSum += sm_buf[sc_fBufferStride * tidxi + j];
         }
 
-        atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
+        atomicFetchAdd(a_f[aidx][tidxi], fSum);
     }
 }
 
@@ -381,11 +426,11 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
  */
 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                 Float3                                                        f,
-                                const cl::sycl::nd_item<1>         itemIdx,
-                                const int                          tidxi,
-                                const int                          tidxj,
-                                const int                          aidx,
-                                DeviceAccessor<float, mode_atomic> a_f)
+                                const cl::sycl::nd_item<3>               itemIdx,
+                                const int                                tidxi,
+                                const int                                tidxj,
+                                const int                                aidx,
+                                DeviceAccessor<Float3, mode::read_write> a_f)
 {
     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
     {
@@ -400,7 +445,7 @@ static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, t
 
 /*! \brief Final i-force reduction.
  *
- * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force componets stored in \p fCiBuf[]
+ * Reduce c_nbnxnGpuNumClusterPerSupercluster i-force components stored in \p fCiBuf[]
  * accumulating atomically into \p a_f.
  * If \p calcFShift is true, further reduce shift forces and atomically accumulate into \p a_fShift.
  *
@@ -409,13 +454,13 @@ static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, t
 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
                                          const bool   calcFShift,
-                                         const cl::sycl::nd_item<1>         itemIdx,
-                                         const int                          tidxi,
-                                         const int                          tidxj,
-                                         const int                          sci,
-                                         const int                          shift,
-                                         DeviceAccessor<float, mode_atomic> a_f,
-                                         DeviceAccessor<float, mode_atomic> a_fShift)
+                                         const cl::sycl::nd_item<3>               itemIdx,
+                                         const int                                tidxi,
+                                         const int                                tidxj,
+                                         const int                                sci,
+                                         const int                                shift,
+                                         DeviceAccessor<Float3, mode::read_write> a_f,
+                                         DeviceAccessor<Float3, mode::read_write> a_fShift)
 {
     // must have power of two elements in fCiBuf
     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
@@ -459,7 +504,7 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
         {
             const float f =
                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
-            atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
+            atomicFetchAdd(a_f[aidx][tidxj], f);
             if (calcFShift)
             {
                 fShiftBuf += f;
@@ -475,7 +520,26 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
            storing the reduction result above. */
         if (tidxj < 3)
         {
-            atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
+            if constexpr (c_clSize == 4)
+            {
+                /* Intel Xe (Gen12LP) and earlier GPUs implement floating-point atomics via
+                 * a compare-and-swap (CAS) loop. It has particularly poor performance when
+                 * updating the same memory location from the same work-group.
+                 * Such optimization might be slightly beneficial for NVIDIA and AMD as well,
+                 * but it is unlikely to make a big difference and thus was not evaluated.
+                 */
+                auto sg = itemIdx.get_sub_group();
+                fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 1);
+                fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
+                if (tidxi == 0)
+                {
+                    atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
+                }
+            }
+            else
+            {
+                atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
+            }
         }
     }
 }
@@ -484,13 +548,13 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
  *
  */
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
-auto nbnxmKernel(cl::sycl::handler&                                   cgh,
-                 DeviceAccessor<Float4, mode::read>                   a_xq,
-                 DeviceAccessor<float, mode_atomic>                   a_f,
-                 DeviceAccessor<Float3, mode::read>                   a_shiftVec,
-                 DeviceAccessor<float, mode_atomic>                   a_fShift,
-                 OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
-                 OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
+auto nbnxmKernel(cl::sycl::handler&                                        cgh,
+                 DeviceAccessor<Float4, mode::read>                        a_xq,
+                 DeviceAccessor<Float3, mode::read_write>                  a_f,
+                 DeviceAccessor<Float3, mode::read>                        a_shiftVec,
+                 DeviceAccessor<Float3, mode::read_write>                  a_fShift,
+                 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
+                 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
@@ -519,34 +583,34 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
 {
     static constexpr EnergyFunctionProperties<elecType, vdwType> props;
 
-    cgh.require(a_xq);
-    cgh.require(a_f);
-    cgh.require(a_shiftVec);
-    cgh.require(a_fShift);
+    a_xq.bind(cgh);
+    a_f.bind(cgh);
+    a_shiftVec.bind(cgh);
+    a_fShift.bind(cgh);
     if constexpr (doCalcEnergies)
     {
-        cgh.require(a_energyElec);
-        cgh.require(a_energyVdw);
+        a_energyElec.bind(cgh);
+        a_energyVdw.bind(cgh);
     }
-    cgh.require(a_plistCJ4);
-    cgh.require(a_plistSci);
-    cgh.require(a_plistExcl);
+    a_plistCJ4.bind(cgh);
+    a_plistSci.bind(cgh);
+    a_plistExcl.bind(cgh);
     if constexpr (!props.vdwComb)
     {
-        cgh.require(a_atomTypes);
-        cgh.require(a_nbfp);
+        a_atomTypes.bind(cgh);
+        a_nbfp.bind(cgh);
     }
     else
     {
-        cgh.require(a_ljComb);
+        a_ljComb.bind(cgh);
     }
     if constexpr (props.vdwEwald)
     {
-        cgh.require(a_nbfpComb);
+        a_nbfpComb.bind(cgh);
     }
     if constexpr (props.elecEwaldTab)
     {
-        cgh.require(a_coulombTab);
+        a_coulombTab.bind(cgh);
     }
 
     // shmem buffer for i x+q pre-loading
@@ -589,9 +653,10 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
             (props.elecEwald || props.elecRF || props.vdwEwald || (props.elecCutoff && doCalcEnergies));
 
     // The post-prune j-i cluster-pair organization is linked to how exclusion and interaction mask data is stored.
-    // Currently this is ideally suited for 32-wide subgroup size but slightly less so for others,
+    // Currently, this is ideally suited for 32-wide subgroup size but slightly less so for others,
     // e.g. subGroupSize > prunedClusterPairSize on AMD GCN / CDNA.
     // Hence, the two are decoupled.
+    // When changing this code, please update requiredSubGroupSizeForNbnxm in src/gromacs/hardware/device_management_sycl.cpp.
     constexpr int prunedClusterPairSize = c_clSize * c_splitClSize;
 #if defined(HIPSYCL_PLATFORM_ROCM) // SYCL-TODO AMD RDNA/RDNA2 has 32-wide exec; how can we check for that?
     gmx_unused constexpr int subGroupSize = c_clSize * c_clSize;
@@ -599,16 +664,14 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
     gmx_unused constexpr int subGroupSize = prunedClusterPairSize;
 #endif
 
-    return [=](cl::sycl::nd_item<1> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
+    return [=](cl::sycl::nd_item<3> itemIdx) [[intel::reqd_sub_group_size(subGroupSize)]]
     {
         /* thread/block/warp id-s */
-        const cl::sycl::id<3> localId = unflattenId<c_clSize, c_clSize>(itemIdx.get_local_id());
-        const unsigned        tidxi   = localId[0];
-        const unsigned        tidxj   = localId[1];
-        const unsigned        tidx    = tidxj * c_clSize + tidxi;
-        const unsigned        tidxz   = 0;
+        const unsigned tidxi = itemIdx.get_local_id(2);
+        const unsigned tidxj = itemIdx.get_local_id(1);
+        const unsigned tidx  = tidxj * c_clSize + tidxi;
+        const unsigned tidxz = 0;
 
-        // Group indexing was flat originally, no need to unflatten it.
         const unsigned bidx = itemIdx.get_group(0);
 
         const sycl_2020::sub_group sg = itemIdx.get_sub_group();
@@ -942,7 +1005,7 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
                                 if constexpr (props.elecRF)
                                 {
                                     energyElec +=
-                                            qi * qj * (pairExclMask * rInv + 0.5f * twoKRf * r2 - cRF);
+                                            qi * qj * (pairExclMask * rInv + 0.5F * twoKRf * r2 - cRF);
                                 }
                                 if constexpr (props.elecEwald)
                                 {
@@ -976,36 +1039,33 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
         } // for (int j4 = cij4Start; j4 < cij4End; j4 += 1)
 
         /* skip central shifts when summing shift forces */
-        const bool doCalcShift = (calcShift && !(nbSci.shift == gmx::c_centralShiftIndex));
+        const bool doCalcShift = (calcShift && nbSci.shift != gmx::c_centralShiftIndex);
 
         reduceForceIAndFShift(
                 sm_reductionBuffer, fCiBuf, doCalcShift, itemIdx, tidxi, tidxj, sci, nbSci.shift, a_f, a_fShift);
 
         if constexpr (doCalcEnergies)
         {
-            const float energyVdwGroup = sycl_2020::group_reduce(
-                    itemIdx.get_group(), energyVdw, 0.0F, sycl_2020::plus<float>());
-            const float energyElecGroup = sycl_2020::group_reduce(
-                    itemIdx.get_group(), energyElec, 0.0F, sycl_2020::plus<float>());
+            const float energyVdwGroup =
+                    groupReduce<subGroupSize, c_clSizeSq>(itemIdx, tidx, sm_reductionBuffer, energyVdw);
+            itemIdx.barrier(fence_space::local_space); // Prevent the race on sm_reductionBuffer.
+            const float energyElecGroup = groupReduce<subGroupSize, c_clSizeSq>(
+                    itemIdx, tidx, sm_reductionBuffer, energyElec);
 
             if (tidx == 0)
             {
-                atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
-                atomicFetchAdd(a_energyElec, 0, energyElecGroup);
+                atomicFetchAdd(a_energyVdw[0], energyVdwGroup);
+                atomicFetchAdd(a_energyElec[0], energyElecGroup);
             }
         }
     };
 }
 
-// SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
-template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
-class NbnxmKernelName;
-
+//! \brief NBNXM kernel launch code.
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType, class... Args>
 cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int numSci, Args&&... args)
 {
-    // Should not be needed for SYCL2020.
-    using kernelNameType = NbnxmKernelName<doPruneNBL, doCalcEnergies, elecType, vdwType>;
+    using kernelNameType = NbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>;
 
     /* Kernel launch config:
      * - The thread block dimensions match the size of i-clusters, j-clusters,
@@ -1013,7 +1073,7 @@ cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int nu
      * - The 1D block-grid contains as many blocks as super-clusters.
      */
     const int                   numBlocks = numSci;
-    const cl::sycl::range<3>    blockSize{ c_clSize, c_clSize, 1 };
+    const cl::sycl::range<3>    blockSize{ 1, c_clSize, c_clSize };
     const cl::sycl::range<3>    globalSize{ numBlocks * blockSize[0], blockSize[1], blockSize[2] };
     const cl::sycl::nd_range<3> range{ globalSize, blockSize };
 
@@ -1022,12 +1082,13 @@ cl::sycl::event launchNbnxmKernel(const DeviceStream& deviceStream, const int nu
     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
         auto kernel = nbnxmKernel<doPruneNBL, doCalcEnergies, elecType, vdwType>(
                 cgh, std::forward<Args>(args)...);
-        cgh.parallel_for<kernelNameType>(flattenNDRange(range), kernel);
+        cgh.parallel_for<kernelNameType>(range, kernel);
     });
 
     return e;
 }
 
+//! \brief Select templated kernel and launch it.
 template<class... Args>
 cl::sycl::event chooseAndLaunchNbnxmKernel(bool          doPruneNBL,
                                            bool          doCalcEnergies,
@@ -1054,12 +1115,6 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In
     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
 
-    // Casting to float simplifies using atomic ops in the kernel
-    cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
-    auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
-    cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
-    auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
-
     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
                                                    stepWork.computeEnergy,
                                                    nbp->elecType,
@@ -1067,9 +1122,9 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In
                                                    deviceStream,
                                                    plist->nsci,
                                                    adat->xq,
-                                                   fAsFloat,
+                                                   adat->f,
                                                    adat->shiftVec,
-                                                   fShiftAsFloat,
+                                                   adat->fShift,
                                                    adat->eElec,
                                                    adat->eLJ,
                                                    plist->cj4,