2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 * \brief SYCL-specific routines for the GPU implementation of SETTLE constraints algorithm.
39 * \author Artem Zhmurov <zhmurov@gmail.com>
41 * \ingroup module_mdlib
44 #include "settle_gpu_internal.h"
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
48 #include "gromacs/pbcutil/pbc_aiuc_sycl.h"
49 #include "gromacs/utility/gmxassert.h"
50 #include "gromacs/utility/template_mp.h"
55 using cl::sycl::access::fence_space;
56 using cl::sycl::access::mode;
57 using cl::sycl::access::target;
59 //! Number of work-items in a work-group
60 constexpr static int sc_workGroupSize = 256;
62 template<bool updateVelocities, bool computeVirial>
63 auto settleKernel(cl::sycl::handler& cgh,
65 DeviceAccessor<WaterMolecule, mode::read> a_settles,
66 SettleParameters pars,
67 DeviceAccessor<Float3, mode::read> a_x,
68 DeviceAccessor<Float3, mode::read_write> a_xp,
70 OptionalAccessor<Float3, mode::read_write, updateVelocities> a_v,
71 OptionalAccessor<float, mode_atomic, computeVirial> a_virialScaled,
74 cgh.require(a_settles);
77 if constexpr (updateVelocities)
81 if constexpr (computeVirial)
83 cgh.require(a_virialScaled);
86 // shmem buffer for i x+q pre-loading
87 auto sm_threadVirial = [&]() {
88 if constexpr (computeVirial)
90 return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
91 cl::sycl::range<1>(sc_workGroupSize * 6), cgh);
99 return [=](cl::sycl::nd_item<1> itemIdx) {
100 constexpr float almost_zero = real(1e-12);
101 const int settleIdx = itemIdx.get_global_linear_id();
102 const int threadIdx = itemIdx.get_local_linear_id(); // Work-item index in work-group
103 assert(itemIdx.get_local_range(0) == sc_workGroupSize);
104 // These are the indexes of three atoms in a single 'water' molecule.
105 // TODO Can be reduced to one integer if atoms are consecutive in memory.
106 if (settleIdx < numSettles)
108 WaterMolecule indices = a_settles[settleIdx];
110 const Float3 x_ow1 = a_x[indices.ow1];
111 const Float3 x_hw2 = a_x[indices.hw2];
112 const Float3 x_hw3 = a_x[indices.hw3];
114 const Float3 xprime_ow1 = a_xp[indices.ow1];
115 const Float3 xprime_hw2 = a_xp[indices.hw2];
116 const Float3 xprime_hw3 = a_xp[indices.hw3];
119 pbcDxAiucSycl(pbcAiuc, x_hw2, x_ow1, dist21);
121 pbcDxAiucSycl(pbcAiuc, x_hw3, x_ow1, dist31);
123 pbcDxAiucSycl(pbcAiuc, xprime_hw2, xprime_ow1, doh2);
126 pbcDxAiucSycl(pbcAiuc, xprime_hw3, xprime_ow1, doh3);
128 Float3 a1 = (doh2 + doh3) * (-pars.wh);
130 Float3 b1 = doh2 + a1;
132 Float3 c1 = doh3 + a1;
134 float xakszd = dist21[YY] * dist31[ZZ] - dist21[ZZ] * dist31[YY];
135 float yakszd = dist21[ZZ] * dist31[XX] - dist21[XX] * dist31[ZZ];
136 float zakszd = dist21[XX] * dist31[YY] - dist21[YY] * dist31[XX];
138 float xaksxd = a1[YY] * zakszd - a1[ZZ] * yakszd;
139 float yaksxd = a1[ZZ] * xakszd - a1[XX] * zakszd;
140 float zaksxd = a1[XX] * yakszd - a1[YY] * xakszd;
142 float xaksyd = yakszd * zaksxd - zakszd * yaksxd;
143 float yaksyd = zakszd * xaksxd - xakszd * zaksxd;
144 float zaksyd = xakszd * yaksxd - yakszd * xaksxd;
146 float axlng = cl::sycl::rsqrt(xaksxd * xaksxd + yaksxd * yaksxd + zaksxd * zaksxd);
147 float aylng = cl::sycl::rsqrt(xaksyd * xaksyd + yaksyd * yaksyd + zaksyd * zaksyd);
148 float azlng = cl::sycl::rsqrt(xakszd * xakszd + yakszd * yakszd + zakszd * zakszd);
150 // TODO {1,2,3} indexes should be swapped with {.x, .y, .z} components.
151 // This way, we will be able to use vector ops more.
152 Float3 trns1, trns2, trns3;
154 trns1[XX] = xaksxd * axlng;
155 trns2[XX] = yaksxd * axlng;
156 trns3[XX] = zaksxd * axlng;
158 trns1[YY] = xaksyd * aylng;
159 trns2[YY] = yaksyd * aylng;
160 trns3[YY] = zaksyd * aylng;
162 trns1[ZZ] = xakszd * azlng;
163 trns2[ZZ] = yakszd * azlng;
164 trns3[ZZ] = zakszd * azlng;
169 b0d[XX] = trns1[XX] * dist21[XX] + trns2[XX] * dist21[YY] + trns3[XX] * dist21[ZZ];
170 b0d[YY] = trns1[YY] * dist21[XX] + trns2[YY] * dist21[YY] + trns3[YY] * dist21[ZZ];
172 c0d[XX] = trns1[XX] * dist31[XX] + trns2[XX] * dist31[YY] + trns3[XX] * dist31[ZZ];
173 c0d[YY] = trns1[YY] * dist31[XX] + trns2[YY] * dist31[YY] + trns3[YY] * dist31[ZZ];
177 float a1d_z = trns1[ZZ] * a1[XX] + trns2[ZZ] * a1[YY] + trns3[ZZ] * a1[ZZ];
179 b1d[XX] = trns1[XX] * b1[XX] + trns2[XX] * b1[YY] + trns3[XX] * b1[ZZ];
180 b1d[YY] = trns1[YY] * b1[XX] + trns2[YY] * b1[YY] + trns3[YY] * b1[ZZ];
181 b1d[ZZ] = trns1[ZZ] * b1[XX] + trns2[ZZ] * b1[YY] + trns3[ZZ] * b1[ZZ];
183 c1d[XX] = trns1[XX] * c1[XX] + trns2[XX] * c1[YY] + trns3[XX] * c1[ZZ];
184 c1d[YY] = trns1[YY] * c1[XX] + trns2[YY] * c1[YY] + trns3[YY] * c1[ZZ];
185 c1d[ZZ] = trns1[ZZ] * c1[XX] + trns2[ZZ] * c1[YY] + trns3[ZZ] * c1[ZZ];
188 const float sinphi = a1d_z * cl::sycl::rsqrt(pars.ra * pars.ra);
189 float tmp2 = 1.0F - sinphi * sinphi;
191 if (almost_zero > tmp2)
196 const float tmp = cl::sycl::rsqrt(tmp2);
197 const float cosphi = tmp2 * tmp;
198 const float sinpsi = (b1d[ZZ] - c1d[ZZ]) * pars.irc2 * tmp;
199 tmp2 = 1.0F - sinpsi * sinpsi;
201 const float cospsi = tmp2 * cl::sycl::rsqrt(tmp2);
203 const float a2d_y = pars.ra * cosphi;
204 const float b2d_x = -pars.rc * cospsi;
205 const float t1 = -pars.rb * cosphi;
206 const float t2 = pars.rc * sinpsi * sinphi;
207 const float b2d_y = t1 - t2;
208 const float c2d_y = t1 + t2;
210 /* --- Step3 al,be,ga --- */
211 const float alpha = b2d_x * (b0d[XX] - c0d[XX]) + b0d[YY] * b2d_y + c0d[YY] * c2d_y;
212 const float beta = b2d_x * (c0d[YY] - b0d[YY]) + b0d[XX] * b2d_y + c0d[XX] * c2d_y;
214 b0d[XX] * b1d[YY] - b1d[XX] * b0d[YY] + c0d[XX] * c1d[YY] - c1d[XX] * c0d[YY];
215 const float al2be2 = alpha * alpha + beta * beta;
216 tmp2 = (al2be2 - gamma * gamma);
217 const float sinthe = (alpha * gamma - beta * tmp2 * cl::sycl::rsqrt(tmp2))
218 * cl::sycl::rsqrt(al2be2 * al2be2);
220 /* --- Step4 A3' --- */
221 tmp2 = 1.0F - sinthe * sinthe;
222 float costhe = tmp2 * cl::sycl::rsqrt(tmp2);
224 Float3 a3d, b3d, c3d;
226 a3d[XX] = -a2d_y * sinthe;
227 a3d[YY] = a2d_y * costhe;
229 b3d[XX] = b2d_x * costhe - b2d_y * sinthe;
230 b3d[YY] = b2d_x * sinthe + b2d_y * costhe;
232 c3d[XX] = -b2d_x * costhe - c2d_y * sinthe;
233 c3d[YY] = -b2d_x * sinthe + c2d_y * costhe;
236 /* --- Step5 A3 --- */
239 a3[XX] = trns1[XX] * a3d[XX] + trns1[YY] * a3d[YY] + trns1[ZZ] * a3d[ZZ];
240 a3[YY] = trns2[XX] * a3d[XX] + trns2[YY] * a3d[YY] + trns2[ZZ] * a3d[ZZ];
241 a3[ZZ] = trns3[XX] * a3d[XX] + trns3[YY] * a3d[YY] + trns3[ZZ] * a3d[ZZ];
243 b3[XX] = trns1[XX] * b3d[XX] + trns1[YY] * b3d[YY] + trns1[ZZ] * b3d[ZZ];
244 b3[YY] = trns2[XX] * b3d[XX] + trns2[YY] * b3d[YY] + trns2[ZZ] * b3d[ZZ];
245 b3[ZZ] = trns3[XX] * b3d[XX] + trns3[YY] * b3d[YY] + trns3[ZZ] * b3d[ZZ];
247 c3[XX] = trns1[XX] * c3d[XX] + trns1[YY] * c3d[YY] + trns1[ZZ] * c3d[ZZ];
248 c3[YY] = trns2[XX] * c3d[XX] + trns2[YY] * c3d[YY] + trns2[ZZ] * c3d[ZZ];
249 c3[ZZ] = trns3[XX] * c3d[XX] + trns3[YY] * c3d[YY] + trns3[ZZ] * c3d[ZZ];
252 /* Compute and store the corrected new coordinate */
253 const Float3 dxOw1 = a3 - a1;
254 const Float3 dxHw2 = b3 - b1;
255 const Float3 dxHw3 = c3 - c1;
257 a_xp[indices.ow1] = xprime_ow1 + dxOw1;
258 a_xp[indices.hw2] = xprime_hw2 + dxHw2;
259 a_xp[indices.hw3] = xprime_hw3 + dxHw3;
261 if constexpr (updateVelocities)
263 Float3 v_ow1 = a_v[indices.ow1];
264 Float3 v_hw2 = a_v[indices.hw2];
265 Float3 v_hw3 = a_v[indices.hw3];
267 /* Add the position correction divided by dt to the velocity */
268 v_ow1 = dxOw1 * invdt + v_ow1;
269 v_hw2 = dxHw2 * invdt + v_hw2;
270 v_hw3 = dxHw3 * invdt + v_hw3;
272 a_v[indices.ow1] = v_ow1;
273 a_v[indices.hw2] = v_hw2;
274 a_v[indices.hw3] = v_hw3;
277 if constexpr (computeVirial)
279 Float3 mdb = pars.mH * dxHw2;
280 Float3 mdc = pars.mH * dxHw3;
281 Float3 mdo = pars.mO * dxOw1 + mdb + mdc;
283 sm_threadVirial[0 * sc_workGroupSize + threadIdx] =
284 -(x_ow1[0] * mdo[0] + dist21[0] * mdb[0] + dist31[0] * mdc[0]);
285 sm_threadVirial[1 * sc_workGroupSize + threadIdx] =
286 -(x_ow1[0] * mdo[1] + dist21[0] * mdb[1] + dist31[0] * mdc[1]);
287 sm_threadVirial[2 * sc_workGroupSize + threadIdx] =
288 -(x_ow1[0] * mdo[2] + dist21[0] * mdb[2] + dist31[0] * mdc[2]);
289 sm_threadVirial[3 * sc_workGroupSize + threadIdx] =
290 -(x_ow1[1] * mdo[1] + dist21[1] * mdb[1] + dist31[1] * mdc[1]);
291 sm_threadVirial[4 * sc_workGroupSize + threadIdx] =
292 -(x_ow1[1] * mdo[2] + dist21[1] * mdb[2] + dist31[1] * mdc[2]);
293 sm_threadVirial[5 * sc_workGroupSize + threadIdx] =
294 -(x_ow1[2] * mdo[2] + dist21[2] * mdb[2] + dist31[2] * mdc[2]);
297 else // settleIdx < numSettles
299 // Filling data for dummy threads with zeroes
300 if constexpr (computeVirial)
302 for (int d = 0; d < 6; d++)
304 sm_threadVirial[d * sc_workGroupSize + threadIdx] = 0.0F;
309 // Basic reduction for the values inside single thread block
310 // TODO what follows should be separated out as a standard virial reduction subroutine
311 if constexpr (computeVirial)
313 // This is to ensure that all threads saved the data before reduction starts
314 subGroupBarrier(itemIdx);
315 constexpr int blockSize = sc_workGroupSize;
316 const int subGroupSize = itemIdx.get_sub_group().get_max_local_range()[0];
317 // Reduce up to one virial per thread block
318 // All blocks are divided by half, the first half of threads sums
319 // two virials. Then the first half is divided by two and the first half
320 // of it sums two values... The procedure continues until only one thread left.
321 // Only works if the threads per blocks is a power of two, hence the assertion.
322 static_assert(gmx::isPowerOfTwo(sc_workGroupSize));
323 for (int divideBy = 2; divideBy <= blockSize; divideBy *= 2)
325 int dividedAt = blockSize / divideBy;
326 if (threadIdx < dividedAt)
328 for (int d = 0; d < 6; d++)
330 sm_threadVirial[d * blockSize + threadIdx] +=
331 sm_threadVirial[d * blockSize + (threadIdx + dividedAt)];
334 if (dividedAt > subGroupSize / 2)
336 subGroupBarrier(itemIdx);
339 // First 6 threads in the block add the 6 components of virial to the global memory address
342 atomicFetchAdd(a_virialScaled, threadIdx, sm_threadVirial[threadIdx * blockSize]);
348 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
349 template<bool updateVelocities, bool computeVirial>
350 class SettleKernelName;
352 template<bool updateVelocities, bool computeVirial, class... Args>
353 static cl::sycl::event launchSettleKernel(const DeviceStream& deviceStream, int numSettles, Args&&... args)
355 // Should not be needed for SYCL2020.
356 using kernelNameType = SettleKernelName<updateVelocities, computeVirial>;
358 const int numSettlesRoundedUp =
359 static_cast<int>((numSettles + sc_workGroupSize - 1) / sc_workGroupSize) * sc_workGroupSize;
360 const cl::sycl::nd_range<1> rangeAllSettles(numSettlesRoundedUp, sc_workGroupSize);
361 cl::sycl::queue q = deviceStream.stream();
363 cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
364 auto kernel = settleKernel<updateVelocities, computeVirial>(
365 cgh, numSettles, std::forward<Args>(args)...);
366 cgh.parallel_for<kernelNameType>(rangeAllSettles, kernel);
372 /*! \brief Select templated kernel and launch it. */
373 template<class... Args>
374 static inline cl::sycl::event launchSettleKernel(bool updateVelocities, bool computeVirial, Args&&... args)
376 return dispatchTemplatedFunction(
377 [&](auto updateVelocities_, auto computeVirial_) {
378 return launchSettleKernel<updateVelocities_, computeVirial_>(std::forward<Args>(args)...);
385 void launchSettleGpuKernel(const int numSettles,
386 const DeviceBuffer<WaterMolecule>& d_settles,
387 const SettleParameters& settleParameters,
388 const DeviceBuffer<Float3>& d_x,
389 DeviceBuffer<Float3> d_xp,
390 const bool updateVelocities,
391 DeviceBuffer<Float3> d_v,
393 const bool computeVirial,
394 DeviceBuffer<float> virialScaled,
395 const PbcAiuc& pbcAiuc,
396 const DeviceStream& deviceStream)
399 launchSettleKernel(updateVelocities,