Implement generic j-reduction in the nbnxm SYCL kernels
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index 8b3b81b7f6aa22460852955e42136cbf3537b0c9..63be54335b0ccd6d9ea4e96976a8fb7dd9ecd85c 100644 (file)
@@ -298,7 +298,7 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
     return lerp(left, right, fraction); // TODO: cl::sycl::mix
 }
 
-/*! \brief Reduce c_clSize j-force components and atomically accumulate into a_f.
+/*! \brief Reduce c_clSize j-force components using shifts and atomically accumulate into a_f.
  *
  * c_clSize consecutive threads hold the force components of a j-atom which we
  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
@@ -338,6 +338,65 @@ static inline void reduceForceJShuffle(Float3                             f,
     }
 }
 
+/*! \brief Reduce c_clSize j-force components using local memory and atomically accumulate into a_f.
+ *
+ * c_clSize consecutive threads hold the force components of a j-atom which we
+ * reduced in cl_Size steps using shift and atomically accumulate them into \p a_f.
+ *
+ * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
+ */
+static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                       Float3                             f,
+                                       const cl::sycl::nd_item<1>         itemIdx,
+                                       const int                          tidxi,
+                                       const int                          tidxj,
+                                       const int                          aidx,
+                                       DeviceAccessor<float, mode_atomic> a_f)
+{
+    static constexpr int sc_fBufferStride = c_clSizeSq;
+    int                  tidx            = tidxi + tidxj * c_clSize;
+    sm_buf[0 * sc_fBufferStride + tidx]  = f[0];
+    sm_buf[1 * sc_fBufferStride + tidx]  = f[1];
+    sm_buf[2 * sc_fBufferStride + tidx]  = f[2];
+
+    subGroupBarrier(itemIdx);
+
+    // reducing data 8-by-by elements on the leader of same threads as those storing above
+    assert(itemIdx.get_sub_group().get_local_range().size() >= c_clSize);
+
+    if (tidxi < 3)
+    {
+        float fSum = 0.0F;
+        for (int j = tidxj * c_clSize; j < (tidxj + 1) * c_clSize; j++)
+        {
+            fSum += sm_buf[sc_fBufferStride * tidxi + j];
+        }
+
+        atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
+    }
+}
+
+
+/*! \brief Reduce c_clSize j-force components using either shifts or local memory and atomically accumulate into a_f.
+ */
+static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
+                                Float3                                                        f,
+                                const cl::sycl::nd_item<1>         itemIdx,
+                                const int                          tidxi,
+                                const int                          tidxj,
+                                const int                          aidx,
+                                DeviceAccessor<float, mode_atomic> a_f)
+{
+    if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
+    {
+        reduceForceJGeneric(sm_buf, f, itemIdx, tidxi, tidxj, aidx, a_f);
+    }
+    else
+    {
+        reduceForceJShuffle(f, itemIdx, tidxi, aidx, a_f);
+    }
+}
+
 
 /*! \brief Final i-force reduction.
  *
@@ -364,7 +423,7 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
     static constexpr int bufStride  = c_clSize * c_clSize;
     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
     const int            tidx       = tidxi + tidxj * c_clSize;
-    float                fShiftBuf  = 0;
+    float                fShiftBuf  = 0.0F;
     for (int ciOffset = 0; ciOffset < c_nbnxnGpuNumClusterPerSupercluster; ciOffset++)
     {
         const int aidx = (sci * c_nbnxnGpuNumClusterPerSupercluster + ciOffset) * c_clSize + tidxi;
@@ -906,7 +965,7 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
                     maskJI += maskJI;
                 } // for (int i = 0; i < c_nbnxnGpuNumClusterPerSupercluster; i++)
                 /* reduce j forces */
-                reduceForceJShuffle(fCjBuf, itemIdx, tidxi, aj, a_f);
+                reduceForceJ(sm_reductionBuffer, fCjBuf, itemIdx, tidxi, tidxj, aj, a_f);
             } // for (int jm = 0; jm < c_nbnxnGpuJgroupSize; jm++)
             if constexpr (doPruneNBL)
             {