Two sets of coefficients for Coulomb FEP PME on GPU
[alexxy/gromacs.git] / src / gromacs / ewald / pme_solve.cu
1 /*
2  * This file is part of the GROMACS molecular simulation package.
3  *
4  * Copyright (c) 2016,2017,2018,2019,2020, by the GROMACS development team, led by
5  * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6  * and including many others, as listed in the AUTHORS file in the
7  * top-level source directory and at http://www.gromacs.org.
8  *
9  * GROMACS is free software; you can redistribute it and/or
10  * modify it under the terms of the GNU Lesser General Public License
11  * as published by the Free Software Foundation; either version 2.1
12  * of the License, or (at your option) any later version.
13  *
14  * GROMACS is distributed in the hope that it will be useful,
15  * but WITHOUT ANY WARRANTY; without even the implied warranty of
16  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
17  * Lesser General Public License for more details.
18  *
19  * You should have received a copy of the GNU Lesser General Public
20  * License along with GROMACS; if not, see
21  * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22  * Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA.
23  *
24  * If you want to redistribute modifications to GROMACS, please
25  * consider that scientific software is very special. Version
26  * control is crucial - bugs must be traceable. We will be happy to
27  * consider code for inclusion in the official distribution, but
28  * derived work must not be called official GROMACS. Details are found
29  * in the README & COPYING files - if they are missing, get the
30  * official version at http://www.gromacs.org.
31  *
32  * To help us fund GROMACS development, we humbly ask that you cite
33  * the research papers on the package. Check out http://www.gromacs.org.
34  */
35
36 /*! \internal \file
37  *  \brief Implements PME GPU Fourier grid solving in CUDA.
38  *
39  *  \author Aleksei Iupinov <a.yupinov@gmail.com>
40  */
41
42 #include "gmxpre.h"
43
44 #include <cassert>
45
46 #include <math_constants.h>
47
48 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
49
50 #include "pme.cuh"
51
52 /*! \brief
53  * PME complex grid solver kernel function.
54  *
55  * \tparam[in] gridOrdering             Specifies the dimension ordering of the complex grid.
56  * \tparam[in] computeEnergyAndVirial   Tells if the reciprocal energy and virial should be computed.
57  * \tparam[in] gridIndex                The index of the grid to use in the kernel.
58  * \param[in]  kernelParams             Input PME CUDA data in constant memory.
59  */
60 template<GridOrdering gridOrdering, bool computeEnergyAndVirial, const int gridIndex>
61 __launch_bounds__(c_solveMaxThreadsPerBlock) CLANG_DISABLE_OPTIMIZATION_ATTRIBUTE __global__
62         void pme_solve_kernel(const struct PmeGpuCudaKernelParams kernelParams)
63 {
64     /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
65     int majorDim, middleDim, minorDim;
66     switch (gridOrdering)
67     {
68         case GridOrdering::YZX:
69             majorDim  = YY;
70             middleDim = ZZ;
71             minorDim  = XX;
72             break;
73
74         case GridOrdering::XYZ:
75             majorDim  = XX;
76             middleDim = YY;
77             minorDim  = ZZ;
78             break;
79
80         default: assert(false);
81     }
82
83     /* Global memory pointers */
84     const float* __restrict__ gm_splineValueMajor = kernelParams.grid.d_splineModuli[gridIndex]
85                                                     + kernelParams.grid.splineValuesOffset[majorDim];
86     const float* __restrict__ gm_splineValueMiddle = kernelParams.grid.d_splineModuli[gridIndex]
87                                                      + kernelParams.grid.splineValuesOffset[middleDim];
88     const float* __restrict__ gm_splineValueMinor = kernelParams.grid.d_splineModuli[gridIndex]
89                                                     + kernelParams.grid.splineValuesOffset[minorDim];
90     float* __restrict__ gm_virialAndEnergy = kernelParams.constants.d_virialAndEnergy[gridIndex];
91     float2* __restrict__ gm_grid           = (float2*)kernelParams.grid.d_fourierGrid[gridIndex];
92
93     /* Various grid sizes and indices */
94     const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0; // unused
95     const int localSizeMinor   = kernelParams.grid.complexGridSizePadded[minorDim];
96     const int localSizeMiddle  = kernelParams.grid.complexGridSizePadded[middleDim];
97     const int localCountMiddle = kernelParams.grid.complexGridSize[middleDim];
98     const int localCountMinor  = kernelParams.grid.complexGridSize[minorDim];
99     const int nMajor           = kernelParams.grid.realGridSize[majorDim];
100     const int nMiddle          = kernelParams.grid.realGridSize[middleDim];
101     const int nMinor           = kernelParams.grid.realGridSize[minorDim];
102     const int maxkMajor        = (nMajor + 1) / 2;  // X or Y
103     const int maxkMiddle       = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
104     const int maxkMinor        = (nMinor + 1) / 2;  // Z or X => only check for YZX
105
106     /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
107      * Each block handles up to c_solveMaxThreadsPerBlock cells -
108      * depending on the grid contiguous dimension size,
109      * that can range from a part of a single gridline to several complete gridlines.
110      */
111     const int threadLocalId     = threadIdx.x;
112     const int gridLineSize      = localCountMinor;
113     const int gridLineIndex     = threadLocalId / gridLineSize;
114     const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
115     const int gridLinesPerBlock = max(blockDim.x / gridLineSize, 1);
116     const int activeWarps       = (blockDim.x / warp_size);
117     const int indexMinor        = blockIdx.x * blockDim.x + gridLineCellIndex;
118     const int indexMiddle       = blockIdx.y * gridLinesPerBlock + gridLineIndex;
119     const int indexMajor        = blockIdx.z;
120
121     /* Optional outputs */
122     float energy = 0.0f;
123     float virxx  = 0.0f;
124     float virxy  = 0.0f;
125     float virxz  = 0.0f;
126     float viryy  = 0.0f;
127     float viryz  = 0.0f;
128     float virzz  = 0.0f;
129
130     assert(indexMajor < kernelParams.grid.complexGridSize[majorDim]);
131     if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
132         & (gridLineIndex < gridLinesPerBlock))
133     {
134         /* The offset should be equal to the global thread index for coalesced access */
135         const int gridThreadIndex =
136                 (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
137         float2* __restrict__ gm_gridCell = gm_grid + gridThreadIndex;
138
139         const int kMajor = indexMajor + localOffsetMajor;
140         /* Checking either X in XYZ, or Y in YZX cases */
141         const float mMajor = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
142
143         const int kMiddle = indexMiddle + localOffsetMiddle;
144         float     mMiddle = kMiddle;
145         /* Checking Y in XYZ case */
146         if (gridOrdering == GridOrdering::XYZ)
147         {
148             mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
149         }
150         const int kMinor = localOffsetMinor + indexMinor;
151         float     mMinor = kMinor;
152         /* Checking X in YZX case */
153         if (gridOrdering == GridOrdering::YZX)
154         {
155             mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
156         }
157         /* We should skip the k-space point (0,0,0) */
158         const bool notZeroPoint = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
159
160         float mX, mY, mZ;
161         switch (gridOrdering)
162         {
163             case GridOrdering::YZX:
164                 mX = mMinor;
165                 mY = mMajor;
166                 mZ = mMiddle;
167                 break;
168
169             case GridOrdering::XYZ:
170                 mX = mMajor;
171                 mY = mMiddle;
172                 mZ = mMinor;
173                 break;
174
175             default: assert(false);
176         }
177
178         /* 0.5 correction factor for the first and last components of a Z dimension */
179         float corner_fac = 1.0f;
180         switch (gridOrdering)
181         {
182             case GridOrdering::YZX:
183                 if ((kMiddle == 0) | (kMiddle == maxkMiddle))
184                 {
185                     corner_fac = 0.5f;
186                 }
187                 break;
188
189             case GridOrdering::XYZ:
190                 if ((kMinor == 0) | (kMinor == maxkMinor))
191                 {
192                     corner_fac = 0.5f;
193                 }
194                 break;
195
196             default: assert(false);
197         }
198
199         if (notZeroPoint)
200         {
201             const float mhxk = mX * kernelParams.current.recipBox[XX][XX];
202             const float mhyk = mX * kernelParams.current.recipBox[XX][YY]
203                                + mY * kernelParams.current.recipBox[YY][YY];
204             const float mhzk = mX * kernelParams.current.recipBox[XX][ZZ]
205                                + mY * kernelParams.current.recipBox[YY][ZZ]
206                                + mZ * kernelParams.current.recipBox[ZZ][ZZ];
207
208             const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
209             assert(m2k != 0.0f);
210             // TODO: use LDG/textures for gm_splineValue
211             float denom = m2k * float(CUDART_PI_F) * kernelParams.current.boxVolume
212                           * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
213                           * gm_splineValueMinor[kMinor];
214             assert(isfinite(denom));
215             assert(denom != 0.0f);
216
217             const float tmp1   = expf(-kernelParams.grid.ewaldFactor * m2k);
218             const float etermk = kernelParams.constants.elFactor * tmp1 / denom;
219
220             float2       gridValue    = *gm_gridCell;
221             const float2 oldGridValue = gridValue;
222             gridValue.x *= etermk;
223             gridValue.y *= etermk;
224             *gm_gridCell = gridValue;
225
226             if (computeEnergyAndVirial)
227             {
228                 const float tmp1k =
229                         2.0f * (gridValue.x * oldGridValue.x + gridValue.y * oldGridValue.y);
230
231                 float vfactor = (kernelParams.grid.ewaldFactor + 1.0f / m2k) * 2.0f;
232                 float ets2    = corner_fac * tmp1k;
233                 energy        = ets2;
234
235                 float ets2vf = ets2 * vfactor;
236
237                 virxx = ets2vf * mhxk * mhxk - ets2;
238                 virxy = ets2vf * mhxk * mhyk;
239                 virxz = ets2vf * mhxk * mhzk;
240                 viryy = ets2vf * mhyk * mhyk - ets2;
241                 viryz = ets2vf * mhyk * mhzk;
242                 virzz = ets2vf * mhzk * mhzk - ets2;
243             }
244         }
245     }
246
247     /* Optional energy/virial reduction */
248     if (computeEnergyAndVirial)
249     {
250         /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
251          * The idea is to reduce 7 energy/virial components into a single variable (aligned by 8).
252          * We will reduce everything into virxx.
253          */
254
255         /* We can only reduce warp-wise */
256         const int          width      = warp_size;
257         const unsigned int activeMask = c_fullWarpMask;
258
259         /* Making pair sums */
260         virxx += __shfl_down_sync(activeMask, virxx, 1, width);
261         viryy += __shfl_up_sync(activeMask, viryy, 1, width);
262         virzz += __shfl_down_sync(activeMask, virzz, 1, width);
263         virxy += __shfl_up_sync(activeMask, virxy, 1, width);
264         virxz += __shfl_down_sync(activeMask, virxz, 1, width);
265         viryz += __shfl_up_sync(activeMask, viryz, 1, width);
266         energy += __shfl_down_sync(activeMask, energy, 1, width);
267         if (threadLocalId & 1)
268         {
269             virxx = viryy; // virxx now holds virxx and viryy pair sums
270             virzz = virxy; // virzz now holds virzz and virxy pair sums
271             virxz = viryz; // virxz now holds virxz and viryz pair sums
272         }
273
274         /* Making quad sums */
275         virxx += __shfl_down_sync(activeMask, virxx, 2, width);
276         virzz += __shfl_up_sync(activeMask, virzz, 2, width);
277         virxz += __shfl_down_sync(activeMask, virxz, 2, width);
278         energy += __shfl_up_sync(activeMask, energy, 2, width);
279         if (threadLocalId & 2)
280         {
281             virxx = virzz;  // virxx now holds quad sums of virxx, virxy, virzz and virxy
282             virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
283         }
284
285         /* Making octet sums */
286         virxx += __shfl_down_sync(activeMask, virxx, 4, width);
287         virxz += __shfl_up_sync(activeMask, virxz, 4, width);
288         if (threadLocalId & 4)
289         {
290             virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
291         }
292
293         /* We only need to reduce virxx now */
294 #pragma unroll
295         for (int delta = 8; delta < width; delta <<= 1)
296         {
297             virxx += __shfl_down_sync(activeMask, virxx, delta, width);
298         }
299         /* Now first 7 threads of each warp have the full output contributions in virxx */
300
301         const int  componentIndex      = threadLocalId & (warp_size - 1);
302         const bool validComponentIndex = (componentIndex < c_virialAndEnergyCount);
303         /* Reduce 7 outputs per warp in the shared memory */
304         const int stride =
305                 8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
306         assert(c_virialAndEnergyCount == 7);
307         const int        reductionBufferSize = (c_solveMaxThreadsPerBlock / warp_size) * stride;
308         __shared__ float sm_virialAndEnergy[reductionBufferSize];
309
310         if (validComponentIndex)
311         {
312             const int warpIndex                                     = threadLocalId / warp_size;
313             sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
314         }
315         __syncthreads();
316
317         /* Reduce to the single warp size */
318         const int targetIndex = threadLocalId;
319 #pragma unroll
320         for (int reductionStride = reductionBufferSize >> 1; reductionStride >= warp_size;
321              reductionStride >>= 1)
322         {
323             const int sourceIndex = targetIndex + reductionStride;
324             if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
325             {
326                 // TODO: the second conditional is only needed on first iteration, actually - see if compiler eliminates it!
327                 sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
328             }
329             __syncthreads();
330         }
331
332         /* Now use shuffle again */
333         /* NOTE: This reduction assumes there are at least 4 warps (asserted).
334          *       To use fewer warps, add to the conditional:
335          *       && threadLocalId < activeWarps * stride
336          */
337         assert(activeWarps * stride >= warp_size);
338         if (threadLocalId < warp_size)
339         {
340             float output = sm_virialAndEnergy[threadLocalId];
341 #pragma unroll
342             for (int delta = stride; delta < warp_size; delta <<= 1)
343             {
344                 output += __shfl_down_sync(activeMask, output, delta, warp_size);
345             }
346             /* Final output */
347             if (validComponentIndex)
348             {
349                 assert(isfinite(output));
350                 atomicAdd(gm_virialAndEnergy + componentIndex, output);
351             }
352         }
353     }
354 }
355
356 //! Kernel instantiations
357 template __global__ void pme_solve_kernel<GridOrdering::YZX, true, 0>(const PmeGpuCudaKernelParams);
358 template __global__ void pme_solve_kernel<GridOrdering::YZX, false, 0>(const PmeGpuCudaKernelParams);
359 template __global__ void pme_solve_kernel<GridOrdering::XYZ, true, 0>(const PmeGpuCudaKernelParams);
360 template __global__ void pme_solve_kernel<GridOrdering::XYZ, false, 0>(const PmeGpuCudaKernelParams);
361 template __global__ void pme_solve_kernel<GridOrdering::YZX, true, 1>(const PmeGpuCudaKernelParams);
362 template __global__ void pme_solve_kernel<GridOrdering::YZX, false, 1>(const PmeGpuCudaKernelParams);
363 template __global__ void pme_solve_kernel<GridOrdering::XYZ, true, 1>(const PmeGpuCudaKernelParams);
364 template __global__ void pme_solve_kernel<GridOrdering::XYZ, false, 1>(const PmeGpuCudaKernelParams);