#include "gromacs/gpu_utils/cudautils.cuh"
#include "gromacs/gpu_utils/devicebuffer.h"
+#include "gromacs/gpu_utils/typecasts.cuh"
#include "gromacs/gpu_utils/vectype_ops.cuh"
#include "gromacs/math/vec.h"
#include "gromacs/mdtypes/group.h"
return kernelPtr;
}
-void LeapFrogGpu::integrate(const DeviceBuffer<float3> d_x,
- DeviceBuffer<float3> d_xp,
- DeviceBuffer<float3> d_v,
- const DeviceBuffer<float3> d_f,
+void LeapFrogGpu::integrate(DeviceBuffer<Float3> d_x,
+ DeviceBuffer<Float3> d_xp,
+ DeviceBuffer<Float3> d_v,
+ const DeviceBuffer<Float3> d_f,
const real dt,
const bool doTemperatureScaling,
gmx::ArrayRef<const t_grp_tcstat> tcstat,
"Fully anisotropic Parrinello-Rahman pressure coupling is not yet supported "
"in GPU version of Leap-Frog integrator.");
prVelocityScalingMatrixDiagonal_ =
- make_float3(dtPressureCouple * prVelocityScalingMatrix[XX][XX],
- dtPressureCouple * prVelocityScalingMatrix[YY][YY],
- dtPressureCouple * prVelocityScalingMatrix[ZZ][ZZ]);
+ Float3{ dtPressureCouple * prVelocityScalingMatrix[XX][XX],
+ dtPressureCouple * prVelocityScalingMatrix[YY][YY],
+ dtPressureCouple * prVelocityScalingMatrix[ZZ][ZZ] };
}
kernelPtr = selectLeapFrogKernelPtr(doTemperatureScaling, numTempScaleValues_, prVelocityScalingType);
}
+ // Checking the buffer types against the kernel argument types
+ static_assert(sizeof(*d_inverseMasses_) == sizeof(float));
const auto kernelArgs = prepareGpuKernelArguments(kernelPtr,
kernelLaunchConfig_,
&numAtoms_,
- &d_x,
- &d_xp,
- &d_v,
- &d_f,
+ asFloat3Pointer(&d_x),
+ asFloat3Pointer(&d_xp),
+ asFloat3Pointer(&d_v),
+ asFloat3Pointer(&d_f),
&d_inverseMasses_,
&dt,
&d_lambdas_,
reallocateDeviceBuffer(
&d_inverseMasses_, numAtoms_, &numInverseMasses_, &numInverseMassesAlloc_, deviceContext_);
copyToDeviceBuffer(
- &d_inverseMasses_, (float*)inverseMasses, 0, numAtoms_, deviceStream_, GpuApiCallBehavior::Sync, nullptr);
+ &d_inverseMasses_, inverseMasses, 0, numAtoms_, deviceStream_, GpuApiCallBehavior::Sync, nullptr);
// Temperature scale group map only used if there are more then one group
if (numTempScaleValues_ > 1)