const int numIterations,
const int expansionOrder,
DeviceAccessor<Float3, mode::read> a_x,
- DeviceAccessor<float, mode::read_write> a_xp,
+ DeviceAccessor<Float3, mode::read_write> a_xp,
const float invdt,
- OptionalAccessor<float, mode::read_write, updateVelocities> a_v,
+ OptionalAccessor<Float3, mode::read_write, updateVelocities> a_v,
OptionalAccessor<float, mode::read_write, computeVirial> a_virialScaled,
PbcAiuc pbcAiuc)
{
cgh.require(a_virialScaled);
}
- // shmem buffer for local distances
- auto sm_r = [&]() {
- return cl::sycl::accessor<Float3, 1, mode::read_write, target::local>(
- cl::sycl::range<1>(c_threadsPerBlock), cgh);
- }();
-
- // shmem buffer for right-hand-side values
- auto sm_rhs = [&]() {
- return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
- cl::sycl::range<1>(c_threadsPerBlock), cgh);
- }();
-
- // shmem buffer for virial components
- auto sm_threadVirial = [&]() {
- if constexpr (computeVirial)
- {
- return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
- cl::sycl::range<1>(c_threadsPerBlock * 6), cgh);
- }
- else
- {
- return nullptr;
- }
- }();
+ /* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
+ * sh_r: one Float3 per thread.
+ * sh_rhs: two floats per thread.
+ * sm_threadVirial: six floats per thread.
+ * So, without virials we need max(1*3, 2) floats, and with virials we need max(1*3, 2, 6) floats.
+ */
+ static constexpr int smBufferElementsPerThread = computeVirial ? 6 : 3;
+ cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buffer{
+ cl::sycl::range<1>(c_threadsPerBlock * smBufferElementsPerThread), cgh
+ };
return [=](cl::sycl::nd_item<1> itemIdx) {
const int threadIndex = itemIdx.get_global_linear_id();
rc = rlen * dx;
}
- sm_r[threadIndex] = rc;
+ sm_buffer[threadInBlock * DIM + XX] = rc[XX];
+ sm_buffer[threadInBlock * DIM + YY] = rc[YY];
+ sm_buffer[threadInBlock * DIM + ZZ] = rc[ZZ];
// Make sure that all r's are saved into shared memory
// before they are accessed in the loop below
itemIdx.barrier(fence_space::global_and_local);
int index = n * numConstraintsThreads + threadIndex;
int c1 = a_coupledConstraintsIndices[index];
- Float3 rc1 = sm_r[c1];
+ Float3 rc1{ sm_buffer[c1 * DIM + XX], sm_buffer[c1 * DIM + YY], sm_buffer[c1 * DIM + ZZ] };
a_matrixA[index] = a_massFactors[index]
* (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]);
}
// Skipping in dummy threads
if (!isDummyThread)
{
- xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
- xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
- xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
- xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
- xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
- xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
+ xi[XX] = atomicLoad(a_xp[i][XX]);
+ xi[YY] = atomicLoad(a_xp[i][YY]);
+ xi[ZZ] = atomicLoad(a_xp[i][ZZ]);
+ xj[XX] = atomicLoad(a_xp[j][XX]);
+ xj[YY] = atomicLoad(a_xp[j][YY]);
+ xj[ZZ] = atomicLoad(a_xp[j][ZZ]);
}
Float3 dx;
* Inverse matrix using a set of expansionOrder matrix multiplications
*/
- // This will use the same memory space as sm_r, which is no longer needed.
- sm_rhs[threadInBlock] = sol;
+ // This will reuse the same buffer, because the old values are no longer needed.
+ itemIdx.barrier(fence_space::local_space);
+ sm_buffer[threadInBlock] = sol;
// No need to iterate if there are no coupled constraints.
if constexpr (haveCoupledConstraints)
{
for (int rec = 0; rec < expansionOrder; rec++)
{
- // Making sure that all sm_rhs are saved before they are accessed in a loop below
+ // Making sure that all sm_buffer values are saved before they are accessed in a loop below
itemIdx.barrier(fence_space::global_and_local);
float mvb = 0.0F;
for (int n = 0; n < coupledConstraintsCount; n++)
int index = n * numConstraintsThreads + threadIndex;
int c1 = a_coupledConstraintsIndices[index];
// Convolute current right-hand-side with A
- // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
- mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+ // Different, non overlapping parts of sm_buffer[..] are read during odd and even iterations
+ mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
}
// 'Switch' rhs vectors, save current result
// These values will be accessed in the loop above during the next iteration.
- sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
- sol = sol + mvb;
+ sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+
+ sol = sol + mvb;
}
}
* Note: Using memory_scope::work_group for atomic_ref can be better here,
* but for now we re-use the existing function for memory_scope::device atomics.
*/
- atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
- atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
- atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
- atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
- atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
- atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+ atomicFetchAdd(a_xp[i][XX], -tmp[XX] * inverseMassi);
+ atomicFetchAdd(a_xp[i][YY], -tmp[YY] * inverseMassi);
+ atomicFetchAdd(a_xp[i][ZZ], -tmp[ZZ] * inverseMassi);
+ atomicFetchAdd(a_xp[j][XX], tmp[XX] * inverseMassj);
+ atomicFetchAdd(a_xp[j][YY], tmp[YY] * inverseMassj);
+ atomicFetchAdd(a_xp[j][ZZ], tmp[ZZ] * inverseMassj);
}
/*
if (!isDummyThread)
{
- xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
- xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
- xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
- xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
- xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
- xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
+ xi[XX] = atomicLoad(a_xp[i][XX]);
+ xi[YY] = atomicLoad(a_xp[i][YY]);
+ xi[ZZ] = atomicLoad(a_xp[i][ZZ]);
+ xj[XX] = atomicLoad(a_xp[j][XX]);
+ xj[YY] = atomicLoad(a_xp[j][YY]);
+ xj[ZZ] = atomicLoad(a_xp[j][ZZ]);
}
Float3 dx;
proj = sqrtReducedMass * targetLength;
}
- sm_rhs[threadInBlock] = proj;
- float sol = proj;
+ sm_buffer[threadInBlock] = proj;
+ float sol = proj;
/*
* Same matrix inversion as above is used for updated data
int index = n * numConstraintsThreads + threadIndex;
int c1 = a_coupledConstraintsIndices[index];
- mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+ mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
}
- sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
- sol = sol + mvb;
+ sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+ sol = sol + mvb;
}
}
if (!isDummyThread)
{
Float3 tmp = rc * sqrtmu_sol;
- atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
- atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
- atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
- atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
- atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
- atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+ atomicFetchAdd(a_xp[i][XX], -tmp[XX] * inverseMassi);
+ atomicFetchAdd(a_xp[i][YY], -tmp[YY] * inverseMassi);
+ atomicFetchAdd(a_xp[i][ZZ], -tmp[ZZ] * inverseMassi);
+ atomicFetchAdd(a_xp[j][XX], tmp[XX] * inverseMassj);
+ atomicFetchAdd(a_xp[j][YY], tmp[YY] * inverseMassj);
+ atomicFetchAdd(a_xp[j][ZZ], tmp[ZZ] * inverseMassj);
}
}
if (!isDummyThread)
{
Float3 tmp = rc * invdt * lagrangeScaled;
- atomicFetchAdd(a_v[i * DIM + XX], -tmp[XX] * inverseMassi);
- atomicFetchAdd(a_v[i * DIM + YY], -tmp[YY] * inverseMassi);
- atomicFetchAdd(a_v[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
- atomicFetchAdd(a_v[j * DIM + XX], tmp[XX] * inverseMassj);
- atomicFetchAdd(a_v[j * DIM + YY], tmp[YY] * inverseMassj);
- atomicFetchAdd(a_v[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
+ atomicFetchAdd(a_v[i][XX], -tmp[XX] * inverseMassi);
+ atomicFetchAdd(a_v[i][YY], -tmp[YY] * inverseMassi);
+ atomicFetchAdd(a_v[i][ZZ], -tmp[ZZ] * inverseMassi);
+ atomicFetchAdd(a_v[j][XX], tmp[XX] * inverseMassj);
+ atomicFetchAdd(a_v[j][YY], tmp[YY] * inverseMassj);
+ atomicFetchAdd(a_v[j][ZZ], tmp[ZZ] * inverseMassj);
}
}
// Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
// 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
// lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
- float mult = targetLength * lagrangeScaled;
- sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
- sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
- sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
- sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
- sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
- sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
+ // We reuse the same shared memory buffer, so we make sure we don't need its old values:
+ itemIdx.barrier(fence_space::local_space);
+ float mult = targetLength * lagrangeScaled;
+ sm_buffer[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
+ sm_buffer[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
+ sm_buffer[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
+ sm_buffer[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
+ sm_buffer[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
+ sm_buffer[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
itemIdx.barrier(fence_space::local_space);
// This casts unsigned into signed integers to avoid clang warnings
{
for (int d = 0; d < 6; d++)
{
- sm_threadVirial[d * blockSize + tib] +=
- sm_threadVirial[d * blockSize + (tib + dividedAt)];
+ sm_buffer[d * blockSize + tib] += sm_buffer[d * blockSize + (tib + dividedAt)];
}
}
if (dividedAt > subGroupSize / 2)
// First 6 threads in the block add the 6 components of virial to the global memory address
if (tib < 6)
{
- atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]);
+ atomicFetchAdd(a_virialScaled[tib], sm_buffer[tib * blockSize]);
}
}
};
const bool computeVirial,
const DeviceStream& deviceStream)
{
- cl::sycl::buffer<Float3, 1> xp(*d_xp.buffer_);
- auto d_xpAsFloat = xp.reinterpret<float, 1>(xp.get_count() * DIM);
-
- cl::sycl::buffer<Float3, 1> v(*d_v.buffer_);
- auto d_vAsFloat = v.reinterpret<float, 1>(v.get_count() * DIM);
-
launchLincsKernel(updateVelocities,
computeVirial,
kernelParams->haveCoupledConstraints,
kernelParams->numIterations,
kernelParams->expansionOrder,
d_x,
- d_xpAsFloat,
+ d_xp,
invdt,
- d_vAsFloat,
+ d_v,
kernelParams->d_virialScaled,
kernelParams->pbcAiuc);
return;