* c_clSize consecutive threads hold the force components of a j-atom which we
* reduced in log2(cl_Size) steps using shift and atomically accumulate them into \p a_f.
*/
-static inline void reduceForceJShuffle(Float3 f,
- const cl::sycl::nd_item<1> itemIdx,
- const int tidxi,
- const int aidx,
- DeviceAccessor<float, mode::read_write> a_f)
+static inline void reduceForceJShuffle(Float3 f,
+ const cl::sycl::nd_item<1> itemIdx,
+ const int tidxi,
+ const int aidx,
+ DeviceAccessor<Float3, mode::read_write> a_f)
{
static_assert(c_clSize == 8 || c_clSize == 4);
sycl_2020::sub_group sg = itemIdx.get_sub_group();
if (tidxi < 3)
{
- atomicFetchAdd(a_f[3 * aidx + tidxi], f[0]);
+ atomicFetchAdd(a_f[aidx][tidxi], f[0]);
}
}
* TODO: implement binary reduction flavor for the case where cl_Size is power of two.
*/
static inline void reduceForceJGeneric(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
- Float3 f,
- const cl::sycl::nd_item<1> itemIdx,
- const int tidxi,
- const int tidxj,
- const int aidx,
- DeviceAccessor<float, mode::read_write> a_f)
+ Float3 f,
+ const cl::sycl::nd_item<1> itemIdx,
+ const int tidxi,
+ const int tidxj,
+ const int aidx,
+ DeviceAccessor<Float3, mode::read_write> a_f)
{
static constexpr int sc_fBufferStride = c_clSizeSq;
int tidx = tidxi + tidxj * c_clSize;
fSum += sm_buf[sc_fBufferStride * tidxi + j];
}
- atomicFetchAdd(a_f[3 * aidx + tidxi], fSum);
+ atomicFetchAdd(a_f[aidx][tidxi], fSum);
}
}
*/
static inline void reduceForceJ(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
Float3 f,
- const cl::sycl::nd_item<1> itemIdx,
- const int tidxi,
- const int tidxj,
- const int aidx,
- DeviceAccessor<float, mode::read_write> a_f)
+ const cl::sycl::nd_item<1> itemIdx,
+ const int tidxi,
+ const int tidxj,
+ const int aidx,
+ DeviceAccessor<Float3, mode::read_write> a_f)
{
if constexpr (!gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster))
{
static inline void reduceForceIAndFShift(cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buf,
const Float3 fCiBuf[c_nbnxnGpuNumClusterPerSupercluster],
const bool calcFShift,
- const cl::sycl::nd_item<1> itemIdx,
- const int tidxi,
- const int tidxj,
- const int sci,
- const int shift,
- DeviceAccessor<float, mode::read_write> a_f,
- DeviceAccessor<float, mode::read_write> a_fShift)
+ const cl::sycl::nd_item<1> itemIdx,
+ const int tidxi,
+ const int tidxj,
+ const int sci,
+ const int shift,
+ DeviceAccessor<Float3, mode::read_write> a_f,
+ DeviceAccessor<Float3, mode::read_write> a_fShift)
{
// must have power of two elements in fCiBuf
static_assert(gmx::isPowerOfTwo(c_nbnxnGpuNumClusterPerSupercluster));
{
const float f =
sm_buf[tidxj * bufStride + tidxi] + sm_buf[tidxj * bufStride + c_clSize + tidxi];
- atomicFetchAdd(a_f[3 * aidx + tidxj], f);
+ atomicFetchAdd(a_f[aidx][tidxj], f);
if (calcFShift)
{
fShiftBuf += f;
fShiftBuf += sycl_2020::shift_left(sg, fShiftBuf, 2);
if (tidxi == 0)
{
- atomicFetchAdd(a_fShift[3 * shift + tidxj], fShiftBuf);
+ atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
}
}
else
{
- atomicFetchAdd(a_fShift[3 * shift + tidxj], fShiftBuf);
+ atomicFetchAdd(a_fShift[shift][tidxj], fShiftBuf);
}
}
}
template<bool doPruneNBL, bool doCalcEnergies, enum ElecType elecType, enum VdwType vdwType>
auto nbnxmKernel(cl::sycl::handler& cgh,
DeviceAccessor<Float4, mode::read> a_xq,
- DeviceAccessor<float, mode::read_write> a_f,
+ DeviceAccessor<Float3, mode::read_write> a_f,
DeviceAccessor<Float3, mode::read> a_shiftVec,
- DeviceAccessor<float, mode::read_write> a_fShift,
+ DeviceAccessor<Float3, mode::read_write> a_fShift,
OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyElec,
OptionalAccessor<float, mode::read_write, doCalcEnergies> a_energyVdw,
DeviceAccessor<nbnxn_cj4_t, doPruneNBL ? mode::read_write : mode::read> a_plistCJ4,
const bool doPruneNBL = (plist->haveFreshList && !nb->didPrune[iloc]);
const DeviceStream& deviceStream = *nb->deviceStreams[iloc];
- // Casting to float simplifies using atomic ops in the kernel
- cl::sycl::buffer<Float3, 1> f(*adat->f.buffer_);
- auto fAsFloat = f.reinterpret<float, 1>(f.get_count() * DIM);
- cl::sycl::buffer<Float3, 1> fShift(*adat->fShift.buffer_);
- auto fShiftAsFloat = fShift.reinterpret<float, 1>(fShift.get_count() * DIM);
-
cl::sycl::event e = chooseAndLaunchNbnxmKernel(doPruneNBL,
stepWork.computeEnergy,
nbp->elecType,
deviceStream,
plist->nsci,
adat->xq,
- fAsFloat,
+ adat->f,
adat->shiftVec,
- fShiftAsFloat,
+ adat->fShift,
adat->eElec,
adat->eLJ,
plist->cj4,