Set up build with hipSYCL
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index 770732c70cc9eda458b02717ce98c64135425642..a372afa77542b646d53bf6a2e413db20558a7f3b 100644 (file)
@@ -301,25 +301,25 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
     return lerp(left, right, fraction); // TODO: cl::sycl::mix
 }
 
-static inline void reduceForceJShuffle(Float3                                  f,
-                                       const cl::sycl::nd_item<1>              itemIdx,
-                                       const int                               tidxi,
-                                       const int                               aidx,
-                                       DeviceAccessor<float, mode::read_write> a_f)
+static inline void reduceForceJShuffle(Float3                             f,
+                                       const cl::sycl::nd_item<1>         itemIdx,
+                                       const int                          tidxi,
+                                       const int                          aidx,
+                                       DeviceAccessor<float, mode_atomic> a_f)
 {
     static_assert(c_clSize == 8 || c_clSize == 4);
     sycl_2020::sub_group sg = itemIdx.get_sub_group();
 
-    f[0] += shuffleDown(f[0], 1, sg);
-    f[1] += shuffleUp(f[1], 1, sg);
-    f[2] += shuffleDown(f[2], 1, sg);
+    f[0] += sycl_2020::shift_left(sg, f[0], 1);
+    f[1] += sycl_2020::shift_right(sg, f[1], 1);
+    f[2] += sycl_2020::shift_left(sg, f[2], 1);
     if (tidxi & 1)
     {
         f[0] = f[1];
     }
 
-    f[0] += shuffleDown(f[0], 2, sg);
-    f[2] += shuffleUp(f[2], 2, sg);
+    f[0] += sycl_2020::shift_left(sg, f[0], 2);
+    f[2] += sycl_2020::shift_right(sg, f[2], 2);
     if (tidxi & 2)
     {
         f[0] = f[2];
@@ -327,7 +327,7 @@ static inline void reduceForceJShuffle(Float3                                  f
 
     if constexpr (c_clSize == 8)
     {
-        f[0] += shuffleDown(f[0], 4, sg);
+        f[0] += sycl_2020::shift_left(sg, f[0], 4);
     }
 
     if (tidxi < 3)
@@ -344,13 +344,13 @@ static inline void reduceForceJShuffle(Float3                                  f
 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
                                          const bool   calcFShift,
-                                         const cl::sycl::nd_item<1>              itemIdx,
-                                         const int                               tidxi,
-                                         const int                               tidxj,
-                                         const int                               sci,
-                                         const int                               shift,
-                                         DeviceAccessor<float, mode::read_write> a_f,
-                                         DeviceAccessor<float, mode::read_write> a_fShift)
+                                         const cl::sycl::nd_item<1>         itemIdx,
+                                         const int                          tidxi,
+                                         const int                          tidxj,
+                                         const int                          sci,
+                                         const int                          shift,
+                                         DeviceAccessor<float, mode_atomic> a_f,
+                                         DeviceAccessor<float, mode_atomic> a_fShift)
 {
     static constexpr int bufStride  = c_clSize * c_clSize;
     static constexpr int clSizeLog2 = gmx::StaticLog2<c_clSize>::value;
@@ -417,13 +417,13 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
  *
  */
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
-auto nbnxmKernel(cl::sycl::handler&                                        cgh,
-                 DeviceAccessor<Float4, mode::read>                        a_xq,
-                 DeviceAccessor<float, mode::read_write>                   a_f,
-                 DeviceAccessor<Float3, mode::read>                        a_shiftVec,
-                 DeviceAccessor<float, mode::read_write>                   a_fShift,
-                 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
-                 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
+auto nbnxmKernel(cl::sycl::handler&                                   cgh,
+                 DeviceAccessor<Float4, mode::read>                   a_xq,
+                 DeviceAccessor<float, mode_atomic>                   a_f,
+                 DeviceAccessor<Float3, mode::read>                   a_shiftVec,
+                 DeviceAccessor<float, mode_atomic>                   a_fShift,
+                 OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
+                 OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
@@ -745,8 +745,13 @@ auto nbnxmKernel(cl::sycl::handler&                                        cgh,
 
                             // Ensure distance do not become so small that r^-12 overflows
                             r2 = std::max(r2, c_nbnxnMinDistanceSquared);
+#if GMX_SYCL_HIPSYCL
+                            // No fast/native functions in some compilation passes
+                            const float rInv = cl::sycl::rsqrt(r2);
+#else
                             // SYCL-TODO: sycl::half_precision::rsqrt?
-                            const float rInv  = cl::sycl::native::rsqrt(r2);
+                            const float rInv = cl::sycl::native::rsqrt(r2);
+#endif
                             const float r2Inv = rInv * rInv;
                             float       r6Inv, fInvR, energyLJPair;
                             if constexpr (!props.vdwCombLB || doCalcEnergies)