From: Andrey Alekseenko Date: Fri, 27 Aug 2021 12:12:57 +0000 (+0200) Subject: Get rid of sycl::buffer::reinterpret X-Git-Url: http://biod.pnpi.spb.ru/gitweb/?a=commitdiff_plain;h=7d9f6fa189a6e99595eb150adee6af41efb186fe;p=alexxy%2Fgromacs.git Get rid of sycl::buffer::reinterpret This functionality is not properly supported in hipSYCL yet, and was only needed in order to use atomic accessors. After fully switching to atomic_ref, we can directly use buffers of Float3. Closes #4063. --- diff --git a/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp b/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp index 00957bb1b9..16be5bc7df 100644 --- a/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp +++ b/src/gromacs/nbnxm/sycl/nbnxm_sycl_kernel.cpp @@ -306,11 +306,11 @@ static inline float interpolateCoulombForceR(const DeviceAccessor itemIdx, - const int tidxi, - const int aidx, - DeviceAccessor a_f) +static inline void reduceForceJShuffle(Float3 f, + const cl::sycl::nd_item<1> itemIdx, + const int tidxi, + const int aidx, + DeviceAccessor a_f) { static_assert(c_clSize == 8 || c_clSize == 4); sycl_2020::sub_group sg = itemIdx.get_sub_group(); @@ -337,7 +337,7 @@ static inline void reduceForceJShuffle(Float3 f if (tidxi < 3) { - atomicFetchAdd(a_f[3 * aidx + tidxi], f[0]); + atomicFetchAdd(a_f[aidx][tidxi], f[0]); } } @@ -389,12 +389,12 @@ static inline float groupReduce(const cl::sycl::nd_item<1> itemIdx, * TODO: implement binary reduction flavor for the case where cl_Size is power of two. */ static inline void reduceForceJGeneric(cl::sycl::accessor sm_buf, - Float3 f, - const cl::sycl::nd_item<1> itemIdx, - const int tidxi, - const int tidxj, - const int aidx, - DeviceAccessor a_f) + Float3 f, + const cl::sycl::nd_item<1> itemIdx, + const int tidxi, + const int tidxj, + const int aidx, + DeviceAccessor a_f) { static constexpr int sc_fBufferStride = c_clSizeSq; int tidx = tidxi + tidxj * c_clSize; @@ -415,7 +415,7 @@ static inline void reduceForceJGeneric(cl::sycl::accessor sm_buf, Float3 f, - const cl::sycl::nd_item<1> itemIdx, - const int tidxi, - const int tidxj, - const int aidx, - DeviceAccessor a_f) + const cl::sycl::nd_item<1> itemIdx, + const int tidxi, + const int tidxj, + const int aidx, + DeviceAccessor a_f) { if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster)) { @@ -452,13 +452,13 @@ static inline void reduceForceJ(cl::sycl::accessor sm_buf, const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster], const bool calcFShift, - const cl::sycl::nd_item<1> itemIdx, - const int tidxi, - const int tidxj, - const int sci, - const int shift, - DeviceAccessor a_f, - DeviceAccessor a_fShift) + const cl::sycl::nd_item<1> itemIdx, + const int tidxi, + const int tidxj, + const int sci, + const int shift, + DeviceAccessor a_f, + DeviceAccessor a_fShift) { // must have power of two elements in fCiBuf static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster)); @@ -502,7 +502,7 @@ static inline void reduceForceIAndFShift(cl::sycl::accessor auto nbnxmKernel(cl::sycl::handler& cgh, DeviceAccessor a_xq, - DeviceAccessor a_f, + DeviceAccessor a_f, DeviceAccessor a_shiftVec, - DeviceAccessor a_fShift, + DeviceAccessor a_fShift, OptionalAccessor a_energyElec, OptionalAccessor a_energyVdw, DeviceAccessor a_plistCJ4, @@ -1114,12 +1114,6 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In const bool doPruneNBL = (plist->haveFreshList && !nb->didPrune[iloc]); const DeviceStream& deviceStream = *nb->deviceStreams[iloc]; - // Casting to float simplifies using atomic ops in the kernel - cl::sycl::buffer f(*adat->f.buffer_); - auto fAsFloat = f.reinterpret(f.get_count() * DIM); - cl::sycl::buffer fShift(*adat->fShift.buffer_); - auto fShiftAsFloat = fShift.reinterpret(fShift.get_count() * DIM); - cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL, stepWork.computeEnergy, nbp->elecType, @@ -1127,9 +1121,9 @@ void launchNbnxmKernel(NbnxmGpu* nb, const gmx::StepWorkload& stepWork, const In deviceStream, plist->nsci, adat->xq, - fAsFloat, + adat->f, adat->shiftVec, - fShiftAsFloat, + adat->fShift, adat->eElec, adat->eLJ, plist->cj4,