Make stepWorkload.useGpuXBufferOps flag consistent
[alexxy/gromacs.git] / src / gromacs / mdlib / settle_gpu_internal.cu
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019,2020,2021, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  *
37  * \brief CUDA-specific routines for the GPU implementation of SETTLE constraints algorithm.
38  *
39  *
40  * \author Artem Zhmurov <zhmurov@gmail.com>
41  *
42  * \ingroup module_mdlib
43  */
44 #include "gmxpre.h"
45
46 #include "settle_gpu_internal.h"
47
48 #include <assert.h>
49 #include <stdio.h>
50
51 #include <cmath>
52
53 #include <algorithm>
54
55 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
56 #include "gromacs/gpu_utils/cudautils.cuh"
57 #include "gromacs/gpu_utils/devicebuffer.h"
58 #include "gromacs/gpu_utils/gputraits.h"
59 #include "gromacs/gpu_utils/typecasts.cuh"
60 #include "gromacs/gpu_utils/vectype_ops.cuh"
61 #include "gromacs/math/functions.h"
62 #include "gromacs/math/vec.h"
63 #include "gromacs/pbcutil/pbc.h"
64 #include "gromacs/pbcutil/pbc_aiuc_cuda.cuh"
65
66 namespace gmx
67 {
68
69 //! Number of CUDA threads in a block
70 constexpr static int sc_threadsPerBlock = 256;
71
72 //! Maximum number of threads in a block (for __launch_bounds__)
73 constexpr static int sc_maxThreadsPerBlock = sc_threadsPerBlock;
74
75 /*! \brief SETTLE constraints kernel
76  *
77  * Each thread corresponds to a single constraints triangle (i.e. single water molecule).
78  *
79  * See original CPU version in settle.cpp
80  *
81  * \param [in]      numSettles       Number of constraints triangles (water molecules).
82  * \param [in]      gm_settles       Indexes of three atoms in the constraints triangle. The field .x of int3
83  *                                   data type corresponds to Oxygen, fields .y and .z are two hydrogen atoms.
84  * \param [in]      pars             Parameters for the algorithm (i.e. masses, target distances, etc.).
85  * \param [in]      gm_x             Coordinates of atoms before the timestep.
86  * \param [in,out]  gm_x             Coordinates of atoms after the timestep (constrained coordinates will be
87  *                                   saved here).
88  * \param [in]      invdt            Reciprocal timestep.
89  * \param [in]      gm_v             Velocities of the particles.
90  * \param [in]      gm_virialScaled  Virial tensor.
91  * \param [in]      pbcAiuc          Periodic boundary conditions data.
92  */
93 template<bool updateVelocities, bool computeVirial>
94 __launch_bounds__(sc_maxThreadsPerBlock) __global__
95         void settle_kernel(const int numSettles,
96                            const WaterMolecule* __restrict__ gm_settles,
97                            const SettleParameters pars,
98                            const float3* __restrict__ gm_x,
99                            float3* __restrict__ gm_xprime,
100                            float invdt,
101                            float3* __restrict__ gm_v,
102                            float* __restrict__ gm_virialScaled,
103                            const PbcAiuc pbcAiuc)
104 {
105     /* ******************************************************************* */
106     /*                                                                  ** */
107     /*    Original code by Shuichi Miyamoto, last update Oct. 1, 1992   ** */
108     /*                                                                  ** */
109     /*    Algorithm changes by Berk Hess:                               ** */
110     /*    2004-07-15 Convert COM to double precision to avoid drift     ** */
111     /*    2006-10-16 Changed velocity update to use differences         ** */
112     /*    2012-09-24 Use oxygen as reference instead of COM             ** */
113     /*    2016-02    Complete rewrite of the code for SIMD              ** */
114     /*    2020-06    Completely remove use of COM to minimize drift     ** */
115     /*                                                                  ** */
116     /*    Reference for the SETTLE algorithm                            ** */
117     /*           S. Miyamoto et al., J. Comp. Chem., 13, 952 (1992).    ** */
118     /*                                                                  ** */
119     /* ******************************************************************* */
120
121     constexpr float almost_zero = real(1e-12);
122
123     extern __shared__ float sm_threadVirial[];
124
125     int tid = static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x);
126
127     if (tid < numSettles)
128     {
129         // These are the indexes of three atoms in a single 'water' molecule.
130         // TODO Can be reduced to one integer if atoms are consecutive in memory.
131         WaterMolecule indices = gm_settles[tid];
132
133         float3 x_ow1 = gm_x[indices.ow1];
134         float3 x_hw2 = gm_x[indices.hw2];
135         float3 x_hw3 = gm_x[indices.hw3];
136
137         float3 xprime_ow1 = gm_xprime[indices.ow1];
138         float3 xprime_hw2 = gm_xprime[indices.hw2];
139         float3 xprime_hw3 = gm_xprime[indices.hw3];
140
141         float3 dist21 = pbcDxAiuc(pbcAiuc, x_hw2, x_ow1);
142         float3 dist31 = pbcDxAiuc(pbcAiuc, x_hw3, x_ow1);
143         float3 doh2   = pbcDxAiuc(pbcAiuc, xprime_hw2, xprime_ow1);
144
145         float3 doh3 = pbcDxAiuc(pbcAiuc, xprime_hw3, xprime_ow1);
146
147         float3 a1 = (-doh2 - doh3) * pars.wh;
148
149         float3 b1 = doh2 + a1;
150
151         float3 c1 = doh3 + a1;
152
153         float xakszd = dist21.y * dist31.z - dist21.z * dist31.y;
154         float yakszd = dist21.z * dist31.x - dist21.x * dist31.z;
155         float zakszd = dist21.x * dist31.y - dist21.y * dist31.x;
156
157         float xaksxd = a1.y * zakszd - a1.z * yakszd;
158         float yaksxd = a1.z * xakszd - a1.x * zakszd;
159         float zaksxd = a1.x * yakszd - a1.y * xakszd;
160
161         float xaksyd = yakszd * zaksxd - zakszd * yaksxd;
162         float yaksyd = zakszd * xaksxd - xakszd * zaksxd;
163         float zaksyd = xakszd * yaksxd - yakszd * xaksxd;
164
165         float axlng = rsqrt(xaksxd * xaksxd + yaksxd * yaksxd + zaksxd * zaksxd);
166         float aylng = rsqrt(xaksyd * xaksyd + yaksyd * yaksyd + zaksyd * zaksyd);
167         float azlng = rsqrt(xakszd * xakszd + yakszd * yakszd + zakszd * zakszd);
168
169         // TODO {1,2,3} indexes should be swapped with {.x, .y, .z} components.
170         //      This way, we will be able to use vector ops more.
171         float3 trns1, trns2, trns3;
172
173         trns1.x = xaksxd * axlng;
174         trns2.x = yaksxd * axlng;
175         trns3.x = zaksxd * axlng;
176
177         trns1.y = xaksyd * aylng;
178         trns2.y = yaksyd * aylng;
179         trns3.y = zaksyd * aylng;
180
181         trns1.z = xakszd * azlng;
182         trns2.z = yakszd * azlng;
183         trns3.z = zakszd * azlng;
184
185
186         float2 b0d, c0d;
187
188         b0d.x = trns1.x * dist21.x + trns2.x * dist21.y + trns3.x * dist21.z;
189         b0d.y = trns1.y * dist21.x + trns2.y * dist21.y + trns3.y * dist21.z;
190
191         c0d.x = trns1.x * dist31.x + trns2.x * dist31.y + trns3.x * dist31.z;
192         c0d.y = trns1.y * dist31.x + trns2.y * dist31.y + trns3.y * dist31.z;
193
194         float3 b1d, c1d;
195
196         float a1d_z = trns1.z * a1.x + trns2.z * a1.y + trns3.z * a1.z;
197
198         b1d.x = trns1.x * b1.x + trns2.x * b1.y + trns3.x * b1.z;
199         b1d.y = trns1.y * b1.x + trns2.y * b1.y + trns3.y * b1.z;
200         b1d.z = trns1.z * b1.x + trns2.z * b1.y + trns3.z * b1.z;
201
202         c1d.x = trns1.x * c1.x + trns2.x * c1.y + trns3.x * c1.z;
203         c1d.y = trns1.y * c1.x + trns2.y * c1.y + trns3.y * c1.z;
204         c1d.z = trns1.z * c1.x + trns2.z * c1.y + trns3.z * c1.z;
205
206
207         float sinphi = a1d_z * rsqrt(pars.ra * pars.ra);
208         float tmp2   = 1.0F - sinphi * sinphi;
209
210         if (almost_zero > tmp2)
211         {
212             tmp2 = almost_zero;
213         }
214
215         float tmp    = rsqrt(tmp2);
216         float cosphi = tmp2 * tmp;
217         float sinpsi = (b1d.z - c1d.z) * pars.irc2 * tmp;
218         tmp2         = 1.0F - sinpsi * sinpsi;
219
220         float cospsi = tmp2 * rsqrt(tmp2);
221
222         float a2d_y = pars.ra * cosphi;
223         float b2d_x = -pars.rc * cospsi;
224         float t1    = -pars.rb * cosphi;
225         float t2    = pars.rc * sinpsi * sinphi;
226         float b2d_y = t1 - t2;
227         float c2d_y = t1 + t2;
228
229         /*     --- Step3  al,be,ga            --- */
230         float alpha  = b2d_x * (b0d.x - c0d.x) + b0d.y * b2d_y + c0d.y * c2d_y;
231         float beta   = b2d_x * (c0d.y - b0d.y) + b0d.x * b2d_y + c0d.x * c2d_y;
232         float gamma  = b0d.x * b1d.y - b1d.x * b0d.y + c0d.x * c1d.y - c1d.x * c0d.y;
233         float al2be2 = alpha * alpha + beta * beta;
234         tmp2         = (al2be2 - gamma * gamma);
235         float sinthe = (alpha * gamma - beta * tmp2 * rsqrt(tmp2)) * rsqrt(al2be2 * al2be2);
236
237         /*  --- Step4  A3' --- */
238         tmp2         = 1.0F - sinthe * sinthe;
239         float costhe = tmp2 * rsqrt(tmp2);
240
241         float3 a3d, b3d, c3d;
242
243         a3d.x = -a2d_y * sinthe;
244         a3d.y = a2d_y * costhe;
245         a3d.z = a1d_z;
246         b3d.x = b2d_x * costhe - b2d_y * sinthe;
247         b3d.y = b2d_x * sinthe + b2d_y * costhe;
248         b3d.z = b1d.z;
249         c3d.x = -b2d_x * costhe - c2d_y * sinthe;
250         c3d.y = -b2d_x * sinthe + c2d_y * costhe;
251         c3d.z = c1d.z;
252
253         /*    --- Step5  A3 --- */
254         float3 a3, b3, c3;
255
256         a3.x = trns1.x * a3d.x + trns1.y * a3d.y + trns1.z * a3d.z;
257         a3.y = trns2.x * a3d.x + trns2.y * a3d.y + trns2.z * a3d.z;
258         a3.z = trns3.x * a3d.x + trns3.y * a3d.y + trns3.z * a3d.z;
259
260         b3.x = trns1.x * b3d.x + trns1.y * b3d.y + trns1.z * b3d.z;
261         b3.y = trns2.x * b3d.x + trns2.y * b3d.y + trns2.z * b3d.z;
262         b3.z = trns3.x * b3d.x + trns3.y * b3d.y + trns3.z * b3d.z;
263
264         c3.x = trns1.x * c3d.x + trns1.y * c3d.y + trns1.z * c3d.z;
265         c3.y = trns2.x * c3d.x + trns2.y * c3d.y + trns2.z * c3d.z;
266         c3.z = trns3.x * c3d.x + trns3.y * c3d.y + trns3.z * c3d.z;
267
268
269         /* Compute and store the corrected new coordinate */
270         const float3 dxOw1 = a3 - a1;
271         const float3 dxHw2 = b3 - b1;
272         const float3 dxHw3 = c3 - c1;
273
274         gm_xprime[indices.ow1] = xprime_ow1 + dxOw1;
275         gm_xprime[indices.hw2] = xprime_hw2 + dxHw2;
276         gm_xprime[indices.hw3] = xprime_hw3 + dxHw3;
277
278         if (updateVelocities)
279         {
280             float3 v_ow1 = gm_v[indices.ow1];
281             float3 v_hw2 = gm_v[indices.hw2];
282             float3 v_hw3 = gm_v[indices.hw3];
283
284             /* Add the position correction divided by dt to the velocity */
285             v_ow1 = dxOw1 * invdt + v_ow1;
286             v_hw2 = dxHw2 * invdt + v_hw2;
287             v_hw3 = dxHw3 * invdt + v_hw3;
288
289             gm_v[indices.ow1] = v_ow1;
290             gm_v[indices.hw2] = v_hw2;
291             gm_v[indices.hw3] = v_hw3;
292         }
293
294         if (computeVirial)
295         {
296             float3 mdb = pars.mH * dxHw2;
297             float3 mdc = pars.mH * dxHw3;
298             float3 mdo = pars.mO * dxOw1 + mdb + mdc;
299
300             sm_threadVirial[0 * blockDim.x + threadIdx.x] =
301                     -(x_ow1.x * mdo.x + dist21.x * mdb.x + dist31.x * mdc.x);
302             sm_threadVirial[1 * blockDim.x + threadIdx.x] =
303                     -(x_ow1.x * mdo.y + dist21.x * mdb.y + dist31.x * mdc.y);
304             sm_threadVirial[2 * blockDim.x + threadIdx.x] =
305                     -(x_ow1.x * mdo.z + dist21.x * mdb.z + dist31.x * mdc.z);
306             sm_threadVirial[3 * blockDim.x + threadIdx.x] =
307                     -(x_ow1.y * mdo.y + dist21.y * mdb.y + dist31.y * mdc.y);
308             sm_threadVirial[4 * blockDim.x + threadIdx.x] =
309                     -(x_ow1.y * mdo.z + dist21.y * mdb.z + dist31.y * mdc.z);
310             sm_threadVirial[5 * blockDim.x + threadIdx.x] =
311                     -(x_ow1.z * mdo.z + dist21.z * mdb.z + dist31.z * mdc.z);
312         }
313     }
314     else
315     {
316         // Filling data for dummy threads with zeroes
317         if (computeVirial)
318         {
319             for (int d = 0; d < 6; d++)
320             {
321                 sm_threadVirial[d * blockDim.x + threadIdx.x] = 0.0F;
322             }
323         }
324     }
325     // Basic reduction for the values inside single thread block
326     // TODO what follows should be separated out as a standard virial reduction subroutine
327     if (computeVirial)
328     {
329         // This is to ensure that all threads saved the data before reduction starts
330         __syncthreads();
331         // This casts unsigned into signed integers to avoid clang warnings
332         int tib       = static_cast<int>(threadIdx.x);
333         int blockSize = static_cast<int>(blockDim.x);
334         // Reduce up to one virial per thread block
335         // All blocks are divided by half, the first half of threads sums
336         // two virials. Then the first half is divided by two and the first half
337         // of it sums two values... The procedure continues until only one thread left.
338         // Only works if the threads per blocks is a power of two.
339         for (int divideBy = 2; divideBy <= blockSize; divideBy *= 2)
340         {
341             int dividedAt = blockSize / divideBy;
342             if (tib < dividedAt)
343             {
344                 for (int d = 0; d < 6; d++)
345                 {
346                     sm_threadVirial[d * blockSize + tib] +=
347                             sm_threadVirial[d * blockSize + (tib + dividedAt)];
348                 }
349             }
350             if (dividedAt > warpSize / 2)
351             {
352                 __syncthreads();
353             }
354         }
355         // First 6 threads in the block add the 6 components of virial to the global memory address
356         if (tib < 6)
357         {
358             atomicAdd(&(gm_virialScaled[tib]), sm_threadVirial[tib * blockSize]);
359         }
360     }
361 }
362
363 /*! \brief Select templated kernel.
364  *
365  * Returns pointer to a CUDA kernel based on provided booleans.
366  *
367  * \param[in] updateVelocities  If the velocities should be constrained.
368  * \param[in] bCalcVir          If virial should be updated.
369  *
370  * \retrun                      Pointer to CUDA kernel
371  */
372 inline auto getSettleKernelPtr(const bool updateVelocities, const bool computeVirial)
373 {
374
375     auto kernelPtr = settle_kernel<true, true>;
376     if (updateVelocities && computeVirial)
377     {
378         kernelPtr = settle_kernel<true, true>;
379     }
380     else if (updateVelocities && !computeVirial)
381     {
382         kernelPtr = settle_kernel<true, false>;
383     }
384     else if (!updateVelocities && computeVirial)
385     {
386         kernelPtr = settle_kernel<false, true>;
387     }
388     else if (!updateVelocities && !computeVirial)
389     {
390         kernelPtr = settle_kernel<false, false>;
391     }
392     return kernelPtr;
393 }
394
395 void launchSettleGpuKernel(const int                          numSettles,
396                            const DeviceBuffer<WaterMolecule>& d_atomIds,
397                            const SettleParameters&            settleParameters,
398                            const DeviceBuffer<Float3>&        d_x,
399                            DeviceBuffer<Float3>               d_xp,
400                            const bool                         updateVelocities,
401                            DeviceBuffer<Float3>               d_v,
402                            const real                         invdt,
403                            const bool                         computeVirial,
404                            DeviceBuffer<float>                virialScaled,
405                            const PbcAiuc&                     pbcAiuc,
406                            const DeviceStream&                deviceStream)
407 {
408     static_assert(
409             gmx::isPowerOfTwo(sc_threadsPerBlock),
410             "Number of threads per block should be a power of two in order for reduction to work.");
411
412     auto kernelPtr = getSettleKernelPtr(updateVelocities, computeVirial);
413
414     KernelLaunchConfig config;
415     config.blockSize[0] = sc_threadsPerBlock;
416     config.blockSize[1] = 1;
417     config.blockSize[2] = 1;
418     config.gridSize[0]  = (numSettles + sc_threadsPerBlock - 1) / sc_threadsPerBlock;
419     config.gridSize[1]  = 1;
420     config.gridSize[2]  = 1;
421
422     // Shared memory is only used for virial reduction
423     if (computeVirial)
424     {
425         config.sharedMemorySize = sc_threadsPerBlock * 6 * sizeof(float);
426     }
427     else
428     {
429         config.sharedMemorySize = 0;
430     }
431
432     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr,
433                                                       config,
434                                                       &numSettles,
435                                                       &d_atomIds,
436                                                       &settleParameters,
437                                                       asFloat3Pointer(&d_x),
438                                                       asFloat3Pointer(&d_xp),
439                                                       &invdt,
440                                                       asFloat3Pointer(&d_v),
441                                                       &virialScaled,
442                                                       &pbcAiuc);
443
444     launchGpuKernel(kernelPtr,
445                     config,
446                     deviceStream,
447                     nullptr,
448                     "settle_kernel<updateVelocities, computeVirial>",
449                     kernelArgs);
450 }
451
452 } // namespace gmx