A few other minor changes to SYCL version of DeviceBuffer.
This is a prerequisite for adding support for USM.
Refs #3847, #3965.
12 files changed:
constexpr int splineParamsSize = atomsPerBlock * DIM * order;
constexpr int gridlineIndicesSize = atomsPerBlock * DIM;
constexpr int splineParamsSize = atomsPerBlock * DIM * order;
constexpr int gridlineIndicesSize = atomsPerBlock * DIM;
- cgh.require(a_gridA);
- cgh.require(a_coefficientsA);
- cgh.require(a_forces);
+ a_gridA.bind(cgh);
+ a_coefficientsA.bind(cgh);
+ a_forces.bind(cgh);
if constexpr (numGrids == 2)
{
if constexpr (numGrids == 2)
{
- cgh.require(a_gridB);
- cgh.require(a_coefficientsB);
+ a_gridB.bind(cgh);
+ a_coefficientsB.bind(cgh);
}
if constexpr (readGlobal)
{
}
if constexpr (readGlobal)
{
- cgh.require(a_theta);
- cgh.require(a_dtheta);
- cgh.require(a_gridlineIndices);
+ a_theta.bind(cgh);
+ a_dtheta.bind(cgh);
+ a_gridlineIndices.bind(cgh);
- cgh.require(a_coordinates);
- cgh.require(a_fractShiftsTable);
- cgh.require(a_gridlineIndicesTable);
+ a_coordinates.bind(cgh);
+ a_fractShiftsTable.bind(cgh);
+ a_gridlineIndicesTable.bind(cgh);
}
// Gridline indices, ivec
}
// Gridline indices, ivec
OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
DeviceAccessor<float, mode::read_write> a_fourierGrid)
{
OptionalAccessor<float, mode::read_write, computeEnergyAndVirial> a_virialAndEnergy,
DeviceAccessor<float, mode::read_write> a_fourierGrid)
{
- cgh.require(a_splineModuli);
- cgh.require(a_solveKernelParams);
+ a_splineModuli.bind(cgh);
+ a_solveKernelParams.bind(cgh);
if constexpr (computeEnergyAndVirial)
{
if constexpr (computeEnergyAndVirial)
{
- cgh.require(a_virialAndEnergy);
+ a_virialAndEnergy.bind(cgh);
- cgh.require(a_fourierGrid);
+ a_fourierGrid.bind(cgh);
/* Reduce 7 outputs per warp in the shared memory */
const int stride =
/* Reduce 7 outputs per warp in the shared memory */
const int stride =
if constexpr (spreadCharges)
{
if constexpr (spreadCharges)
{
- cgh.require(a_realGrid_0);
+ a_realGrid_0.bind(cgh);
}
if constexpr (writeGlobal || computeSplines)
{
}
if constexpr (writeGlobal || computeSplines)
{
}
if constexpr (computeSplines && writeGlobal)
{
}
if constexpr (computeSplines && writeGlobal)
{
}
if constexpr (writeGlobal)
{
}
if constexpr (writeGlobal)
{
- cgh.require(a_gridlineIndices);
+ a_gridlineIndices.bind(cgh);
}
if constexpr (computeSplines)
{
}
if constexpr (computeSplines)
{
- cgh.require(a_fractShiftsTable);
- cgh.require(a_gridlineIndicesTable);
- cgh.require(a_coordinates);
+ a_fractShiftsTable.bind(cgh);
+ a_gridlineIndicesTable.bind(cgh);
+ a_coordinates.bind(cgh);
- cgh.require(a_coefficients_0);
+ a_coefficients_0.bind(cgh);
if constexpr (numGrids == 2 && spreadCharges)
{
if constexpr (numGrids == 2 && spreadCharges)
{
- cgh.require(a_realGrid_1);
- cgh.require(a_coefficients_1);
+ a_realGrid_1.bind(cgh);
+ a_coefficients_1.bind(cgh);
}
// Gridline indices, ivec
}
// Gridline indices, ivec
static_assert(mode == cl::sycl::access::mode::read,
"Can not create non-read-only accessor from a const DeviceBuffer");
}
static_assert(mode == cl::sycl::access::mode::read,
"Can not create non-read-only accessor from a const DeviceBuffer");
}
+ void bind(cl::sycl::handler& cgh) { cgh.require(*this); }
private:
//! Helper function to get sycl:buffer object from DeviceBuffer wrapper, with a sanity check.
private:
//! Helper function to get sycl:buffer object from DeviceBuffer wrapper, with a sanity check.
namespace gmx::internal
{
namespace gmx::internal
{
-//! A "blackhole" class to be used when we want to ignore an argument to a function.
-struct EmptyClassThatIgnoresConstructorArguments
+//! A non-functional class that can be used instead of real accessors
+template<class T>
+struct NullAccessor
- template<class... Args>
- [[maybe_unused]] EmptyClassThatIgnoresConstructorArguments(Args&&... /*args*/)
- {
- }
+ NullAccessor(const DeviceBuffer<T>& /*buffer*/) {}
//! Allow casting to nullptr
constexpr operator std::nullptr_t() const { return nullptr; }
//! Allow casting to nullptr
constexpr operator std::nullptr_t() const { return nullptr; }
+ //! Placeholder implementation of \c cl::sycl::accessor::get_pointer.
+ T* get_pointer() const noexcept { return nullptr; }
+ void bind(cl::sycl::handler& /*cgh*/) { assert(false); }
};
} // namespace gmx::internal
};
} // namespace gmx::internal
*/
template<class T, cl::sycl::access::mode mode, bool enabled>
using OptionalAccessor =
*/
template<class T, cl::sycl::access::mode mode, bool enabled>
using OptionalAccessor =
- std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::EmptyClassThatIgnoresConstructorArguments>;
+ std::conditional_t<enabled, DeviceAccessor<T, mode>, gmx::internal::NullAccessor<T>>;
#endif // #ifndef DOXYGEN
#endif // #ifndef DOXYGEN
if constexpr (usingHipSycl)
{
// hipSYCL does not support reinterpret but allows using Float3 directly.
if constexpr (usingHipSycl)
{
// hipSYCL does not support reinterpret but allows using Float3 directly.
template<typename ValueType>
void destroyParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer, DeviceTexture* /* deviceTexture */)
{
template<typename ValueType>
void destroyParamLookupTable(DeviceBuffer<ValueType>* deviceBuffer, DeviceTexture* /* deviceTexture */)
{
- deviceBuffer->buffer_.reset(nullptr);
+ freeDeviceBuffer(deviceBuffer);
}
#endif // GMX_GPU_UTILS_DEVICEBUFFER_SYCL_H
}
#endif // GMX_GPU_UTILS_DEVICEBUFFER_SYCL_H
DeviceAccessor<int, cl::sycl::access::mode::read> a_cell,
const int atomStart)
{
DeviceAccessor<int, cl::sycl::access::mode::read> a_cell,
const int atomStart)
{
- cgh.require(a_nbnxmForce);
+ a_nbnxmForce.bind(cgh);
if constexpr (addRvecForce)
{
if constexpr (addRvecForce)
{
- cgh.require(a_rvecForceToAdd);
+ a_rvecForceToAdd.bind(cgh);
- cgh.require(a_forceTotal);
- cgh.require(a_cell);
+ a_forceTotal.bind(cgh);
+ a_cell.bind(cgh);
return [=](cl::sycl::id<1> itemIdx) {
// Set to nbnxnm force, then perhaps accumulate further to it
return [=](cl::sycl::id<1> itemIdx) {
// Set to nbnxnm force, then perhaps accumulate further to it
OptionalAccessor<unsigned short, mode::read, numTempScaleValues == NumTempScaleValues::Multiple> a_tempScaleGroups,
Float3 prVelocityScalingMatrixDiagonal)
{
OptionalAccessor<unsigned short, mode::read, numTempScaleValues == NumTempScaleValues::Multiple> a_tempScaleGroups,
Float3 prVelocityScalingMatrixDiagonal)
{
- cgh.require(a_x);
- cgh.require(a_xp);
- cgh.require(a_v);
- cgh.require(a_f);
- cgh.require(a_inverseMasses);
+ a_x.bind(cgh);
+ a_xp.bind(cgh);
+ a_v.bind(cgh);
+ a_f.bind(cgh);
+ a_inverseMasses.bind(cgh);
if constexpr (numTempScaleValues != NumTempScaleValues::None)
{
if constexpr (numTempScaleValues != NumTempScaleValues::None)
{
- cgh.require(a_lambdas);
}
if constexpr (numTempScaleValues == NumTempScaleValues::Multiple)
{
}
if constexpr (numTempScaleValues == NumTempScaleValues::Multiple)
{
- cgh.require(a_tempScaleGroups);
+ a_tempScaleGroups.bind(cgh);
}
return [=](cl::sycl::id<1> itemIdx) {
}
return [=](cl::sycl::id<1> itemIdx) {
OptionalAccessor<float, mode::read_write, computeVirial> a_virialScaled,
PbcAiuc pbcAiuc)
{
OptionalAccessor<float, mode::read_write, computeVirial> a_virialScaled,
PbcAiuc pbcAiuc)
{
- cgh.require(a_constraints);
- cgh.require(a_constraintsTargetLengths);
+ a_constraints.bind(cgh);
+ a_constraintsTargetLengths.bind(cgh);
if constexpr (haveCoupledConstraints)
{
if constexpr (haveCoupledConstraints)
{
- cgh.require(a_coupledConstraintsCounts);
- cgh.require(a_coupledConstraintsIndices);
- cgh.require(a_massFactors);
- cgh.require(a_matrixA);
+ a_coupledConstraintsCounts.bind(cgh);
+ a_coupledConstraintsIndices.bind(cgh);
+ a_massFactors.bind(cgh);
+ a_matrixA.bind(cgh);
- cgh.require(a_inverseMasses);
- cgh.require(a_x);
- cgh.require(a_xp);
+ a_inverseMasses.bind(cgh);
+ a_x.bind(cgh);
+ a_xp.bind(cgh);
if constexpr (updateVelocities)
{
if constexpr (updateVelocities)
{
}
if constexpr (computeVirial)
{
}
if constexpr (computeVirial)
{
- cgh.require(a_virialScaled);
+ a_virialScaled.bind(cgh);
}
/* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
}
/* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
OptionalAccessor<float, mode::read_write, computeVirial> a_virialScaled,
PbcAiuc pbcAiuc)
{
OptionalAccessor<float, mode::read_write, computeVirial> a_virialScaled,
PbcAiuc pbcAiuc)
{
- cgh.require(a_settles);
- cgh.require(a_x);
- cgh.require(a_xp);
+ a_settles.bind(cgh);
+ a_x.bind(cgh);
+ a_xp.bind(cgh);
if constexpr (updateVelocities)
{
if constexpr (updateVelocities)
{
}
if constexpr (computeVirial)
{
}
if constexpr (computeVirial)
{
- cgh.require(a_virialScaled);
+ a_virialScaled.bind(cgh);
}
// shmem buffer for i x+q pre-loading
}
// shmem buffer for i x+q pre-loading
DeviceAccessor<Float3, cl::sycl::access::mode::read_write> a_x,
const ScalingMatrix scalingMatrix)
{
DeviceAccessor<Float3, cl::sycl::access::mode::read_write> a_x,
const ScalingMatrix scalingMatrix)
{
return [=](cl::sycl::id<1> itemIdx) {
Float3 x = a_x[itemIdx];
return [=](cl::sycl::id<1> itemIdx) {
Float3 x = a_x[itemIdx];
int numAtomsPerCell,
int columnsOffset)
{
int numAtomsPerCell,
int columnsOffset)
{
- cgh.require(a_xq);
- cgh.require(a_x);
- cgh.require(a_atomIndex);
- cgh.require(a_numAtoms);
- cgh.require(a_cellIndex);
+ a_xq.bind(cgh);
+ a_x.bind(cgh);
+ a_atomIndex.bind(cgh);
+ a_numAtoms.bind(cgh);
+ a_cellIndex.bind(cgh);
return [=](cl::sycl::id<2> itemIdx) {
// Map cell-level parallelism to y component of block index.
return [=](cl::sycl::id<2> itemIdx) {
// Map cell-level parallelism to y component of block index.
{
static constexpr EnergyFunctionProperties<elecType, vdwType> props;
{
static constexpr EnergyFunctionProperties<elecType, vdwType> props;
- cgh.require(a_xq);
- cgh.require(a_f);
- cgh.require(a_shiftVec);
- cgh.require(a_fShift);
+ a_xq.bind(cgh);
+ a_f.bind(cgh);
+ a_shiftVec.bind(cgh);
+ a_fShift.bind(cgh);
if constexpr (doCalcEnergies)
{
if constexpr (doCalcEnergies)
{
- cgh.require(a_energyElec);
- cgh.require(a_energyVdw);
+ a_energyElec.bind(cgh);
+ a_energyVdw.bind(cgh);
- cgh.require(a_plistCJ4);
- cgh.require(a_plistSci);
- cgh.require(a_plistExcl);
+ a_plistCJ4.bind(cgh);
+ a_plistSci.bind(cgh);
+ a_plistExcl.bind(cgh);
if constexpr (!props.vdwComb)
{
if constexpr (!props.vdwComb)
{
- cgh.require(a_atomTypes);
- cgh.require(a_nbfp);
+ a_atomTypes.bind(cgh);
+ a_nbfp.bind(cgh);
}
if constexpr (props.vdwEwald)
{
}
if constexpr (props.vdwEwald)
{
- cgh.require(a_nbfpComb);
}
if constexpr (props.elecEwaldTab)
{
}
if constexpr (props.elecEwaldTab)
{
- cgh.require(a_coulombTab);
+ a_coulombTab.bind(cgh);
}
// shmem buffer for i x+q pre-loading
}
// shmem buffer for i x+q pre-loading
const int numParts,
const int part)
{
const int numParts,
const int part)
{
- cgh.require(a_xq);
- cgh.require(a_shiftVec);
- cgh.require(a_plistCJ4);
- cgh.require(a_plistSci);
- cgh.require(a_plistIMask);
+ a_xq.bind(cgh);
+ a_shiftVec.bind(cgh);
+ a_plistCJ4.bind(cgh);
+ a_plistSci.bind(cgh);
+ a_plistIMask.bind(cgh);
/* shmem buffer for i x+q pre-loading */
cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(
/* shmem buffer for i x+q pre-loading */
cl::sycl::accessor<Float4, 2, mode::read_write, target::local> sm_xq(