Apply clang-format to source tree
[alexxy/gromacs.git] / src / gromacs / mdlib / settle_cuda.cu
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2019, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35 /*! \internal \file
36  *
37  * \brief Implements SETTLE using CUDA
38  *
39  * This file contains implementation of SETTLE constraints algorithm
40  * using CUDA, including class initialization, data-structures management
41  * and GPU kernel.
42  *
43  * \note Management of CUDA stream and periodic boundary should be unified with LINCS
44  *       and removed from here once constraints are fully integrated with update module.
45  * \todo Reconsider naming to use "gpu" suffix instead of "cuda".
46  *
47  * \author Artem Zhmurov <zhmurov@gmail.com>
48  *
49  * \ingroup module_mdlib
50  */
51 #include "gmxpre.h"
52
53 #include "settle_cuda.cuh"
54
55 #include <assert.h>
56 #include <stdio.h>
57
58 #include <cmath>
59
60 #include <algorithm>
61
62 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
63 #include "gromacs/gpu_utils/cudautils.cuh"
64 #include "gromacs/gpu_utils/devicebuffer.h"
65 #include "gromacs/gpu_utils/gputraits.cuh"
66 #include "gromacs/gpu_utils/vectype_ops.cuh"
67 #include "gromacs/math/vec.h"
68 #include "gromacs/pbcutil/pbc.h"
69 #include "gromacs/pbcutil/pbc_aiuc_cuda.cuh"
70
71 namespace gmx
72 {
73
74 //! Number of CUDA threads in a block
75 constexpr static int c_threadsPerBlock = 256;
76 //! Maximum number of threads in a block (for __launch_bounds__)
77 constexpr static int c_maxThreadsPerBlock = c_threadsPerBlock;
78
79 /*! \brief SETTLE constraints kernel
80  *
81  * Each thread corresponds to a single constraints triangle (i.e. single water molecule).
82  *
83  * See original CPU version in settle.cpp
84  *
85  * \param [in]      numSettles       Number of constraints triangles (water molecules).
86  * \param [in]      gm_settles       Indexes of three atoms in the constraints triangle. The field .x of int3
87  *                                   data type corresponds to Oxygen, fields .y and .z are two hydrogen atoms.
88  * \param [in]      pars             Parameters for the algorithm (i.e. masses, target distances, etc.).
89  * \param [in]      gm_x             Coordinates of atoms before the timestep.
90  * \param [in,out]  gm_x             Coordinates of atoms after the timestep (constrained coordinates will be
91  *                                   saved here).
92  * \param [in]      pbcAiuc          Periodic boundary conditions data.
93  * \param [in]      invdt            Reciprocal timestep.
94  * \param [in]      gm_v             Velocities of the particles.
95  * \param [in]      gm_virialScaled  Virial tensor.
96  */
97 template<bool updateVelocities, bool computeVirial>
98 __launch_bounds__(c_maxThreadsPerBlock) __global__
99         void settle_kernel(const int numSettles,
100                            const int3* __restrict__ gm_settles,
101                            const SettleParameters pars,
102                            const float3* __restrict__ gm_x,
103                            float3* __restrict__ gm_xprime,
104                            const PbcAiuc pbcAiuc,
105                            float         invdt,
106                            float3* __restrict__ gm_v,
107                            float* __restrict__ gm_virialScaled)
108 {
109     /* ******************************************************************* */
110     /*                                                                  ** */
111     /*    Original code by Shuichi Miyamoto, last update Oct. 1, 1992   ** */
112     /*                                                                  ** */
113     /*    Algorithm changes by Berk Hess:                               ** */
114     /*    2004-07-15 Convert COM to double precision to avoid drift     ** */
115     /*    2006-10-16 Changed velocity update to use differences         ** */
116     /*    2012-09-24 Use oxygen as reference instead of COM             ** */
117     /*    2016-02    Complete rewrite of the code for SIMD              ** */
118     /*                                                                  ** */
119     /*    Reference for the SETTLE algorithm                            ** */
120     /*           S. Miyamoto et al., J. Comp. Chem., 13, 952 (1992).    ** */
121     /*                                                                  ** */
122     /* ******************************************************************* */
123
124     constexpr float almost_zero = real(1e-12);
125
126     extern __shared__ float sm_threadVirial[];
127
128     int tid = static_cast<int>(blockIdx.x * blockDim.x + threadIdx.x);
129
130     if (tid < numSettles)
131     {
132         // These are the indexes of three atoms in a single 'water' molecule.
133         // TODO Can be reduced to one integer if atoms are consecutive in memory.
134         int3 indices = gm_settles[tid];
135
136         float3 x_ow1 = gm_x[indices.x];
137         float3 x_hw2 = gm_x[indices.y];
138         float3 x_hw3 = gm_x[indices.z];
139
140         float3 xprime_ow1 = gm_xprime[indices.x];
141         float3 xprime_hw2 = gm_xprime[indices.y];
142         float3 xprime_hw3 = gm_xprime[indices.z];
143
144         float3 dist21 = pbcDxAiuc(pbcAiuc, x_hw2, x_ow1);
145         float3 dist31 = pbcDxAiuc(pbcAiuc, x_hw3, x_ow1);
146         float3 doh2   = pbcDxAiuc(pbcAiuc, xprime_hw2, xprime_ow1);
147
148         float3 sh_hw2 = xprime_hw2 - (xprime_ow1 + doh2);
149         xprime_hw2    = xprime_hw2 - sh_hw2;
150
151         float3 doh3 = pbcDxAiuc(pbcAiuc, xprime_hw3, xprime_ow1);
152
153         float3 sh_hw3 = xprime_hw3 - (xprime_ow1 + doh3);
154         xprime_hw3    = xprime_hw3 - sh_hw3;
155
156         float3 a1  = (-doh2 - doh3) * pars.wh;
157         float3 com = xprime_ow1 - a1;
158
159         float3 b1 = xprime_hw2 - com;
160
161         float3 c1 = xprime_hw3 - com;
162
163         float xakszd = dist21.y * dist31.z - dist21.z * dist31.y;
164         float yakszd = dist21.z * dist31.x - dist21.x * dist31.z;
165         float zakszd = dist21.x * dist31.y - dist21.y * dist31.x;
166
167         float xaksxd = a1.y * zakszd - a1.z * yakszd;
168         float yaksxd = a1.z * xakszd - a1.x * zakszd;
169         float zaksxd = a1.x * yakszd - a1.y * xakszd;
170
171         float xaksyd = yakszd * zaksxd - zakszd * yaksxd;
172         float yaksyd = zakszd * xaksxd - xakszd * zaksxd;
173         float zaksyd = xakszd * yaksxd - yakszd * xaksxd;
174
175         float axlng = rsqrt(xaksxd * xaksxd + yaksxd * yaksxd + zaksxd * zaksxd);
176         float aylng = rsqrt(xaksyd * xaksyd + yaksyd * yaksyd + zaksyd * zaksyd);
177         float azlng = rsqrt(xakszd * xakszd + yakszd * yakszd + zakszd * zakszd);
178
179         // TODO {1,2,3} indexes should be swapped with {.x, .y, .z} components.
180         //      This way, we will be able to use vector ops more.
181         float3 trns1, trns2, trns3;
182
183         trns1.x = xaksxd * axlng;
184         trns2.x = yaksxd * axlng;
185         trns3.x = zaksxd * axlng;
186
187         trns1.y = xaksyd * aylng;
188         trns2.y = yaksyd * aylng;
189         trns3.y = zaksyd * aylng;
190
191         trns1.z = xakszd * azlng;
192         trns2.z = yakszd * azlng;
193         trns3.z = zakszd * azlng;
194
195
196         float2 b0d, c0d;
197
198         b0d.x = trns1.x * dist21.x + trns2.x * dist21.y + trns3.x * dist21.z;
199         b0d.y = trns1.y * dist21.x + trns2.y * dist21.y + trns3.y * dist21.z;
200
201         c0d.x = trns1.x * dist31.x + trns2.x * dist31.y + trns3.x * dist31.z;
202         c0d.y = trns1.y * dist31.x + trns2.y * dist31.y + trns3.y * dist31.z;
203
204         float3 b1d, c1d;
205
206         float a1d_z = trns1.z * a1.x + trns2.z * a1.y + trns3.z * a1.z;
207
208         b1d.x = trns1.x * b1.x + trns2.x * b1.y + trns3.x * b1.z;
209         b1d.y = trns1.y * b1.x + trns2.y * b1.y + trns3.y * b1.z;
210         b1d.z = trns1.z * b1.x + trns2.z * b1.y + trns3.z * b1.z;
211
212         c1d.x = trns1.x * c1.x + trns2.x * c1.y + trns3.x * c1.z;
213         c1d.y = trns1.y * c1.x + trns2.y * c1.y + trns3.y * c1.z;
214         c1d.z = trns1.z * c1.x + trns2.z * c1.y + trns3.z * c1.z;
215
216
217         float sinphi = a1d_z * rsqrt(pars.ra * pars.ra);
218         float tmp2   = 1.0f - sinphi * sinphi;
219
220         if (almost_zero > tmp2)
221         {
222             tmp2 = almost_zero;
223         }
224
225         float tmp    = rsqrt(tmp2);
226         float cosphi = tmp2 * tmp;
227         float sinpsi = (b1d.z - c1d.z) * pars.irc2 * tmp;
228         tmp2         = 1.0f - sinpsi * sinpsi;
229
230         float cospsi = tmp2 * rsqrt(tmp2);
231
232         float a2d_y = pars.ra * cosphi;
233         float b2d_x = -pars.rc * cospsi;
234         float t1    = -pars.rb * cosphi;
235         float t2    = pars.rc * sinpsi * sinphi;
236         float b2d_y = t1 - t2;
237         float c2d_y = t1 + t2;
238
239         /*     --- Step3  al,be,ga            --- */
240         float alpha  = b2d_x * (b0d.x - c0d.x) + b0d.y * b2d_y + c0d.y * c2d_y;
241         float beta   = b2d_x * (c0d.y - b0d.y) + b0d.x * b2d_y + c0d.x * c2d_y;
242         float gamma  = b0d.x * b1d.y - b1d.x * b0d.y + c0d.x * c1d.y - c1d.x * c0d.y;
243         float al2be2 = alpha * alpha + beta * beta;
244         tmp2         = (al2be2 - gamma * gamma);
245         float sinthe = (alpha * gamma - beta * tmp2 * rsqrt(tmp2)) * rsqrt(al2be2 * al2be2);
246
247         /*  --- Step4  A3' --- */
248         tmp2         = 1.0f - sinthe * sinthe;
249         float costhe = tmp2 * rsqrt(tmp2);
250
251         float3 a3d, b3d, c3d;
252
253         a3d.x = -a2d_y * sinthe;
254         a3d.y = a2d_y * costhe;
255         a3d.z = a1d_z;
256         b3d.x = b2d_x * costhe - b2d_y * sinthe;
257         b3d.y = b2d_x * sinthe + b2d_y * costhe;
258         b3d.z = b1d.z;
259         c3d.x = -b2d_x * costhe - c2d_y * sinthe;
260         c3d.y = -b2d_x * sinthe + c2d_y * costhe;
261         c3d.z = c1d.z;
262
263         /*    --- Step5  A3 --- */
264         float3 a3, b3, c3;
265
266         a3.x = trns1.x * a3d.x + trns1.y * a3d.y + trns1.z * a3d.z;
267         a3.y = trns2.x * a3d.x + trns2.y * a3d.y + trns2.z * a3d.z;
268         a3.z = trns3.x * a3d.x + trns3.y * a3d.y + trns3.z * a3d.z;
269
270         b3.x = trns1.x * b3d.x + trns1.y * b3d.y + trns1.z * b3d.z;
271         b3.y = trns2.x * b3d.x + trns2.y * b3d.y + trns2.z * b3d.z;
272         b3.z = trns3.x * b3d.x + trns3.y * b3d.y + trns3.z * b3d.z;
273
274         c3.x = trns1.x * c3d.x + trns1.y * c3d.y + trns1.z * c3d.z;
275         c3.y = trns2.x * c3d.x + trns2.y * c3d.y + trns2.z * c3d.z;
276         c3.z = trns3.x * c3d.x + trns3.y * c3d.y + trns3.z * c3d.z;
277
278
279         /* Compute and store the corrected new coordinate */
280         xprime_ow1 = com + a3;
281         xprime_hw2 = com + b3 + sh_hw2;
282         xprime_hw3 = com + c3 + sh_hw3;
283
284         gm_xprime[indices.x] = xprime_ow1;
285         gm_xprime[indices.y] = xprime_hw2;
286         gm_xprime[indices.z] = xprime_hw3;
287
288
289         if (updateVelocities || computeVirial)
290         {
291
292             float3 da = a3 - a1;
293             float3 db = b3 - b1;
294             float3 dc = c3 - c1;
295
296             if (updateVelocities)
297             {
298
299                 float3 v_ow1 = gm_v[indices.x];
300                 float3 v_hw2 = gm_v[indices.y];
301                 float3 v_hw3 = gm_v[indices.z];
302
303                 /* Add the position correction divided by dt to the velocity */
304                 v_ow1 = da * invdt + v_ow1;
305                 v_hw2 = db * invdt + v_hw2;
306                 v_hw3 = dc * invdt + v_hw3;
307
308                 gm_v[indices.x] = v_ow1;
309                 gm_v[indices.y] = v_hw2;
310                 gm_v[indices.z] = v_hw3;
311             }
312
313             if (computeVirial)
314             {
315
316                 float3 mdb = pars.mH * db;
317                 float3 mdc = pars.mH * dc;
318                 float3 mdo = pars.mO * da + mdb + mdc;
319
320                 sm_threadVirial[0 * blockDim.x + threadIdx.x] =
321                         -(x_ow1.x * mdo.x + dist21.x * mdb.x + dist31.x * mdc.x);
322                 sm_threadVirial[1 * blockDim.x + threadIdx.x] =
323                         -(x_ow1.x * mdo.y + dist21.x * mdb.y + dist31.x * mdc.y);
324                 sm_threadVirial[2 * blockDim.x + threadIdx.x] =
325                         -(x_ow1.x * mdo.z + dist21.x * mdb.z + dist31.x * mdc.z);
326                 sm_threadVirial[3 * blockDim.x + threadIdx.x] =
327                         -(x_ow1.y * mdo.y + dist21.y * mdb.y + dist31.y * mdc.y);
328                 sm_threadVirial[4 * blockDim.x + threadIdx.x] =
329                         -(x_ow1.y * mdo.z + dist21.y * mdb.z + dist31.y * mdc.z);
330                 sm_threadVirial[5 * blockDim.x + threadIdx.x] =
331                         -(x_ow1.z * mdo.z + dist21.z * mdb.z + dist31.z * mdc.z);
332             }
333         }
334     }
335     else
336     {
337         // Filling data for dummy threads with zeroes
338         if (computeVirial)
339         {
340             for (int d = 0; d < 6; d++)
341             {
342                 sm_threadVirial[d * blockDim.x + threadIdx.x] = 0.0f;
343             }
344         }
345     }
346     // Basic reduction for the values inside single thread block
347     // TODO what follows should be separated out as a standard virial reduction subroutine
348     if (computeVirial)
349     {
350         // This is to ensure that all threads saved the data before reduction starts
351         __syncthreads();
352         // This casts unsigned into signed integers to avoid clang warnings
353         int tib       = static_cast<int>(threadIdx.x);
354         int blockSize = static_cast<int>(blockDim.x);
355         // Reduce up to one virial per thread block
356         // All blocks are divided by half, the first half of threads sums
357         // two virials. Then the first half is divided by two and the first half
358         // of it sums two values... The procedure continues until only one thread left.
359         // Only works if the threads per blocks is a power of two.
360         for (int divideBy = 2; divideBy <= blockSize; divideBy *= 2)
361         {
362             int dividedAt = blockSize / divideBy;
363             if (tib < dividedAt)
364             {
365                 for (int d = 0; d < 6; d++)
366                 {
367                     sm_threadVirial[d * blockSize + tib] +=
368                             sm_threadVirial[d * blockSize + (tib + dividedAt)];
369                 }
370             }
371             if (dividedAt > warpSize / 2)
372             {
373                 __syncthreads();
374             }
375         }
376         // First 6 threads in the block add the 6 components of virial to the global memory address
377         if (tib < 6)
378         {
379             atomicAdd(&(gm_virialScaled[tib]), sm_threadVirial[tib * blockSize]);
380         }
381     }
382
383     return;
384 }
385
386 /*! \brief Select templated kernel.
387  *
388  * Returns pointer to a CUDA kernel based on provided booleans.
389  *
390  * \param[in] updateVelocities  If the velocities should be constrained.
391  * \param[in] bCalcVir          If virial should be updated.
392  *
393  * \retrun                      Pointer to CUDA kernel
394  */
395 inline auto getSettleKernelPtr(const bool updateVelocities, const bool computeVirial)
396 {
397
398     auto kernelPtr = settle_kernel<true, true>;
399     if (updateVelocities && computeVirial)
400     {
401         kernelPtr = settle_kernel<true, true>;
402     }
403     else if (updateVelocities && !computeVirial)
404     {
405         kernelPtr = settle_kernel<true, false>;
406     }
407     else if (!updateVelocities && computeVirial)
408     {
409         kernelPtr = settle_kernel<false, true>;
410     }
411     else if (!updateVelocities && !computeVirial)
412     {
413         kernelPtr = settle_kernel<false, false>;
414     }
415     return kernelPtr;
416 }
417
418 void SettleCuda::apply(const float3* d_x,
419                        float3*       d_xp,
420                        const bool    updateVelocities,
421                        float3*       d_v,
422                        const real    invdt,
423                        const bool    computeVirial,
424                        tensor        virialScaled)
425 {
426
427     ensureNoPendingCudaError("In CUDA version SETTLE");
428
429     // Early exit if no settles
430     if (numSettles_ == 0)
431     {
432         return;
433     }
434
435     if (computeVirial)
436     {
437         // Fill with zeros so the values can be reduced to it
438         // Only 6 values are needed because virial is symmetrical
439         clearDeviceBufferAsync(&d_virialScaled_, 0, 6, commandStream_);
440     }
441
442     auto kernelPtr = getSettleKernelPtr(updateVelocities, computeVirial);
443
444     KernelLaunchConfig config;
445     config.blockSize[0] = c_threadsPerBlock;
446     config.blockSize[1] = 1;
447     config.blockSize[2] = 1;
448     config.gridSize[0]  = (numSettles_ + c_threadsPerBlock - 1) / c_threadsPerBlock;
449     config.gridSize[1]  = 1;
450     config.gridSize[2]  = 1;
451     // Shared memory is only used for virial reduction
452     if (computeVirial)
453     {
454         config.sharedMemorySize = c_threadsPerBlock * 6 * sizeof(float);
455     }
456     else
457     {
458         config.sharedMemorySize = 0;
459     }
460     config.stream = commandStream_;
461
462     const auto kernelArgs = prepareGpuKernelArguments(kernelPtr, config, &numSettles_, &d_atomIds_,
463                                                       &settleParameters_, &d_x, &d_xp, &pbcAiuc_,
464                                                       &invdt, &d_v, &d_virialScaled_);
465
466     launchGpuKernel(kernelPtr, config, nullptr, "settle_kernel<updateVelocities, computeVirial>", kernelArgs);
467
468     if (computeVirial)
469     {
470         copyFromDeviceBuffer(h_virialScaled_.data(), &d_virialScaled_, 0, 6, commandStream_,
471                              GpuApiCallBehavior::Sync, nullptr);
472
473         // Mapping [XX, XY, XZ, YY, YZ, ZZ] internal format to a tensor object
474         virialScaled[XX][XX] += h_virialScaled_[0];
475         virialScaled[XX][YY] += h_virialScaled_[1];
476         virialScaled[XX][ZZ] += h_virialScaled_[2];
477
478         virialScaled[YY][XX] += h_virialScaled_[1];
479         virialScaled[YY][YY] += h_virialScaled_[3];
480         virialScaled[YY][ZZ] += h_virialScaled_[4];
481
482         virialScaled[ZZ][XX] += h_virialScaled_[2];
483         virialScaled[ZZ][YY] += h_virialScaled_[4];
484         virialScaled[ZZ][ZZ] += h_virialScaled_[5];
485     }
486
487     return;
488 }
489
490 SettleCuda::SettleCuda(const gmx_mtop_t& mtop, CommandStream commandStream) :
491     commandStream_(commandStream)
492 {
493     static_assert(sizeof(real) == sizeof(float),
494                   "Real numbers should be in single precision in GPU code.");
495     static_assert(
496             c_threadsPerBlock > 0 && ((c_threadsPerBlock & (c_threadsPerBlock - 1)) == 0),
497             "Number of threads per block should be a power of two in order for reduction to work.");
498
499     // This is to prevent the assertion failure for the systems without water
500     int totalSettles = 0;
501     for (unsigned mt = 0; mt < mtop.moltype.size(); mt++)
502     {
503         const int        nral1           = 1 + NRAL(F_SETTLE);
504         InteractionList  interactionList = mtop.moltype.at(mt).ilist[F_SETTLE];
505         std::vector<int> iatoms          = interactionList.iatoms;
506         totalSettles += iatoms.size() / nral1;
507     }
508     if (totalSettles == 0)
509     {
510         return;
511     }
512
513     // TODO This should be lifted to a separate subroutine that gets the values of Oxygen and
514     // Hydrogen masses, checks if they are consistent across the topology and if there is no more
515     // than two values for each mass if the free energy perturbation is enabled. In later case,
516     // masses may need to be updated on a regular basis (i.e. in set(...) method).
517     // TODO Do the checks for FEP
518     real mO = -1.0;
519     real mH = -1.0;
520
521     for (unsigned mt = 0; mt < mtop.moltype.size(); mt++)
522     {
523         const int        nral1           = 1 + NRAL(F_SETTLE);
524         InteractionList  interactionList = mtop.moltype.at(mt).ilist[F_SETTLE];
525         std::vector<int> iatoms          = interactionList.iatoms;
526         for (unsigned i = 0; i < iatoms.size() / nral1; i++)
527         {
528             int3 settler;
529             settler.x  = iatoms[i * nral1 + 1]; // Oxygen index
530             settler.y  = iatoms[i * nral1 + 2]; // First hydrogen index
531             settler.z  = iatoms[i * nral1 + 3]; // Second hydrogen index
532             t_atom ow1 = mtop.moltype.at(mt).atoms.atom[settler.x];
533             t_atom hw2 = mtop.moltype.at(mt).atoms.atom[settler.y];
534             t_atom hw3 = mtop.moltype.at(mt).atoms.atom[settler.z];
535
536             if (mO < 0)
537             {
538                 mO = ow1.m;
539             }
540             if (mH < 0)
541             {
542                 mH = hw2.m;
543             }
544             GMX_RELEASE_ASSERT(mO == ow1.m,
545                                "Topology has different values for oxygen mass. Should be identical "
546                                "in order to use SETTLE.");
547             GMX_RELEASE_ASSERT(hw2.m == hw3.m && hw2.m == mH,
548                                "Topology has different values for hydrogen mass. Should be "
549                                "identical in order to use SETTLE.");
550         }
551     }
552
553     GMX_RELEASE_ASSERT(mO > 0, "Could not find oxygen mass in the topology. Needed in SETTLE.");
554     GMX_RELEASE_ASSERT(mH > 0, "Could not find hydrogen mass in the topology. Needed in SETTLE.");
555
556     // TODO Very similar to SETTLE initialization on CPU. Should be lifted to a separate method
557     // (one that gets dOH and dHH values and checks them for consistency)
558     int settle_type = -1;
559     for (unsigned mt = 0; mt < mtop.moltype.size(); mt++)
560     {
561         const int       nral1           = 1 + NRAL(F_SETTLE);
562         InteractionList interactionList = mtop.moltype.at(mt).ilist[F_SETTLE];
563         for (int i = 0; i < interactionList.size(); i += nral1)
564         {
565             if (settle_type == -1)
566             {
567                 settle_type = interactionList.iatoms[i];
568             }
569             else if (interactionList.iatoms[i] != settle_type)
570             {
571                 gmx_fatal(FARGS,
572                           "The [molecules] section of your topology specifies more than one block "
573                           "of\n"
574                           "a [moleculetype] with a [settles] block. Only one such is allowed.\n"
575                           "If you are trying to partition your solvent into different *groups*\n"
576                           "(e.g. for freezing, T-coupling, etc.), you are using the wrong "
577                           "approach. Index\n"
578                           "files specify groups. Otherwise, you may wish to change the least-used\n"
579                           "block of molecules with SETTLE constraints into 3 normal constraints.");
580             }
581         }
582     }
583
584     GMX_RELEASE_ASSERT(settle_type >= 0, "settle_init called without settles");
585
586     real dOH = mtop.ffparams.iparams[settle_type].settle.doh;
587     real dHH = mtop.ffparams.iparams[settle_type].settle.dhh;
588
589     initSettleParameters(&settleParameters_, mO, mH, dOH, dHH);
590
591     allocateDeviceBuffer(&d_virialScaled_, 6, nullptr);
592     h_virialScaled_.resize(6);
593 }
594
595 SettleCuda::~SettleCuda()
596 {
597     // Early exit if there is no settles
598     if (numSettles_ == 0)
599     {
600         return;
601     }
602     freeDeviceBuffer(&d_virialScaled_);
603     if (numAtomIdsAlloc_ > 0)
604     {
605         freeDeviceBuffer(&d_atomIds_);
606     }
607 }
608
609 void SettleCuda::set(const t_idef& idef, const t_mdatoms gmx_unused& md)
610 {
611     const int nral1     = 1 + NRAL(F_SETTLE);
612     t_ilist   il_settle = idef.il[F_SETTLE];
613     t_iatom*  iatoms    = il_settle.iatoms;
614     numSettles_         = il_settle.nr / nral1;
615
616     reallocateDeviceBuffer(&d_atomIds_, numSettles_, &numAtomIds_, &numAtomIdsAlloc_, nullptr);
617     h_atomIds_.resize(numSettles_);
618     for (int i = 0; i < numSettles_; i++)
619     {
620         int3 settler;
621         settler.x        = iatoms[i * nral1 + 1]; // Oxygen index
622         settler.y        = iatoms[i * nral1 + 2]; // First hydrogen index
623         settler.z        = iatoms[i * nral1 + 3]; // Second hydrogen index
624         h_atomIds_.at(i) = settler;
625     }
626     copyToDeviceBuffer(&d_atomIds_, h_atomIds_.data(), 0, numSettles_, commandStream_,
627                        GpuApiCallBehavior::Sync, nullptr);
628 }
629
630 void SettleCuda::setPbc(const t_pbc* pbc)
631 {
632     setPbcAiuc(pbc->ndim_ePBC, pbc->box, &pbcAiuc_);
633 }
634
635 } // namespace gmx