Get rid of sycl::buffer::reinterpret
authorAndrey Alekseenko <al42and@gmail.com>
Fri, 27 Aug 2021 12:12:57 +0000 (14:12 +0200)
committerAndrey Alekseenko <al42and@gmail.com>
Sat, 28 Aug 2021 09:29:41 +0000 (12:29 +0300)
This functionality is not properly supported in hipSYCL yet, and was
only needed in order to use atomic accessors. After fully switching to
atomic_ref, we can directly use buffers of Float3.

Closes #4063.

src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp

index 00957bb1b9cfd6c2cacccd599e032ff7993f23f3..16be5bc7df8ec5f8f38ce0ef64a92e19c2e1dadf 100644 (file)
@@ -306,11 +306,11 @@ static inline float interpolateCoulombForceR(const DeviceAccessor<float, mode::r
  * c_clSize consecutive threads hold the force components of a j-atom which we
  * reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
  */
-static inline void reduceForceJShuffle(Float3                                  f,
-                                       const cl::sycl::nd_item<1>              itemIdx,
-                                       const int                               tidxi,
-                                       const int                               aidx,
-                                       DeviceAccessor<float, mode::read_write> a_f)
+static inline void reduceForceJShuffle(Float3                                   f,
+                                       const cl::sycl::nd_item<1>               itemIdx,
+                                       const int                                tidxi,
+                                       const int                                aidx,
+                                       DeviceAccessor<Float3, mode::read_write> a_f)
 {
     static_assert(c_clSize == 8 || c_clSize == 4);
     sycl_2020::sub_group sg = itemIdx.get_sub_group();
@@ -337,7 +337,7 @@ static inline void reduceForceJShuffle(Float3                                  f
 
     if (tidxi < 3)
     {
-        atomicFetchAdd(a_f[3 * aidx + tidxi], f[0]);
+        atomicFetchAdd(a_f[aidx][tidxi], f[0]);
     }
 }
 
@@ -389,12 +389,12 @@ static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx,
  * TODO: implement binary reduction flavor for the case where cl_Size is power of two.
  */
 static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
-                                       Float3                                  f,
-                                       const cl::sycl::nd_item<1>              itemIdx,
-                                       const int                               tidxi,
-                                       const int                               tidxj,
-                                       const int                               aidx,
-                                       DeviceAccessor<float, mode::read_write> a_f)
+                                       Float3                                   f,
+                                       const cl::sycl::nd_item<1>               itemIdx,
+                                       const int                                tidxi,
+                                       const int                                tidxj,
+                                       const int                                aidx,
+                                       DeviceAccessor<Float3, mode::read_write> a_f)
 {
     static constexpr int sc_fBufferStride = c_clSizeSq;
     int                  tidx             = tidxi + tidxj * c_clSize;
@@ -415,7 +415,7 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
             fSum += sm_buf[sc_fBufferStride * tidxi + j];
         }
 
-        atomicFetchAdd(a_f[3 * aidx + tidxi], fSum);
+        atomicFetchAdd(a_f[aidx][tidxi], fSum);
     }
 }
 
@@ -424,11 +424,11 @@ static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_w
  */
 static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                 Float3                                                        f,
-                                const cl::sycl::nd_item<1>              itemIdx,
-                                const int                               tidxi,
-                                const int                               tidxj,
-                                const int                               aidx,
-                                DeviceAccessor<float, mode::read_write> a_f)
+                                const cl::sycl::nd_item<1>               itemIdx,
+                                const int                                tidxi,
+                                const int                                tidxj,
+                                const int                                aidx,
+                                DeviceAccessor<Float3, mode::read_write> a_f)
 {
     if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
     {
@@ -452,13 +452,13 @@ static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, t
 static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
                                          const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
                                          const bool   calcFShift,
-                                         const cl::sycl::nd_item<1>              itemIdx,
-                                         const int                               tidxi,
-                                         const int                               tidxj,
-                                         const int                               sci,
-                                         const int                               shift,
-                                         DeviceAccessor<float, mode::read_write> a_f,
-                                         DeviceAccessor<float, mode::read_write> a_fShift)
+                                         const cl::sycl::nd_item<1>               itemIdx,
+                                         const int                                tidxi,
+                                         const int                                tidxj,
+                                         const int                                sci,
+                                         const int                                shift,
+                                         DeviceAccessor<Float3, mode::read_write> a_f,
+                                         DeviceAccessor<Float3, mode::read_write> a_fShift)
 {
     // must have power of two elements in fCiBuf
     static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
@@ -502,7 +502,7 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
         {
             const float f =
                     sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
-            atomicFetchAdd(a_f[3 * aidx + tidxj], f);
+            atomicFetchAdd(a_f[aidx][tidxj], f);
             if (calcFShift)
             {
                 fShiftBuf += f;
@@ -531,12 +531,12 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
                 fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
                 if (tidxi == 0)
                 {
-                    atomicFetchAdd(a_fShift[3 * shift + tidxj], fShiftBuf);
+                    atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
                 }
             }
             else
             {
-                atomicFetchAdd(a_fShift[3 * shift + tidxj], fShiftBuf);
+                atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
             }
         }
     }
@@ -548,9 +548,9 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read
 template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
 auto nbnxmKernel(cl::sycl::handler&                                        cgh,
                  DeviceAccessor<Float4, mode::read>                        a_xq,
-                 DeviceAccessor<float, mode::read_write>                   a_f,
+                 DeviceAccessor<Float3, mode::read_write>                  a_f,
                  DeviceAccessor<Float3, mode::read>                        a_shiftVec,
-                 DeviceAccessor<float, mode::read_write>                   a_fShift,
+                 DeviceAccessor<Float3, mode::read_write>                  a_fShift,
                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
                  OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
                  DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
@@ -1114,12 +1114,6 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In
     const bool          doPruneNBL   = (plist->haveFreshList && !nb->didPrune[iloc]);
     const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
 
-    // Casting to float simplifies using atomic ops in the kernel
-    cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
-    auto                        fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
-    cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
-    auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
-
     cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
                                                    stepWork.computeEnergy,
                                                    nbp->elecType,
@@ -1127,9 +1121,9 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In
                                                    deviceStream,
                                                    plist->nsci,
                                                    adat->xq,
-                                                   fAsFloat,
+                                                   adat->f,
                                                    adat->shiftVec,
-                                                   fShiftAsFloat,
+                                                   adat->fShift,
                                                    adat->eElec,
                                                    adat->eLJ,
                                                    plist->cj4,