9627ce8d2c660ab257f0da9910736bc8d2cc8802
[alexxy/gromacs.git] / src / gromacs / mdlib / lincs_gpu_internal_sycl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  *
37  * \brief Implements LINCS kernels using SYCL
38  *
39  * This file contains SYCL kernels of LINCS constraints algorithm.
40  *
41  * \author Artem Zhmurov <zhmurov@gmail.com>
42  *
43  * \ingroup module_mdlib
44  */
45 #include "lincs_gpu_internal.h"
46
47 #include "gromacs/gpu_utils/devicebuffer.h"
48 #include "gromacs/gpu_utils/gmxsycl.h"
49 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
50 #include "gromacs/mdlib/lincs_gpu.h"
51 #include "gromacs/pbcutil/pbc_aiuc_sycl.h"
52 #include "gromacs/utility/gmxassert.h"
53 #include "gromacs/utility/template_mp.h"
54
55 namespace gmx
56 {
57
58 using cl::sycl::access::fence_space;
59 using cl::sycl::access::mode;
60 using cl::sycl::access::target;
61
62 /*! \brief Main kernel for LINCS constraints.
63  *
64  * See Hess et al., J. Comput. Chem. 18: 1463-1472 (1997) for the description of the algorithm.
65  *
66  * In GPU version, one thread is responsible for all computations for one constraint. The blocks are
67  * filled in a way that no constraint is coupled to the constraint from the next block. This is achieved
68  * by moving active threads to the next block, if the correspondent group of coupled constraints is to big
69  * to fit the current thread block. This may leave some 'dummy' threads in the end of the thread block, i.e.
70  * threads that are not required to do actual work. Since constraints from different blocks are not coupled,
71  * there is no need to synchronize across the device. However, extensive communication in a thread block
72  * are still needed.
73  *
74  * \todo Reduce synchronization overhead. Some ideas are:
75  *        1. Consider going to warp-level synchronization for the coupled constraints.
76  *        2. Move more data to local/shared memory and try to get rid of atomic operations (at least on
77  *           the device level).
78  *        3. Use analytical solution for matrix A inversion.
79  *        4. Introduce mapping of thread id to both single constraint and single atom, thus designating
80  *           Nth threads to deal with Nat <= Nth coupled atoms and Nc <= Nth coupled constraints.
81  *       See Issue #2885 for details (https://gitlab.com/gromacs/gromacs/-/issues/2885)
82  * \todo The use of __restrict__  for gm_xp and gm_v causes failure, probably because of the atomic
83  *       operations. Investigate this issue further.
84  *
85  * \tparam updateVelocities        Whether velocities should be updated this step.
86  * \tparam computeVirial           Whether virial tensor should be computed this step.
87  * \tparam haveCoupledConstraints  If there are coupled constraints (i.e. LINCS iterations are needed).
88  *
89  * \param[in]     cgh                           SYCL handler.
90  * \param[in]     numConstraintsThreads         Total number of threads.
91  * \param[in]     a_constraints                 List of constrained atoms.
92  * \param[in]     a_constraintsTargetLengths    Equilibrium distances for the constraints.
93  * \param[in]     a_coupledConstraintsCounts    Number of constraints, coupled with the current one.
94  * \param[in]     a_coupledConstraintsIndices   List of coupled with the current one.
95  * \param[in]     a_massFactors                 Mass factors.
96  * \param[in]     a_matrixA                     Elements of the coupling matrix.
97  * \param[in]     a_inverseMasses               1/mass for all atoms.
98  * \param[in]     numIterations                 Number of iterations used to correct the projection.
99  * \param[in]     expansionOrder                Order of expansion when inverting the matrix.
100  * \param[in]     a_x                           Unconstrained positions.
101  * \param[in,out] a_xp                          Positions at the previous step, will be updated.
102  * \param[in]     invdt                         Inverse timestep (needed to update velocities).
103  * \param[in,out] a_v                           Velocities of atoms, will be updated if \c updateVelocities.
104  * \param[in,out] a_virialScaled                Scaled virial tensor (6 floats: [XX, XY, XZ, YY, YZ, ZZ].
105  *                                              Will be updated if \c updateVirial.
106  * \param[in]     pbcAiuc                       Periodic boundary data.
107  */
108 template<bool updateVelocities, bool computeVirial, bool haveCoupledConstraints>
109 auto lincsKernel(cl::sycl::handler&                   cgh,
110                  const int                            numConstraintsThreads,
111                  DeviceAccessor<AtomPair, mode::read> a_constraints,
112                  DeviceAccessor<float, mode::read>    a_constraintsTargetLengths,
113                  OptionalAccessor<int, mode::read, haveCoupledConstraints> a_coupledConstraintsCounts,
114                  OptionalAccessor<int, mode::read, haveCoupledConstraints> a_coupledConstraintsIndices,
115                  OptionalAccessor<float, mode::read, haveCoupledConstraints>       a_massFactors,
116                  OptionalAccessor<float, mode::read_write, haveCoupledConstraints> a_matrixA,
117                  DeviceAccessor<float, mode::read>                                 a_inverseMasses,
118                  const int                                                         numIterations,
119                  const int                                                         expansionOrder,
120                  DeviceAccessor<Float3, mode::read>                                a_x,
121                  DeviceAccessor<float, mode::read_write>                           a_xp,
122                  const float                                                       invdt,
123                  OptionalAccessor<float, mode::read_write, updateVelocities>       a_v,
124                  OptionalAccessor<float, mode::read_write, computeVirial>          a_virialScaled,
125                  PbcAiuc                                                           pbcAiuc)
126 {
127     cgh.require(a_constraints);
128     cgh.require(a_constraintsTargetLengths);
129     if constexpr (haveCoupledConstraints)
130     {
131         cgh.require(a_coupledConstraintsCounts);
132         cgh.require(a_coupledConstraintsIndices);
133         cgh.require(a_massFactors);
134         cgh.require(a_matrixA);
135     }
136     cgh.require(a_inverseMasses);
137     cgh.require(a_x);
138     cgh.require(a_xp);
139     if constexpr (updateVelocities)
140     {
141         cgh.require(a_v);
142     }
143     if constexpr (computeVirial)
144     {
145         cgh.require(a_virialScaled);
146     }
147
148     // shmem buffer for local distances
149     auto sm_r = [&]() {
150         return cl::sycl::accessor<Float3, 1, mode::read_write, target::local>(
151                 cl::sycl::range<1>(c_threadsPerBlock), cgh);
152     }();
153
154     // shmem buffer for right-hand-side values
155     auto sm_rhs = [&]() {
156         return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
157                 cl::sycl::range<1>(c_threadsPerBlock * 2), cgh);
158     }();
159
160     // shmem buffer for virial components
161     auto sm_threadVirial = [&]() {
162         if constexpr (computeVirial)
163         {
164             return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
165                     cl::sycl::range<1>(c_threadsPerBlock * 6), cgh);
166         }
167         else
168         {
169             return nullptr;
170         }
171     }();
172
173     return [=](cl::sycl::nd_item<1> itemIdx) {
174         const int threadIndex   = itemIdx.get_global_linear_id();
175         const int threadInBlock = itemIdx.get_local_linear_id(); // Work-item index in work-group
176
177         AtomPair pair = a_constraints[threadIndex];
178         int      i    = pair.i;
179         int      j    = pair.j;
180
181         // Mass-scaled Lagrange multiplier
182         float lagrangeScaled = 0.0F;
183
184         float targetLength;
185         float inverseMassi;
186         float inverseMassj;
187         float sqrtReducedMass;
188
189         Float3 xi;
190         Float3 xj;
191         Float3 rc;
192
193         // i == -1 indicates dummy constraint at the end of the thread block.
194         bool isDummyThread = (i == -1);
195
196         // Everything computed for these dummies will be equal to zero
197         if (isDummyThread)
198         {
199             targetLength    = 0.0F;
200             inverseMassi    = 0.0F;
201             inverseMassj    = 0.0F;
202             sqrtReducedMass = 0.0F;
203
204             xi = Float3(0.0F, 0.0F, 0.0F);
205             xj = Float3(0.0F, 0.0F, 0.0F);
206             rc = Float3(0.0F, 0.0F, 0.0F);
207         }
208         else
209         {
210             // Collecting data
211             targetLength    = a_constraintsTargetLengths[threadIndex];
212             inverseMassi    = a_inverseMasses[i];
213             inverseMassj    = a_inverseMasses[j];
214             sqrtReducedMass = cl::sycl::rsqrt(inverseMassi + inverseMassj);
215
216             xi = a_x[i];
217             xj = a_x[j];
218
219             Float3 dx;
220             pbcDxAiucSycl(pbcAiuc, xi, xj, dx);
221
222             float rlen = cl::sycl::rsqrt(dx[XX] * dx[XX] + dx[YY] * dx[YY] + dx[ZZ] * dx[ZZ]);
223             rc         = rlen * dx;
224         }
225
226         sm_r[threadInBlock] = rc;
227         // Make sure that all r's are saved into shared memory
228         // before they are accessed in the loop below
229         itemIdx.barrier(fence_space::global_and_local);
230
231         /*
232          * Constructing LINCS matrix (A)
233          */
234         int coupledConstraintsCount = 0;
235         if constexpr (haveCoupledConstraints)
236         {
237             // Only non-zero values are saved (for coupled constraints)
238             coupledConstraintsCount = a_coupledConstraintsCounts[threadIndex];
239             for (int n = 0; n < coupledConstraintsCount; n++)
240             {
241                 int index = n * numConstraintsThreads + threadIndex;
242                 int c1    = a_coupledConstraintsIndices[index];
243
244                 Float3 rc1       = sm_r[c1];
245                 a_matrixA[index] = a_massFactors[index]
246                                    * (rc[XX] * rc1[XX] + rc[YY] * rc1[YY] + rc[ZZ] * rc1[ZZ]);
247             }
248         }
249
250         // Skipping in dummy threads
251         if (!isDummyThread)
252         {
253             xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
254             xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
255             xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
256             xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
257             xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
258             xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
259         }
260
261         Float3 dx;
262         pbcDxAiucSycl(pbcAiuc, xi, xj, dx);
263
264         float sol = sqrtReducedMass * ((rc[XX] * dx[XX] + rc[YY] * dx[YY] + rc[ZZ] * dx[ZZ]) - targetLength);
265
266         /*
267          *  Inverse matrix using a set of expansionOrder matrix multiplications
268          */
269
270         // This will use the same memory space as sm_r, which is no longer needed.
271         sm_rhs[threadInBlock] = sol;
272
273         // No need to iterate if there are no coupled constraints.
274         if constexpr (haveCoupledConstraints)
275         {
276             for (int rec = 0; rec < expansionOrder; rec++)
277             {
278                 // Making sure that all sm_rhs are saved before they are accessed in a loop below
279                 itemIdx.barrier(fence_space::global_and_local);
280                 float mvb = 0.0F;
281                 for (int n = 0; n < coupledConstraintsCount; n++)
282                 {
283                     int index = n * numConstraintsThreads + threadIndex;
284                     int c1    = a_coupledConstraintsIndices[index];
285                     // Convolute current right-hand-side with A
286                     // Different, non overlapping parts of sm_rhs[..] are read during odd and even iterations
287                     mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
288                 }
289                 // 'Switch' rhs vectors, save current result
290                 // These values will be accessed in the loop above during the next iteration.
291                 sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
292                 sol                                                         = sol + mvb;
293             }
294         }
295
296         // Current mass-scaled Lagrange multipliers
297         lagrangeScaled = sqrtReducedMass * sol;
298
299         // Save updated coordinates before correction for the rotational lengthening
300         Float3 tmp = rc * lagrangeScaled;
301
302         // Writing for all but dummy constraints
303         if (!isDummyThread)
304         {
305             /*
306              * Note: Using memory_scope::work_group for atomic_ref can be better here,
307              * but for now we re-use the existing function for memory_scope::device atomics.
308              */
309             atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
310             atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
311             atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
312             atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
313             atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
314             atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
315         }
316
317         /*
318          *  Correction for centripetal effects
319          */
320         for (int iter = 0; iter < numIterations; iter++)
321         {
322             // Make sure that all xp's are saved: atomic operation calls before are
323             // communicating current xp[..] values across thread block.
324             itemIdx.barrier(fence_space::global_and_local);
325
326             if (!isDummyThread)
327             {
328                 xi[XX] = atomicLoad(a_xp[i * DIM + XX]);
329                 xi[YY] = atomicLoad(a_xp[i * DIM + YY]);
330                 xi[ZZ] = atomicLoad(a_xp[i * DIM + ZZ]);
331                 xj[XX] = atomicLoad(a_xp[j * DIM + XX]);
332                 xj[YY] = atomicLoad(a_xp[j * DIM + YY]);
333                 xj[ZZ] = atomicLoad(a_xp[j * DIM + ZZ]);
334             }
335
336             Float3 dx;
337             pbcDxAiucSycl(pbcAiuc, xi, xj, dx);
338
339             float len2  = targetLength * targetLength;
340             float dlen2 = 2.0F * len2 - (dx[XX] * dx[XX] + dx[YY] * dx[YY] + dx[ZZ] * dx[ZZ]);
341
342             // TODO A little bit more effective but slightly less readable version of the below would be:
343             //      float proj = sqrtReducedMass*(targetLength - (dlen2 > 0.0f ? 1.0f : 0.0f)*dlen2*rsqrt(dlen2));
344             float proj;
345             if (dlen2 > 0.0F)
346             {
347                 proj = sqrtReducedMass * (targetLength - dlen2 * cl::sycl::rsqrt(dlen2));
348             }
349             else
350             {
351                 proj = sqrtReducedMass * targetLength;
352             }
353
354             sm_rhs[threadInBlock] = proj;
355             float sol             = proj;
356
357             /*
358              * Same matrix inversion as above is used for updated data
359              */
360             if constexpr (haveCoupledConstraints)
361             {
362                 for (int rec = 0; rec < expansionOrder; rec++)
363                 {
364                     // Make sure that all elements of rhs are saved into shared memory
365                     itemIdx.barrier(fence_space::global_and_local);
366                     float mvb = 0;
367                     for (int n = 0; n < coupledConstraintsCount; n++)
368                     {
369                         int index = n * numConstraintsThreads + threadIndex;
370                         int c1    = a_coupledConstraintsIndices[index];
371
372                         mvb = mvb + a_matrixA[index] * sm_rhs[c1 + c_threadsPerBlock * (rec % 2)];
373                     }
374
375                     sm_rhs[threadInBlock + c_threadsPerBlock * ((rec + 1) % 2)] = mvb;
376                     sol                                                         = sol + mvb;
377                 }
378             }
379
380             // Add corrections to Lagrange multipliers
381             float sqrtmu_sol = sqrtReducedMass * sol;
382             lagrangeScaled += sqrtmu_sol;
383
384             // Save updated coordinates for the next iteration
385             // Dummy constraints are skipped
386             if (!isDummyThread)
387             {
388                 Float3 tmp = rc * sqrtmu_sol;
389                 atomicFetchAdd(a_xp[i * DIM + XX], -tmp[XX] * inverseMassi);
390                 atomicFetchAdd(a_xp[i * DIM + YY], -tmp[YY] * inverseMassi);
391                 atomicFetchAdd(a_xp[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
392                 atomicFetchAdd(a_xp[j * DIM + XX], tmp[XX] * inverseMassj);
393                 atomicFetchAdd(a_xp[j * DIM + YY], tmp[YY] * inverseMassj);
394                 atomicFetchAdd(a_xp[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
395             }
396         }
397
398         // Updating particle velocities for all but dummy threads
399         if constexpr (updateVelocities)
400         {
401             if (!isDummyThread)
402             {
403                 Float3 tmp = rc * invdt * lagrangeScaled;
404                 atomicFetchAdd(a_v[i * DIM + XX], -tmp[XX] * inverseMassi);
405                 atomicFetchAdd(a_v[i * DIM + YY], -tmp[YY] * inverseMassi);
406                 atomicFetchAdd(a_v[i * DIM + ZZ], -tmp[ZZ] * inverseMassi);
407                 atomicFetchAdd(a_v[j * DIM + XX], tmp[XX] * inverseMassj);
408                 atomicFetchAdd(a_v[j * DIM + YY], tmp[YY] * inverseMassj);
409                 atomicFetchAdd(a_v[j * DIM + ZZ], tmp[ZZ] * inverseMassj);
410             }
411         }
412
413         if constexpr (computeVirial)
414         {
415             // Virial is computed from Lagrange multiplier (lagrangeScaled), target constrain length
416             // (targetLength) and the normalized vector connecting constrained atoms before
417             // the algorithm was applied (rc). The evaluation of virial in each thread is
418             // followed by basic reduction for the values inside single thread block.
419             // Then, the values are reduced across grid by atomicAdd(...).
420             //
421             // TODO Shuffle reduction.
422             // TODO Should be unified and/or done once when virial is actually needed.
423             // TODO Recursive version that removes atomicAdd(...)'s entirely is needed. Ideally,
424             //      one that works for any datatype.
425
426             // Save virial for each thread into the shared memory. Tensor is symmetrical, hence only
427             // 6 values are saved. Dummy threads will have zeroes in their virial: targetLength,
428             // lagrangeScaled and rc are all set to zero for them in the beginning of the kernel.
429             float mult                                             = targetLength * lagrangeScaled;
430             sm_threadVirial[0 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[XX];
431             sm_threadVirial[1 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[YY];
432             sm_threadVirial[2 * c_threadsPerBlock + threadInBlock] = mult * rc[XX] * rc[ZZ];
433             sm_threadVirial[3 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[YY];
434             sm_threadVirial[4 * c_threadsPerBlock + threadInBlock] = mult * rc[YY] * rc[ZZ];
435             sm_threadVirial[5 * c_threadsPerBlock + threadInBlock] = mult * rc[ZZ] * rc[ZZ];
436
437             itemIdx.barrier(fence_space::local_space);
438             // This casts unsigned into signed integers to avoid clang warnings
439             const int tib          = static_cast<int>(threadInBlock);
440             const int blockSize    = static_cast<int>(c_threadsPerBlock);
441             const int subGroupSize = itemIdx.get_sub_group().get_max_local_range()[0];
442
443             // Reduce up to one virial per thread block
444             // All blocks are divided by half, the first half of threads sums
445             // two virials. Then the first half is divided by two and the first half
446             // of it sums two values... The procedure continues until only one thread left.
447             // Only works if the threads per blocks is a power of two.
448             for (int divideBy = 2; divideBy <= blockSize; divideBy *= 2)
449             {
450                 int dividedAt = blockSize / divideBy;
451                 if (tib < dividedAt)
452                 {
453                     for (int d = 0; d < 6; d++)
454                     {
455                         sm_threadVirial[d * blockSize + tib] +=
456                                 sm_threadVirial[d * blockSize + (tib + dividedAt)];
457                     }
458                 }
459                 if (dividedAt > subGroupSize / 2)
460                 {
461                     itemIdx.barrier(fence_space::local_space);
462                 }
463                 else
464                 {
465                     subGroupBarrier(itemIdx);
466                 }
467             }
468             // First 6 threads in the block add the 6 components of virial to the global memory address
469             if (tib < 6)
470             {
471                 atomicFetchAdd(a_virialScaled[tib], sm_threadVirial[tib * blockSize]);
472             }
473         }
474     };
475 }
476
477 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
478 template<bool updateVelocities, bool computeVirial, bool haveCoupledConstraints>
479 class LincsKernelName;
480
481 template<bool updateVelocities, bool computeVirial, bool haveCoupledConstraints, class... Args>
482 static cl::sycl::event launchLincsKernel(const DeviceStream& deviceStream,
483                                          const int           numConstraintsThreads,
484                                          Args&&... args)
485 {
486     // Should not be needed for SYCL2020.
487     using kernelNameType = LincsKernelName<updateVelocities, computeVirial, haveCoupledConstraints>;
488
489     const cl::sycl::nd_range<1> rangeAllLincs(numConstraintsThreads, c_threadsPerBlock);
490     cl::sycl::queue             q = deviceStream.stream();
491
492     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
493         auto kernel = lincsKernel<updateVelocities, computeVirial, haveCoupledConstraints>(
494                 cgh, numConstraintsThreads, std::forward<Args>(args)...);
495         cgh.parallel_for<kernelNameType>(rangeAllLincs, kernel);
496     });
497
498     return e;
499 }
500
501 /*! \brief Select templated kernel and launch it. */
502 template<class... Args>
503 static inline cl::sycl::event
504 launchLincsKernel(bool updateVelocities, bool computeVirial, bool haveCoupledConstraints, Args&&... args)
505 {
506     return dispatchTemplatedFunction(
507             [&](auto updateVelocities_, auto computeVirial_, auto haveCoupledConstraints_) {
508                 return launchLincsKernel<updateVelocities_, computeVirial_, haveCoupledConstraints_>(
509                         std::forward<Args>(args)...);
510             },
511             updateVelocities,
512             computeVirial,
513             haveCoupledConstraints);
514 }
515
516
517 void launchLincsGpuKernel(LincsGpuKernelParameters*   kernelParams,
518                           const DeviceBuffer<Float3>& d_x,
519                           DeviceBuffer<Float3>        d_xp,
520                           const bool                  updateVelocities,
521                           DeviceBuffer<Float3>        d_v,
522                           const real                  invdt,
523                           const bool                  computeVirial,
524                           const DeviceStream&         deviceStream)
525 {
526     cl::sycl::buffer<Float3, 1> xp(*d_xp.buffer_);
527     auto                        d_xpAsFloat = xp.reinterpret<float, 1>(xp.get_count() * DIM);
528
529     cl::sycl::buffer<Float3, 1> v(*d_v.buffer_);
530     auto                        d_vAsFloat = v.reinterpret<float, 1>(v.get_count() * DIM);
531
532     launchLincsKernel(updateVelocities,
533                       computeVirial,
534                       kernelParams->haveCoupledConstraints,
535                       deviceStream,
536                       kernelParams->numConstraintsThreads,
537                       kernelParams->d_constraints,
538                       kernelParams->d_constraintsTargetLengths,
539                       kernelParams->d_coupledConstraintsCounts,
540                       kernelParams->d_coupledConstraintsIndices,
541                       kernelParams->d_massFactors,
542                       kernelParams->d_matrixA,
543                       kernelParams->d_inverseMasses,
544                       kernelParams->numIterations,
545                       kernelParams->expansionOrder,
546                       d_x,
547                       d_xpAsFloat,
548                       invdt,
549                       d_vAsFloat,
550                       kernelParams->d_virialScaled,
551                       kernelParams->pbcAiuc);
552     return;
553 }
554
555 } // namespace gmx