SYCL: Use acc.bind(cgh) instead of cgh.require(acc)
[alexxy/gromacs.git] / src / gromacs / mdlib / settle_gpu_internal_sycl.cpp
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  *
37  * \brief SYCL-specific routines for the GPU implementation of SETTLE constraints algorithm.
38  *
39  * \author Artem Zhmurov <zhmurov@gmail.com>
40  *
41  * \ingroup module_mdlib
42  */
43
44 #include "settle_gpu_internal.h"
45
46 #include "gromacs/gpu_utils/devicebuffer.h"
47 #include "gromacs/gpu_utils/sycl_kernel_utils.h"
48 #include "gromacs/pbcutil/pbc_aiuc_sycl.h"
49 #include "gromacs/utility/gmxassert.h"
50 #include "gromacs/utility/template_mp.h"
51
52 namespace gmx
53 {
54
55 using cl::sycl::access::fence_space;
56 using cl::sycl::access::mode;
57 using cl::sycl::access::target;
58
59 //! Number of work-items in a work-group
60 constexpr static int sc_workGroupSize = 256;
61
62 //! \brief Function returning the SETTLE kernel lambda.
63 template<bool updateVelocities, bool computeVirial>
64 auto settleKernel(cl::sycl::handler&                                           cgh,
65                   const int                                                    numSettles,
66                   DeviceAccessor<WaterMolecule, mode::read>                    a_settles,
67                   SettleParameters                                             pars,
68                   DeviceAccessor<Float3, mode::read>                           a_x,
69                   DeviceAccessor<Float3, mode::read_write>                     a_xp,
70                   float                                                        invdt,
71                   OptionalAccessor<Float3, mode::read_write, updateVelocities> a_v,
72                   OptionalAccessor<float, mode::read_write, computeVirial>     a_virialScaled,
73                   PbcAiuc                                                      pbcAiuc)
74 {
75     a_settles.bind(cgh);
76     a_x.bind(cgh);
77     a_xp.bind(cgh);
78     if constexpr (updateVelocities)
79     {
80         a_v.bind(cgh);
81     }
82     if constexpr (computeVirial)
83     {
84         a_virialScaled.bind(cgh);
85     }
86
87     // shmem buffer for i x+q pre-loading
88     auto sm_threadVirial = [&]() {
89         if constexpr (computeVirial)
90         {
91             return cl::sycl::accessor<float, 1, mode::read_write, target::local>(
92                     cl::sycl::range<1>(sc_workGroupSize * 6), cgh);
93         }
94         else
95         {
96             return nullptr;
97         }
98     }();
99
100     return [=](cl::sycl::nd_item<1> itemIdx) {
101         constexpr float almost_zero = real(1e-12);
102         const int       settleIdx   = itemIdx.get_global_linear_id();
103         const int       threadIdx = itemIdx.get_local_linear_id(); // Work-item index in work-group
104         assert(itemIdx.get_local_range(0) == sc_workGroupSize);
105         // These are the indexes of three atoms in a single 'water' molecule.
106         // TODO Can be reduced to one integer if atoms are consecutive in memory.
107         if (settleIdx < numSettles)
108         {
109             WaterMolecule indices = a_settles[settleIdx];
110
111             const Float3 x_ow1 = a_x[indices.ow1];
112             const Float3 x_hw2 = a_x[indices.hw2];
113             const Float3 x_hw3 = a_x[indices.hw3];
114
115             const Float3 xprime_ow1 = a_xp[indices.ow1];
116             const Float3 xprime_hw2 = a_xp[indices.hw2];
117             const Float3 xprime_hw3 = a_xp[indices.hw3];
118
119             Float3 dist21;
120             pbcDxAiucSycl(pbcAiuc, x_hw2, x_ow1, dist21);
121             Float3 dist31;
122             pbcDxAiucSycl(pbcAiuc, x_hw3, x_ow1, dist31);
123             Float3 doh2;
124             pbcDxAiucSycl(pbcAiuc, xprime_hw2, xprime_ow1, doh2);
125
126             Float3 doh3;
127             pbcDxAiucSycl(pbcAiuc, xprime_hw3, xprime_ow1, doh3);
128
129             Float3 a1 = (doh2 + doh3) * (-pars.wh);
130
131             Float3 b1 = doh2 + a1;
132
133             Float3 c1 = doh3 + a1;
134
135             float xakszd = dist21[YY] * dist31[ZZ] - dist21[ZZ] * dist31[YY];
136             float yakszd = dist21[ZZ] * dist31[XX] - dist21[XX] * dist31[ZZ];
137             float zakszd = dist21[XX] * dist31[YY] - dist21[YY] * dist31[XX];
138
139             float xaksxd = a1[YY] * zakszd - a1[ZZ] * yakszd;
140             float yaksxd = a1[ZZ] * xakszd - a1[XX] * zakszd;
141             float zaksxd = a1[XX] * yakszd - a1[YY] * xakszd;
142
143             float xaksyd = yakszd * zaksxd - zakszd * yaksxd;
144             float yaksyd = zakszd * xaksxd - xakszd * zaksxd;
145             float zaksyd = xakszd * yaksxd - yakszd * xaksxd;
146
147             float axlng = cl::sycl::rsqrt(xaksxd * xaksxd + yaksxd * yaksxd + zaksxd * zaksxd);
148             float aylng = cl::sycl::rsqrt(xaksyd * xaksyd + yaksyd * yaksyd + zaksyd * zaksyd);
149             float azlng = cl::sycl::rsqrt(xakszd * xakszd + yakszd * yakszd + zakszd * zakszd);
150
151             // TODO {1,2,3} indexes should be swapped with {.x, .y, .z} components.
152             //      This way, we will be able to use vector ops more.
153             Float3 trns1, trns2, trns3;
154
155             trns1[XX] = xaksxd * axlng;
156             trns2[XX] = yaksxd * axlng;
157             trns3[XX] = zaksxd * axlng;
158
159             trns1[YY] = xaksyd * aylng;
160             trns2[YY] = yaksyd * aylng;
161             trns3[YY] = zaksyd * aylng;
162
163             trns1[ZZ] = xakszd * azlng;
164             trns2[ZZ] = yakszd * azlng;
165             trns3[ZZ] = zakszd * azlng;
166
167
168             Float2 b0d, c0d;
169
170             b0d[XX] = trns1[XX] * dist21[XX] + trns2[XX] * dist21[YY] + trns3[XX] * dist21[ZZ];
171             b0d[YY] = trns1[YY] * dist21[XX] + trns2[YY] * dist21[YY] + trns3[YY] * dist21[ZZ];
172
173             c0d[XX] = trns1[XX] * dist31[XX] + trns2[XX] * dist31[YY] + trns3[XX] * dist31[ZZ];
174             c0d[YY] = trns1[YY] * dist31[XX] + trns2[YY] * dist31[YY] + trns3[YY] * dist31[ZZ];
175
176             Float3 b1d, c1d;
177
178             float a1d_z = trns1[ZZ] * a1[XX] + trns2[ZZ] * a1[YY] + trns3[ZZ] * a1[ZZ];
179
180             b1d[XX] = trns1[XX] * b1[XX] + trns2[XX] * b1[YY] + trns3[XX] * b1[ZZ];
181             b1d[YY] = trns1[YY] * b1[XX] + trns2[YY] * b1[YY] + trns3[YY] * b1[ZZ];
182             b1d[ZZ] = trns1[ZZ] * b1[XX] + trns2[ZZ] * b1[YY] + trns3[ZZ] * b1[ZZ];
183
184             c1d[XX] = trns1[XX] * c1[XX] + trns2[XX] * c1[YY] + trns3[XX] * c1[ZZ];
185             c1d[YY] = trns1[YY] * c1[XX] + trns2[YY] * c1[YY] + trns3[YY] * c1[ZZ];
186             c1d[ZZ] = trns1[ZZ] * c1[XX] + trns2[ZZ] * c1[YY] + trns3[ZZ] * c1[ZZ];
187
188
189             const float sinphi = a1d_z * cl::sycl::rsqrt(pars.ra * pars.ra);
190             float       tmp2   = 1.0F - sinphi * sinphi;
191
192             if (almost_zero > tmp2)
193             {
194                 tmp2 = almost_zero;
195             }
196
197             const float tmp    = cl::sycl::rsqrt(tmp2);
198             const float cosphi = tmp2 * tmp;
199             const float sinpsi = (b1d[ZZ] - c1d[ZZ]) * pars.irc2 * tmp;
200             tmp2               = 1.0F - sinpsi * sinpsi;
201
202             const float cospsi = tmp2 * cl::sycl::rsqrt(tmp2);
203
204             const float a2d_y = pars.ra * cosphi;
205             const float b2d_x = -pars.rc * cospsi;
206             const float t1    = -pars.rb * cosphi;
207             const float t2    = pars.rc * sinpsi * sinphi;
208             const float b2d_y = t1 - t2;
209             const float c2d_y = t1 + t2;
210
211             /*     --- Step3  al,be,ga            --- */
212             const float alpha = b2d_x * (b0d[XX] - c0d[XX]) + b0d[YY] * b2d_y + c0d[YY] * c2d_y;
213             const float beta  = b2d_x * (c0d[YY] - b0d[YY]) + b0d[XX] * b2d_y + c0d[XX] * c2d_y;
214             const float gamma =
215                     b0d[XX] * b1d[YY] - b1d[XX] * b0d[YY] + c0d[XX] * c1d[YY] - c1d[XX] * c0d[YY];
216             const float al2be2 = alpha * alpha + beta * beta;
217             tmp2               = (al2be2 - gamma * gamma);
218             const float sinthe = (alpha * gamma - beta * tmp2 * cl::sycl::rsqrt(tmp2))
219                                  * cl::sycl::rsqrt(al2be2 * al2be2);
220
221             /*  --- Step4  A3' --- */
222             tmp2         = 1.0F - sinthe * sinthe;
223             float costhe = tmp2 * cl::sycl::rsqrt(tmp2);
224
225             Float3 a3d, b3d, c3d;
226
227             a3d[XX] = -a2d_y * sinthe;
228             a3d[YY] = a2d_y * costhe;
229             a3d[ZZ] = a1d_z;
230             b3d[XX] = b2d_x * costhe - b2d_y * sinthe;
231             b3d[YY] = b2d_x * sinthe + b2d_y * costhe;
232             b3d[ZZ] = b1d[ZZ];
233             c3d[XX] = -b2d_x * costhe - c2d_y * sinthe;
234             c3d[YY] = -b2d_x * sinthe + c2d_y * costhe;
235             c3d[ZZ] = c1d[ZZ];
236
237             /*    --- Step5  A3 --- */
238             Float3 a3, b3, c3;
239
240             a3[XX] = trns1[XX] * a3d[XX] + trns1[YY] * a3d[YY] + trns1[ZZ] * a3d[ZZ];
241             a3[YY] = trns2[XX] * a3d[XX] + trns2[YY] * a3d[YY] + trns2[ZZ] * a3d[ZZ];
242             a3[ZZ] = trns3[XX] * a3d[XX] + trns3[YY] * a3d[YY] + trns3[ZZ] * a3d[ZZ];
243
244             b3[XX] = trns1[XX] * b3d[XX] + trns1[YY] * b3d[YY] + trns1[ZZ] * b3d[ZZ];
245             b3[YY] = trns2[XX] * b3d[XX] + trns2[YY] * b3d[YY] + trns2[ZZ] * b3d[ZZ];
246             b3[ZZ] = trns3[XX] * b3d[XX] + trns3[YY] * b3d[YY] + trns3[ZZ] * b3d[ZZ];
247
248             c3[XX] = trns1[XX] * c3d[XX] + trns1[YY] * c3d[YY] + trns1[ZZ] * c3d[ZZ];
249             c3[YY] = trns2[XX] * c3d[XX] + trns2[YY] * c3d[YY] + trns2[ZZ] * c3d[ZZ];
250             c3[ZZ] = trns3[XX] * c3d[XX] + trns3[YY] * c3d[YY] + trns3[ZZ] * c3d[ZZ];
251
252
253             /* Compute and store the corrected new coordinate */
254             const Float3 dxOw1 = a3 - a1;
255             const Float3 dxHw2 = b3 - b1;
256             const Float3 dxHw3 = c3 - c1;
257
258             a_xp[indices.ow1] = xprime_ow1 + dxOw1;
259             a_xp[indices.hw2] = xprime_hw2 + dxHw2;
260             a_xp[indices.hw3] = xprime_hw3 + dxHw3;
261
262             if constexpr (updateVelocities)
263             {
264                 Float3 v_ow1 = a_v[indices.ow1];
265                 Float3 v_hw2 = a_v[indices.hw2];
266                 Float3 v_hw3 = a_v[indices.hw3];
267
268                 /* Add the position correction divided by dt to the velocity */
269                 v_ow1 = dxOw1 * invdt + v_ow1;
270                 v_hw2 = dxHw2 * invdt + v_hw2;
271                 v_hw3 = dxHw3 * invdt + v_hw3;
272
273                 a_v[indices.ow1] = v_ow1;
274                 a_v[indices.hw2] = v_hw2;
275                 a_v[indices.hw3] = v_hw3;
276             }
277
278             if constexpr (computeVirial)
279             {
280                 Float3 mdb = pars.mH * dxHw2;
281                 Float3 mdc = pars.mH * dxHw3;
282                 Float3 mdo = pars.mO * dxOw1 + mdb + mdc;
283
284                 sm_threadVirial[0 * sc_workGroupSize + threadIdx] =
285                         -(x_ow1[0] * mdo[0] + dist21[0] * mdb[0] + dist31[0] * mdc[0]);
286                 sm_threadVirial[1 * sc_workGroupSize + threadIdx] =
287                         -(x_ow1[0] * mdo[1] + dist21[0] * mdb[1] + dist31[0] * mdc[1]);
288                 sm_threadVirial[2 * sc_workGroupSize + threadIdx] =
289                         -(x_ow1[0] * mdo[2] + dist21[0] * mdb[2] + dist31[0] * mdc[2]);
290                 sm_threadVirial[3 * sc_workGroupSize + threadIdx] =
291                         -(x_ow1[1] * mdo[1] + dist21[1] * mdb[1] + dist31[1] * mdc[1]);
292                 sm_threadVirial[4 * sc_workGroupSize + threadIdx] =
293                         -(x_ow1[1] * mdo[2] + dist21[1] * mdb[2] + dist31[1] * mdc[2]);
294                 sm_threadVirial[5 * sc_workGroupSize + threadIdx] =
295                         -(x_ow1[2] * mdo[2] + dist21[2] * mdb[2] + dist31[2] * mdc[2]);
296             }
297         }
298         else // settleIdx < numSettles
299         {
300             // Filling data for dummy threads with zeroes
301             if constexpr (computeVirial)
302             {
303                 for (int d = 0; d < 6; d++)
304                 {
305                     sm_threadVirial[d * sc_workGroupSize + threadIdx] = 0.0F;
306                 }
307             }
308         }
309
310         // Basic reduction for the values inside single thread block
311         // TODO what follows should be separated out as a standard virial reduction subroutine
312         if constexpr (computeVirial)
313         {
314             // This is to ensure that all threads saved the data before reduction starts
315             subGroupBarrier(itemIdx);
316             constexpr int blockSize    = sc_workGroupSize;
317             const int     subGroupSize = itemIdx.get_sub_group().get_max_local_range()[0];
318             // Reduce up to one virial per thread block
319             // All blocks are divided by half, the first half of threads sums
320             // two virials. Then the first half is divided by two and the first half
321             // of it sums two values... The procedure continues until only one thread left.
322             // Only works if the threads per blocks is a power of two, hence the assertion.
323             static_assert(gmx::isPowerOfTwo(sc_workGroupSize));
324             for (int divideBy = 2; divideBy <= blockSize; divideBy *= 2)
325             {
326                 int dividedAt = blockSize / divideBy;
327                 if (threadIdx < dividedAt)
328                 {
329                     for (int d = 0; d < 6; d++)
330                     {
331                         sm_threadVirial[d * blockSize + threadIdx] +=
332                                 sm_threadVirial[d * blockSize + (threadIdx + dividedAt)];
333                     }
334                 }
335                 if (dividedAt > subGroupSize / 2)
336                 {
337                     subGroupBarrier(itemIdx);
338                 }
339             }
340             // First 6 threads in the block add the 6 components of virial to the global memory address
341             if (threadIdx < 6)
342             {
343                 atomicFetchAdd(a_virialScaled[threadIdx], sm_threadVirial[threadIdx * blockSize]);
344             }
345         }
346     };
347 }
348
349 // SYCL 1.2.1 requires providing a unique type for a kernel. Should not be needed for SYCL2020.
350 template<bool updateVelocities, bool computeVirial>
351 class SettleKernelName;
352
353 //! \brief SETTLE SYCL kernel launch code.
354 template<bool updateVelocities, bool computeVirial, class... Args>
355 static cl::sycl::event launchSettleKernel(const DeviceStream& deviceStream, int numSettles, Args&&... args)
356 {
357     // Should not be needed for SYCL2020.
358     using kernelNameType = SettleKernelName<updateVelocities, computeVirial>;
359
360     const int numSettlesRoundedUp =
361             static_cast<int>((numSettles + sc_workGroupSize - 1) / sc_workGroupSize) * sc_workGroupSize;
362     const cl::sycl::nd_range<1> rangeAllSettles(numSettlesRoundedUp, sc_workGroupSize);
363     cl::sycl::queue             q = deviceStream.stream();
364
365     cl::sycl::event e = q.submit([&](cl::sycl::handler& cgh) {
366         auto kernel = settleKernel<updateVelocities, computeVirial>(
367                 cgh, numSettles, std::forward<Args>(args)...);
368         cgh.parallel_for<kernelNameType>(rangeAllSettles, kernel);
369     });
370
371     return e;
372 }
373
374 /*! \brief Select templated kernel and launch it. */
375 template<class... Args>
376 static inline cl::sycl::event launchSettleKernel(bool updateVelocities, bool computeVirial, Args&&... args)
377 {
378     return dispatchTemplatedFunction(
379             [&](auto updateVelocities_, auto computeVirial_) {
380                 return launchSettleKernel<updateVelocities_, computeVirial_>(std::forward<Args>(args)...);
381             },
382             updateVelocities,
383             computeVirial);
384 }
385
386
387 void launchSettleGpuKernel(const int                          numSettles,
388                            const DeviceBuffer<WaterMolecule>& d_settles,
389                            const SettleParameters&            settleParameters,
390                            const DeviceBuffer<Float3>&        d_x,
391                            DeviceBuffer<Float3>               d_xp,
392                            const bool                         updateVelocities,
393                            DeviceBuffer<Float3>               d_v,
394                            const real                         invdt,
395                            const bool                         computeVirial,
396                            DeviceBuffer<float>                virialScaled,
397                            const PbcAiuc&                     pbcAiuc,
398                            const DeviceStream&                deviceStream)
399 {
400
401     launchSettleKernel(updateVelocities,
402                        computeVirial,
403                        deviceStream,
404                        numSettles,
405                        d_settles,
406                        settleParameters,
407                        d_x,
408                        d_xp,
409                        invdt,
410                        d_v,
411                        virialScaled,
412                        pbcAiuc);
413     return;
414 }
415
416 } // namespace gmx