2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 * \brief Implements LINCS kernels using SYCL
39 * This file contains SYCL kernels of LINCS constraints algorithm.
41 * \author Artem Zhmurov <zhmurov@gmail.com>
43 * \ingroup module_mdlib
45 #include "lincs_gpu_internal.h"
47 #include "gromacs/gpu_utils/devicebuffer.h"
48 #include "gromacs/gpu_utils/gmxsycl.h"
49 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
50 #include "gromacs/mdlib/lincs_gpu.h"
51 #include "gromacs/pbcutil/pbc_aiuc_sycl.h"
52 #include "gromacs/utility/gmxassert.h"
53 #include "gromacs/utility/template_mp.h"
58 using cl::sycl::access::fence_space;
59 using cl::sycl::access::mode;
60 using cl::sycl::access::target;
62 /*! \brief Main kernel for LINCS constraints.
64 * See Hess et al., J. Comput. Chem. 18: 1463-1472 (1997) for the description of the algorithm.
66 * In GPU version, one thread is responsible for all computations for one constraint. The blocks are
67 * filled in a way that no constraint is coupled to the constraint from the next block. This is achieved
68 * by moving active threads to the next block, if the correspondent group of coupled constraints is to big
69 * to fit the current thread block. This may leave some 'dummy' threads in the end of the thread block, i.e.
70 * threads that are not required to do actual work. Since constraints from different blocks are not coupled,
71 * there is no need to synchronize across the device. However, extensive communication in a thread block
74 * \todo Reduce synchronization overhead. Some ideas are:
75 * 1. Consider going to warp-level synchronization for the coupled constraints.
76 * 2. Move more data to local/shared memory and try to get rid of atomic operations (at least on
78 * 3. Use analytical solution for matrix A inversion.
79 * 4. Introduce mapping of thread id to both single constraint and single atom, thus designating
80 * Nth threads to deal with Nat <= Nth coupled atoms and Nc <= Nth coupled constraints.
81 * See Issue #2885 for details (https://gitlab.com/gromacs/gromacs/-/issues/2885)
82 * \todo The use of __restrict__ for gm_xp and gm_v causes failure, probably because of the atomic
83 * operations. Investigate this issue further.
85 * \tparam updateVelocities Whether velocities should be updated this step.
86 * \tparam computeVirial Whether virial tensor should be computed this step.
87 * \tparam haveCoupledConstraints If there are coupled constraints (i.e. LINCS iterations are needed).
89 * \param[in] cgh SYCL handler.
90 * \param[in] numConstraintsThreads Total number of threads.
91 * \param[in] a_constraints List of constrained atoms.
92 * \param[in] a_constraintsTargetLengths Equilibrium distances for the constraints.
93 * \param[in] a_coupledConstraintsCounts Number of constraints, coupled with the current one.
94 * \param[in] a_coupledConstraintsIndices List of coupled with the current one.
95 * \param[in] a_massFactors Mass factors.
96 * \param[in] a_matrixA Elements of the coupling matrix.
97 * \param[in] a_inverseMasses 1/mass for all atoms.
98 * \param[in] numIterations Number of iterations used to correct the projection.
99 * \param[in] expansionOrder Order of expansion when inverting the matrix.
100 * \param[in] a_x Unconstrained positions.
101 * \param[in,out] a_xp Positions at the previous step, will be updated.
102 * \param[in] invdt Inverse timestep (needed to update velocities).
103 * \param[in,out] a_v Velocities of atoms, will be updated if \c updateVelocities.
104 * \param[in,out] a_virialScaled Scaled virial tensor (6 floats: [XX, XY, XZ, YY, YZ, ZZ].
105 * Will be updated if \c updateVirial.
106 * \param[in] pbcAiuc Periodic boundary data.
108 template<bool updateVelocities, bool computeVirial, bool haveCoupledConstraints>
109 auto lincsKernel(cl::sycl::handler& cgh,
110 const int numConstraintsThreads,
111 DeviceAccessor<AtomPair, mode::read> a_constraints,
112 DeviceAccessor<float, mode::read> a_constraintsTargetLengths,
113 OptionalAccessor<int, mode::read, haveCoupledConstraints> a_coupledConstraintsCounts,
114 OptionalAccessor<int, mode::read, haveCoupledConstraints> a_coupledConstraintsIndices,
115 OptionalAccessor<float, mode::read, haveCoupledConstraints> a_massFactors,
116 OptionalAccessor<float, mode::read_write, haveCoupledConstraints> a_matrixA,
117 DeviceAccessor<float, mode::read> a_inverseMasses,
118 const int numIterations,
119 const int expansionOrder,
120 DeviceAccessor<Float3, mode::read> a_x,
121 DeviceAccessor<float, mode::read_write> a_xp,
123 OptionalAccessor<float, mode::read_write, updateVelocities> a_v,
124 OptionalAccessor<float, mode::read_write, computeVirial> a_virialScaled,
127 cgh.require(a_constraints);
128 cgh.require(a_constraintsTargetLengths);
129 if constexpr (haveCoupledConstraints)
131 cgh.require(a_coupledConstraintsCounts);
132 cgh.require(a_coupledConstraintsIndices);
133 cgh.require(a_massFactors);
134 cgh.require(a_matrixA);
136 cgh.require(a_inverseMasses);
139 if constexpr (updateVelocities)
143 if constexpr (computeVirial)
145 cgh.require(a_virialScaled);
148 // shmem buffer for local distances
150 return cl::sycl::accessor<Float3, 1, mode::read_write, target::local>(
151 cl::sycl::range<1>(c_threadsPerBlock), cgh);
154 // shmem buffer for right-hand-side values
155 auto sm_rhs = [&]() {
156 return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
157 cl::sycl::range<1>(c_threadsPerBlock), cgh);
160 // shmem buffer for virial components
161 auto sm_threadVirial = [&]() {
162 if constexpr (computeVirial)
164 return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
165 cl::sycl::range<1>(c_threadsPerBlock * 6), cgh);
173 return [=](cl::sycl::nd_item<1> itemIdx) {
174 const int threadIndex = itemIdx.get_global_linear_id();
175 const int threadInBlock = itemIdx.get_local_linear_id(); // Work-item index in work-group
177 AtomPair pair = a_constraints[threadIndex];
181 // Mass-scaled Lagrange multiplier
182 float lagrangeScaled = 0.0F;
187 float sqrtReducedMass;
193 // i == -1 indicates dummy constraint at the end of the thread block.
194 bool isDummyThread = (i == -1);
196 // Everything computed for these dummies will be equal to zero
202 sqrtReducedMass = 0.0F;
204 xi = Float3(0.0F, 0.0F, 0.0F);
205 xj = Float3(0.0F, 0.0F, 0.0F);
206 rc = Float3(0.0F, 0.0F, 0.0F);
211 targetLength = a_constraintsTargetLengths[threadIndex];
212 inverseMassi = a_inverseMasses[i];
213 inverseMassj = a_inverseMasses[j];
214 sqrtReducedMass = cl::sycl::rsqrt(inverseMassi + inverseMassj);
220 pbcDxAiucSycl(pbcAiuc, xi, xj, dx);
222 float rlen = cl::sycl::rsqrt(dx[XX] * dx[XX] + dx[YY] * dx[YY] + dx[ZZ] * dx[ZZ]);
226 sm_r[threadIndex] = rc;
227 // Make sure that all r's are saved into shared memory
228 // before they are accessed in the loop below
229 itemIdx.barrier(fence_space::global_and_local);
232 * Constructing LINCS matrix (A)
234 int coupledConstraintsCount = 0;
235 if constexpr (haveCoupledConstraints)
237 // Only non-zero values are saved (for coupled constraints)
238 coupledConstraintsCount = a_coupledConstraintsCounts[threadIndex];
239 for (int n = 0; n < coupledConstraintsCount; n++)
241 int index = n * numConstraintsThreads + threadIndex;
242 int c1 = a_coupledConstraintsIndices[index];
244 Float3 rc1 = sm_r[c1];
245 a_matrixA[index] = a_massFactors[index]
246 * (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]);
250 // Skipping in dummy threads
253 xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
254 xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
255 xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
256 xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
257 xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
258 xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
262 pbcDxAiucSycl(pbcAiuc, xi, xj, dx);
264 float sol = sqrtReducedMass * ((rc[XX] * dx[XX] + rc[YY] * dx[YY] + rc[ZZ] * dx[ZZ]) - targetLength);
267 * Inverse matrix using a set of expansionOrder matrix multiplications
270 // This will use the same memory space as sm_r, which is no longer needed.
271 sm_rhs[threadInBlock] = sol;
273 // No need to iterate if there are no coupled constraints.
274 if constexpr (haveCoupledConstraints)
276 for (int rec = 0; rec < expansionOrder; rec++)
278 // Making sure that all sm_rhs are saved before they are accessed in a loop below
279 itemIdx.barrier(fence_space::global_and_local);
281 for (int n = 0; n < coupledConstraintsCount; n++)
283 int index = n * numConstraintsThreads + threadIndex;
284 int c1 = a_coupledConstraintsIndices[index];
285 // Convolute current right-hand-side with A
286 // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
287 mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
289 // 'Switch' rhs vectors, save current result
290 // These values will be accessed in the loop above during the next iteration.
291 sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
296 // Current mass-scaled Lagrange multipliers
297 lagrangeScaled = sqrtReducedMass * sol;
299 // Save updated coordinates before correction for the rotational lengthening
300 Float3 tmp = rc * lagrangeScaled;
302 // Writing for all but dummy constraints
306 * Note: Using memory_scope::work_group for atomic_ref can be better here,
307 * but for now we re-use the existing function for memory_scope::device atomics.
309 atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
310 atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
311 atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
312 atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
313 atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
314 atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
318 * Correction for centripetal effects
320 for (int iter = 0; iter < numIterations; iter++)
322 // Make sure that all xp's are saved: atomic operation calls before are
323 // communicating current xp[..] values across thread block.
324 itemIdx.barrier(fence_space::global_and_local);
328 xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
329 xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
330 xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
331 xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
332 xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
333 xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
337 pbcDxAiucSycl(pbcAiuc, xi, xj, dx);
339 float len2 = targetLength * targetLength;
340 float dlen2 = 2.0F * len2 - (dx[XX] * dx[XX] + dx[YY] * dx[YY] + dx[ZZ] * dx[ZZ]);
342 // TODO A little bit more effective but slightly less readable version of the below would be:
343 // float proj = sqrtReducedMass*(targetLength - (dlen2 > 0.0f ? 1.0f : 0.0f)*dlen2*rsqrt(dlen2));
347 proj = sqrtReducedMass * (targetLength - dlen2 * cl::sycl::rsqrt(dlen2));
351 proj = sqrtReducedMass * targetLength;
354 sm_rhs[threadInBlock] = proj;
358 * Same matrix inversion as above is used for updated data
360 if constexpr (haveCoupledConstraints)
362 for (int rec = 0; rec < expansionOrder; rec++)
364 // Make sure that all elements of rhs are saved into shared memory
365 itemIdx.barrier(fence_space::global_and_local);
367 for (int n = 0; n < coupledConstraintsCount; n++)
369 int index = n * numConstraintsThreads + threadIndex;
370 int c1 = a_coupledConstraintsIndices[index];
372 mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
375 sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
380 // Add corrections to Lagrange multipliers
381 float sqrtmu_sol = sqrtReducedMass * sol;
382 lagrangeScaled += sqrtmu_sol;
384 // Save updated coordinates for the next iteration
385 // Dummy constraints are skipped
388 Float3 tmp = rc * sqrtmu_sol;
389 atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
390 atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
391 atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
392 atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
393 atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
394 atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
398 // Updating particle velocities for all but dummy threads
399 if constexpr (updateVelocities)
403 Float3 tmp = rc * invdt * lagrangeScaled;
404 atomicFetchAdd(a_v[i * DIM + XX], -tmp[XX] * inverseMassi);
405 atomicFetchAdd(a_v[i * DIM + YY], -tmp[YY] * inverseMassi);
406 atomicFetchAdd(a_v[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
407 atomicFetchAdd(a_v[j * DIM + XX], tmp[XX] * inverseMassj);
408 atomicFetchAdd(a_v[j * DIM + YY], tmp[YY] * inverseMassj);
409 atomicFetchAdd(a_v[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
413 if constexpr (computeVirial)
415 // Virial is computed from Lagrange multiplier (lagrangeScaled), target constrain length
416 // (targetLength) and the normalized vector connecting constrained atoms before
417 // the algorithm was applied (rc). The evaluation of virial in each thread is
418 // followed by basic reduction for the values inside single thread block.
419 // Then, the values are reduced across grid by atomicAdd(...).
421 // TODO Shuffle reduction.
422 // TODO Should be unified and/or done once when virial is actually needed.
423 // TODO Recursive version that removes atomicAdd(...)'s entirely is needed. Ideally,
424 // one that works for any datatype.
426 // Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
427 // 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
428 // lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
429 float mult = targetLength * lagrangeScaled;
430 sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
431 sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
432 sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
433 sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
434 sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
435 sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
437 itemIdx.barrier(fence_space::local_space);
438 // This casts unsigned into signed integers to avoid clang warnings
439 const int tib = static_cast<int>(threadInBlock);
440 const int blockSize = static_cast<int>(c_threadsPerBlock);
441 const int subGroupSize = itemIdx.get_sub_group().get_max_local_range()[0];
443 // Reduce up to one virial per thread block
444 // All blocks are divided by half, the first half of threads sums
445 // two virials. Then the first half is divided by two and the first half
446 // of it sums two values... The procedure continues until only one thread left.
447 // Only works if the threads per blocks is a power of two.
448 for (int divideBy = 2; divideBy <= blockSize; divideBy *= 2)
450 int dividedAt = blockSize / divideBy;
453 for (int d = 0; d < 6; d++)
455 sm_threadVirial[d * blockSize + tib] +=
456 sm_threadVirial[d * blockSize + (tib + dividedAt)];
459 if (dividedAt > subGroupSize / 2)
461 itemIdx.barrier(fence_space::local_space);
465 subGroupBarrier(itemIdx);
468 // First 6 threads in the block add the 6 components of virial to the global memory address
471 atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]);
477 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
478 template<bool updateVelocities, bool computeVirial, bool haveCoupledConstraints>
479 class LincsKernelName;
481 template<bool updateVelocities, bool computeVirial, bool haveCoupledConstraints, class... Args>
482 static cl::sycl::event launchLincsKernel(const DeviceStream& deviceStream,
483 const int numConstraintsThreads,
486 // Should not be needed for SYCL2020.
487 using kernelNameType = LincsKernelName<updateVelocities, computeVirial, haveCoupledConstraints>;
489 const cl::sycl::nd_range<1> rangeAllLincs(numConstraintsThreads, c_threadsPerBlock);
490 cl::sycl::queue q = deviceStream.stream();
492 cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
493 auto kernel = lincsKernel<updateVelocities, computeVirial, haveCoupledConstraints>(
494 cgh, numConstraintsThreads, std::forward<Args>(args)...);
495 cgh.parallel_for<kernelNameType>(rangeAllLincs, kernel);
501 /*! \brief Select templated kernel and launch it. */
502 template<class... Args>
503 static inline cl::sycl::event
504 launchLincsKernel(bool updateVelocities, bool computeVirial, bool haveCoupledConstraints, Args&&... args)
506 return dispatchTemplatedFunction(
507 [&](auto updateVelocities_, auto computeVirial_, auto haveCoupledConstraints_) {
508 return launchLincsKernel<updateVelocities_, computeVirial_, haveCoupledConstraints_>(
509 std::forward<Args>(args)...);
513 haveCoupledConstraints);
517 void launchLincsGpuKernel(LincsGpuKernelParameters* kernelParams,
518 const DeviceBuffer<Float3>& d_x,
519 DeviceBuffer<Float3> d_xp,
520 const bool updateVelocities,
521 DeviceBuffer<Float3> d_v,
523 const bool computeVirial,
524 const DeviceStream& deviceStream)
526 cl::sycl::buffer<Float3, 1> xp(*d_xp.buffer_);
527 auto d_xpAsFloat = xp.reinterpret<float, 1>(xp.get_count() * DIM);
529 cl::sycl::buffer<Float3, 1> v(*d_v.buffer_);
530 auto d_vAsFloat = v.reinterpret<float, 1>(v.get_count() * DIM);
532 launchLincsKernel(updateVelocities,
534 kernelParams->haveCoupledConstraints,
536 kernelParams->numConstraintsThreads,
537 kernelParams->d_constraints,
538 kernelParams->d_constraintsTargetLengths,
539 kernelParams->d_coupledConstraintsCounts,
540 kernelParams->d_coupledConstraintsIndices,
541 kernelParams->d_massFactors,
542 kernelParams->d_matrixA,
543 kernelParams->d_inverseMasses,
544 kernelParams->numIterations,
545 kernelParams->expansionOrder,
550 kernelParams->d_virialScaled,
551 kernelParams->pbcAiuc);