#include "gromacs/utility/fatalerror.h"
#include "gromacs/utility/template_mp.h"
+//! \brief Class name for leap-frog kernel
+template<gmx::NumTempScaleValues numTempScaleValues, gmx::VelocityScalingType velocityScaling>
+class LeapFrogKernel;
+
namespace gmx
{
* \param[in] dt Timestep.
* \param[in] a_lambdas Temperature scaling factors (one per group).
* \param[in] a_tempScaleGroups Mapping of atoms into groups.
- * \param[in] prVelocityScalingMatrixDiagonal Diagonal elements of Parrinello-Rahman velocity scaling matrix
+ * \param[in] prVelocityScalingMatrixDiagonal Diagonal elements of Parrinello-Rahman velocity scaling matrix.
*/
template<NumTempScaleValues numTempScaleValues, VelocityScalingType velocityScaling>
auto leapFrogKernel(
OptionalAccessor<unsigned short, mode::read, numTempScaleValues == NumTempScaleValues::Multiple> a_tempScaleGroups,
Float3 prVelocityScalingMatrixDiagonal)
{
- cgh.require(a_x);
- cgh.require(a_xp);
- cgh.require(a_v);
- cgh.require(a_f);
- cgh.require(a_inverseMasses);
+ a_x.bind(cgh);
+ a_xp.bind(cgh);
+ a_v.bind(cgh);
+ a_f.bind(cgh);
+ a_inverseMasses.bind(cgh);
if constexpr (numTempScaleValues != NumTempScaleValues::None)
{
- cgh.require(a_lambdas);
+ a_lambdas.bind(cgh);
}
if constexpr (numTempScaleValues == NumTempScaleValues::Multiple)
{
- cgh.require(a_tempScaleGroups);
+ a_tempScaleGroups.bind(cgh);
}
return [=](cl::sycl::id<1> itemIdx) {
};
}
-// SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
-template<NumTempScaleValues numTempScaleValues, VelocityScalingType velocityScaling>
-class LeapFrogKernelName;
-
+//! \brief Leap Frog SYCL kernel launch code.
template<NumTempScaleValues numTempScaleValues, VelocityScalingType velocityScaling, class... Args>
static cl::sycl::event launchLeapFrogKernel(const DeviceStream& deviceStream, int numAtoms, Args&&... args)
{
// Should not be needed for SYCL2020.
- using kernelNameType = LeapFrogKernelName<numTempScaleValues, velocityScaling>;
+ using kernelNameType = LeapFrogKernel<numTempScaleValues, velocityScaling>;
const cl::sycl::range<1> rangeAllAtoms(numAtoms);
cl::sycl::queue q = deviceStream.stream();
return e;
}
+//! Convert \p doTemperatureScaling and \p numTempScaleValues to \ref NumTempScaleValues.
static NumTempScaleValues getTempScalingType(bool doTemperatureScaling, int numTempScaleValues)
{
if (!doTemperatureScaling)