cgh.require(a_virialScaled);
}
- // shmem buffer for local distances
- auto sm_r = [&]() {
- return cl::sycl::accessor<Float3, 1, mode::read_write, target::local>(
- cl::sycl::range<1>(c_threadsPerBlock), cgh);
- }();
-
- // shmem buffer for right-hand-side values
- auto sm_rhs = [&]() {
- return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
- cl::sycl::range<1>(c_threadsPerBlock * 2), cgh);
- }();
-
- // shmem buffer for virial components
- auto sm_threadVirial = [&]() {
- if constexpr (computeVirial)
- {
- return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
- cl::sycl::range<1>(c_threadsPerBlock * 6), cgh);
- }
- else
- {
- return nullptr;
- }
- }();
+ /* Shared local memory buffer. Corresponds to sh_r, sm_rhs, and sm_threadVirial in CUDA.
+ * sh_r: one Float3 per thread.
+ * sh_rhs: two floats per thread.
+ * sm_threadVirial: six floats per thread.
+ * So, without virials we need max(1*3, 2) floats, and with virials we need max(1*3, 2, 6) floats.
+ */
+ static constexpr int smBufferElementsPerThread = computeVirial ? 6 : 3;
+ cl::sycl::accessor<float, 1, mode::read_write, target::local> sm_buffer{
+ cl::sycl::range<1>(c_threadsPerBlock * smBufferElementsPerThread), cgh
+ };
return [=](cl::sycl::nd_item<1> itemIdx) {
const int threadIndex = itemIdx.get_global_linear_id();
rc = rlen * dx;
}
- sm_r[threadInBlock] = rc;
+ sm_buffer[threadInBlock * DIM + XX] = rc[XX];
+ sm_buffer[threadInBlock * DIM + YY] = rc[YY];
+ sm_buffer[threadInBlock * DIM + ZZ] = rc[ZZ];
// Make sure that all r's are saved into shared memory
// before they are accessed in the loop below
itemIdx.barrier(fence_space::global_and_local);
int index = n * numConstraintsThreads + threadIndex;
int c1 = a_coupledConstraintsIndices[index];
- Float3 rc1 = sm_r[c1];
+ Float3 rc1{ sm_buffer[c1 * DIM + XX], sm_buffer[c1 * DIM + YY], sm_buffer[c1 * DIM + ZZ] };
a_matrixA[index] = a_massFactors[index]
* (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]);
}
* Inverse matrix using a set of expansionOrder matrix multiplications
*/
- // This will use the same memory space as sm_r, which is no longer needed.
- sm_rhs[threadInBlock] = sol;
+ // This will reuse the same buffer, because the old values are no longer needed.
+ itemIdx.barrier(fence_space::local_space);
+ sm_buffer[threadInBlock] = sol;
// No need to iterate if there are no coupled constraints.
if constexpr (haveCoupledConstraints)
{
for (int rec = 0; rec < expansionOrder; rec++)
{
- // Making sure that all sm_rhs are saved before they are accessed in a loop below
+ // Making sure that all sm_buffer values are saved before they are accessed in a loop below
itemIdx.barrier(fence_space::global_and_local);
float mvb = 0.0F;
for (int n = 0; n < coupledConstraintsCount; n++)
int index = n * numConstraintsThreads + threadIndex;
int c1 = a_coupledConstraintsIndices[index];
// Convolute current right-hand-side with A
- // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
- mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+ // Different, non overlapping parts of sm_buffer[..] are read during odd and even iterations
+ mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
}
// 'Switch' rhs vectors, save current result
// These values will be accessed in the loop above during the next iteration.
- sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
- sol = sol + mvb;
+ sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+
+ sol = sol + mvb;
}
}
proj = sqrtReducedMass * targetLength;
}
- sm_rhs[threadInBlock] = proj;
- float sol = proj;
+ sm_buffer[threadInBlock] = proj;
+ float sol = proj;
/*
* Same matrix inversion as above is used for updated data
int index = n * numConstraintsThreads + threadIndex;
int c1 = a_coupledConstraintsIndices[index];
- mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
+ mvb = mvb + a_matrixA[index] * sm_buffer[c1 + c_threadsPerBlock * (rec % 2)];
}
- sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
- sol = sol + mvb;
+ sm_buffer[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
+ sol = sol + mvb;
}
}
// Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
// 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
// lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
- float mult = targetLength * lagrangeScaled;
- sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
- sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
- sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
- sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
- sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
- sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
+ // We reuse the same shared memory buffer, so we make sure we don't need its old values:
+ itemIdx.barrier(fence_space::local_space);
+ float mult = targetLength * lagrangeScaled;
+ sm_buffer[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
+ sm_buffer[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
+ sm_buffer[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
+ sm_buffer[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
+ sm_buffer[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
+ sm_buffer[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
itemIdx.barrier(fence_space::local_space);
// This casts unsigned into signed integers to avoid clang warnings
{
for (int d = 0; d < 6; d++)
{
- sm_threadVirial[d * blockSize + tib] +=
- sm_threadVirial[d * blockSize + (tib + dividedAt)];
+ sm_buffer[d * blockSize + tib] += sm_buffer[d * blockSize + (tib + dividedAt)];
}
}
if (dividedAt > subGroupSize / 2)
// First 6 threads in the block add the 6 components of virial to the global memory address
if (tib < 6)
{
- atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]);
+ atomicFetchAdd(a_virialScaled[tib], sm_buffer[tib * blockSize]);
}
}
};