SYCL: Fully switch to atomic_ref
[alexxy/gromacs.git] / src / gromacs / nbnxm / sycl / nbnxm_sycl_kernel.cpp
index 7304922c09f19ae5ea9a54a8b490190b837c040f..00957bb1b9cfd6c2cacccd599e032ff7993f23f3 100644 (file)
@@ -306,11 +306,11 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
  * c_clSize consecutive threads hold the force components of a j-atom which we
  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
  */
-static inline void reduceForceJShuffle(Float3                             f,
-                                       const cl::sycl::nd_item<1>         itemIdx,
-                                       const int                          tidxi,
-                                       const int                          aidx,
-                                       DeviceAccessor<float, mode_atomic> a_f)
+static inline void reduceForceJShuffle(Float3                                  f,
+                                       const cl::sycl::nd_item<1>              itemIdx,
+                                       const int                               tidxi,
+                                       const int                               aidx,
+                                       DeviceAccessor<float, mode::read_write> a_f)
 {
     static_assert(c_clSize == 8 || c_clSize == 4);
     sycl_2020::sub_group sg = itemIdx.get_sub_group();
@@ -337,7 +337,7 @@ static inline void reduceForceJShuffle(Float3                             f,
 
     if (tidxi < 3)
     {
-        atomicFetchAdd(a_f, 3 * aidx + tidxi, f[0]);
+        atomicFetchAdd(a_f[3 * aidx + tidxi], f[0]);
     }
 }
 
@@ -389,12 +389,12 @@ static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
  */
 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
-                                       Float3                             f,
-                                       const cl::sycl::nd_item<1>         itemIdx,
-                                       const int                          tidxi,
-                                       const int                          tidxj,
-                                       const int                          aidx,
-                                       DeviceAccessor<float, mode_atomic> a_f)
+                                       Float3                                  f,
+                                       const cl::sycl::nd_item<1>              itemIdx,
+                                       const int                               tidxi,
+                                       const int                               tidxj,
+                                       const int                               aidx,
+                                       DeviceAccessor<float, mode::read_write> a_f)
 {
     static constexpr int sc_fBufferStride = c_clSizeSq;
     int                  tidx             = tidxi + tidxj * c_clSize;
@@ -415,7 +415,7 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
             fSum += sm_buf[sc_fBufferStride * tidxi + j];
         }
 
-        atomicFetchAdd(a_f, 3 * aidx + tidxi, fSum);
+        atomicFetchAdd(a_f[3 * aidx + tidxi], fSum);
     }
 }
 
@@ -424,11 +424,11 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
  */
 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                 Float3                                                        f,
-                                const cl::sycl::nd_item<1>         itemIdx,
-                                const int                          tidxi,
-                                const int                          tidxj,
-                                const int                          aidx,
-                                DeviceAccessor<float, mode_atomic> a_f)
+                                const cl::sycl::nd_item<1>              itemIdx,
+                                const int                               tidxi,
+                                const int                               tidxj,
+                                const int                               aidx,
+                                DeviceAccessor<float, mode::read_write> a_f)
 {
     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
     {
@@ -452,13 +452,13 @@ static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, t
 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
                                          const bool   calcFShift,
-                                         const cl::sycl::nd_item<1>         itemIdx,
-                                         const int                          tidxi,
-                                         const int                          tidxj,
-                                         const int                          sci,
-                                         const int                          shift,
-                                         DeviceAccessor<float, mode_atomic> a_f,
-                                         DeviceAccessor<float, mode_atomic> a_fShift)
+                                         const cl::sycl::nd_item<1>              itemIdx,
+                                         const int                               tidxi,
+                                         const int                               tidxj,
+                                         const int                               sci,
+                                         const int                               shift,
+                                         DeviceAccessor<float, mode::read_write> a_f,
+                                         DeviceAccessor<float, mode::read_write> a_fShift)
 {
     // must have power of two elements in fCiBuf
     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
@@ -502,7 +502,7 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
         {
             const float f =
                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
-            atomicFetchAdd(a_f, 3 * aidx + tidxj, f);
+            atomicFetchAdd(a_f[3 * aidx + tidxj], f);
             if (calcFShift)
             {
                 fShiftBuf += f;
@@ -531,12 +531,12 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
                 if (tidxi == 0)
                 {
-                    atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
+                    atomicFetchAdd(a_fShift[3 * shift + tidxj], fShiftBuf);
                 }
             }
             else
             {
-                atomicFetchAdd(a_fShift, 3 * shift + tidxj, fShiftBuf);
+                atomicFetchAdd(a_fShift[3 * shift + tidxj], fShiftBuf);
             }
         }
     }
@@ -546,13 +546,13 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
  *
  */
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
-auto nbnxmKernel(cl::sycl::handler&                                   cgh,
-                 DeviceAccessor<Float4, mode::read>                   a_xq,
-                 DeviceAccessor<float, mode_atomic>                   a_f,
-                 DeviceAccessor<Float3, mode::read>                   a_shiftVec,
-                 DeviceAccessor<float, mode_atomic>                   a_fShift,
-                 OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyElec,
-                 OptionalAccessor<float, mode_atomic, doCalcEnergies> a_energyVdw,
+auto nbnxmKernel(cl::sycl::handler&                                        cgh,
+                 DeviceAccessor<Float4, mode::read>                        a_xq,
+                 DeviceAccessor<float, mode::read_write>                   a_f,
+                 DeviceAccessor<Float3, mode::read>                        a_shiftVec,
+                 DeviceAccessor<float, mode::read_write>                   a_fShift,
+                 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
+                 OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
                  DeviceAccessor<nbnxn_sci_t, mode::read>                                 a_plistSci,
                  DeviceAccessor<nbnxn_excl_t, mode::read>                    a_plistExcl,
@@ -1053,8 +1053,8 @@ auto nbnxmKernel(cl::sycl::handler&                                   cgh,
 
             if (tidx == 0)
             {
-                atomicFetchAdd(a_energyVdw, 0, energyVdwGroup);
-                atomicFetchAdd(a_energyElec, 0, energyElecGroup);
+                atomicFetchAdd(a_energyVdw[0], energyVdwGroup);
+                atomicFetchAdd(a_energyElec[0], energyElecGroup);
             }
         }
     };