2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2016,2017,2018,2019, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 * \brief Implements PME GPU Fourier grid solving in CUDA.
39 * \author Aleksei Iupinov <a.yupinov@gmail.com>
46 #include <math_constants.h>
48 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
53 * PME complex grid solver kernel function.
55 * \tparam[in] gridOrdering Specifies the dimension ordering of the complex grid.
56 * \tparam[in] computeEnergyAndVirial Tells if the reciprocal energy and virial should be
57 * computed. \param[in] kernelParams Input PME CUDA data in constant memory.
59 template<GridOrdering gridOrdering, bool computeEnergyAndVirial>
60 __launch_bounds__(c_solveMaxThreadsPerBlock) CLANG_DISABLE_OPTIMIZATION_ATTRIBUTE __global__
61 void pme_solve_kernel(const struct PmeGpuCudaKernelParams kernelParams)
63 /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
64 int majorDim, middleDim, minorDim;
67 case GridOrdering::YZX:
73 case GridOrdering::XYZ:
79 default: assert(false);
82 /* Global memory pointers */
83 const float* __restrict__ gm_splineValueMajor =
84 kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[majorDim];
85 const float* __restrict__ gm_splineValueMiddle =
86 kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[middleDim];
87 const float* __restrict__ gm_splineValueMinor =
88 kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[minorDim];
89 float* __restrict__ gm_virialAndEnergy = kernelParams.constants.d_virialAndEnergy;
90 float2* __restrict__ gm_grid = (float2*)kernelParams.grid.d_fourierGrid;
92 /* Various grid sizes and indices */
93 const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0; // unused
94 const int localSizeMinor = kernelParams.grid.complexGridSizePadded[minorDim];
95 const int localSizeMiddle = kernelParams.grid.complexGridSizePadded[middleDim];
96 const int localCountMiddle = kernelParams.grid.complexGridSize[middleDim];
97 const int localCountMinor = kernelParams.grid.complexGridSize[minorDim];
98 const int nMajor = kernelParams.grid.realGridSize[majorDim];
99 const int nMiddle = kernelParams.grid.realGridSize[middleDim];
100 const int nMinor = kernelParams.grid.realGridSize[minorDim];
101 const int maxkMajor = (nMajor + 1) / 2; // X or Y
102 const int maxkMiddle = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
103 const int maxkMinor = (nMinor + 1) / 2; // Z or X => only check for YZX
105 /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
106 * Each block handles up to c_solveMaxThreadsPerBlock cells -
107 * depending on the grid contiguous dimension size,
108 * that can range from a part of a single gridline to several complete gridlines.
110 const int threadLocalId = threadIdx.x;
111 const int gridLineSize = localCountMinor;
112 const int gridLineIndex = threadLocalId / gridLineSize;
113 const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
114 const int gridLinesPerBlock = max(blockDim.x / gridLineSize, 1);
115 const int activeWarps = (blockDim.x / warp_size);
116 const int indexMinor = blockIdx.x * blockDim.x + gridLineCellIndex;
117 const int indexMiddle = blockIdx.y * gridLinesPerBlock + gridLineIndex;
118 const int indexMajor = blockIdx.z;
120 /* Optional outputs */
129 assert(indexMajor < kernelParams.grid.complexGridSize[majorDim]);
130 if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor)
131 & (gridLineIndex < gridLinesPerBlock))
133 /* The offset should be equal to the global thread index for coalesced access */
134 const int gridIndex = (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
135 float2* __restrict__ gm_gridCell = gm_grid + gridIndex;
137 const int kMajor = indexMajor + localOffsetMajor;
138 /* Checking either X in XYZ, or Y in YZX cases */
139 const float mMajor = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
141 const int kMiddle = indexMiddle + localOffsetMiddle;
142 float mMiddle = kMiddle;
143 /* Checking Y in XYZ case */
144 if (gridOrdering == GridOrdering::XYZ)
146 mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
148 const int kMinor = localOffsetMinor + indexMinor;
149 float mMinor = kMinor;
150 /* Checking X in YZX case */
151 if (gridOrdering == GridOrdering::YZX)
153 mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
155 /* We should skip the k-space point (0,0,0) */
156 const bool notZeroPoint = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
159 switch (gridOrdering)
161 case GridOrdering::YZX:
167 case GridOrdering::XYZ:
173 default: assert(false);
176 /* 0.5 correction factor for the first and last components of a Z dimension */
177 float corner_fac = 1.0f;
178 switch (gridOrdering)
180 case GridOrdering::YZX:
181 if ((kMiddle == 0) | (kMiddle == maxkMiddle))
187 case GridOrdering::XYZ:
188 if ((kMinor == 0) | (kMinor == maxkMinor))
194 default: assert(false);
199 const float mhxk = mX * kernelParams.current.recipBox[XX][XX];
200 const float mhyk = mX * kernelParams.current.recipBox[XX][YY]
201 + mY * kernelParams.current.recipBox[YY][YY];
202 const float mhzk = mX * kernelParams.current.recipBox[XX][ZZ]
203 + mY * kernelParams.current.recipBox[YY][ZZ]
204 + mZ * kernelParams.current.recipBox[ZZ][ZZ];
206 const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
208 // TODO: use LDG/textures for gm_splineValue
209 float denom = m2k * float(CUDART_PI_F) * kernelParams.current.boxVolume
210 * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle]
211 * gm_splineValueMinor[kMinor];
212 assert(isfinite(denom));
213 assert(denom != 0.0f);
215 const float tmp1 = expf(-kernelParams.grid.ewaldFactor * m2k);
216 const float etermk = kernelParams.constants.elFactor * tmp1 / denom;
218 float2 gridValue = *gm_gridCell;
219 const float2 oldGridValue = gridValue;
220 gridValue.x *= etermk;
221 gridValue.y *= etermk;
222 *gm_gridCell = gridValue;
224 if (computeEnergyAndVirial)
227 2.0f * (gridValue.x * oldGridValue.x + gridValue.y * oldGridValue.y);
229 float vfactor = (kernelParams.grid.ewaldFactor + 1.0f / m2k) * 2.0f;
230 float ets2 = corner_fac * tmp1k;
233 float ets2vf = ets2 * vfactor;
235 virxx = ets2vf * mhxk * mhxk - ets2;
236 virxy = ets2vf * mhxk * mhyk;
237 virxz = ets2vf * mhxk * mhzk;
238 viryy = ets2vf * mhyk * mhyk - ets2;
239 viryz = ets2vf * mhyk * mhzk;
240 virzz = ets2vf * mhzk * mhzk - ets2;
245 /* Optional energy/virial reduction */
246 if (computeEnergyAndVirial)
248 /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
249 * The idea is to reduce 7 energy/virial components into a single variable (aligned by 8).
250 * We will reduce everything into virxx.
253 /* We can only reduce warp-wise */
254 const int width = warp_size;
255 const unsigned int activeMask = c_fullWarpMask;
257 /* Making pair sums */
258 virxx += __shfl_down_sync(activeMask, virxx, 1, width);
259 viryy += __shfl_up_sync(activeMask, viryy, 1, width);
260 virzz += __shfl_down_sync(activeMask, virzz, 1, width);
261 virxy += __shfl_up_sync(activeMask, virxy, 1, width);
262 virxz += __shfl_down_sync(activeMask, virxz, 1, width);
263 viryz += __shfl_up_sync(activeMask, viryz, 1, width);
264 energy += __shfl_down_sync(activeMask, energy, 1, width);
265 if (threadLocalId & 1)
267 virxx = viryy; // virxx now holds virxx and viryy pair sums
268 virzz = virxy; // virzz now holds virzz and virxy pair sums
269 virxz = viryz; // virxz now holds virxz and viryz pair sums
272 /* Making quad sums */
273 virxx += __shfl_down_sync(activeMask, virxx, 2, width);
274 virzz += __shfl_up_sync(activeMask, virzz, 2, width);
275 virxz += __shfl_down_sync(activeMask, virxz, 2, width);
276 energy += __shfl_up_sync(activeMask, energy, 2, width);
277 if (threadLocalId & 2)
279 virxx = virzz; // virxx now holds quad sums of virxx, virxy, virzz and virxy
280 virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
283 /* Making octet sums */
284 virxx += __shfl_down_sync(activeMask, virxx, 4, width);
285 virxz += __shfl_up_sync(activeMask, virxz, 4, width);
286 if (threadLocalId & 4)
288 virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
291 /* We only need to reduce virxx now */
293 for (int delta = 8; delta < width; delta <<= 1)
295 virxx += __shfl_down_sync(activeMask, virxx, delta, width);
297 /* Now first 7 threads of each warp have the full output contributions in virxx */
299 const int componentIndex = threadLocalId & (warp_size - 1);
300 const bool validComponentIndex = (componentIndex < c_virialAndEnergyCount);
301 /* Reduce 7 outputs per warp in the shared memory */
303 8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
304 assert(c_virialAndEnergyCount == 7);
305 const int reductionBufferSize = (c_solveMaxThreadsPerBlock / warp_size) * stride;
306 __shared__ float sm_virialAndEnergy[reductionBufferSize];
308 if (validComponentIndex)
310 const int warpIndex = threadLocalId / warp_size;
311 sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
315 /* Reduce to the single warp size */
316 const int targetIndex = threadLocalId;
318 for (int reductionStride = reductionBufferSize >> 1; reductionStride >= warp_size;
319 reductionStride >>= 1)
321 const int sourceIndex = targetIndex + reductionStride;
322 if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
324 // TODO: the second conditional is only needed on first iteration, actually - see if compiler eliminates it!
325 sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
330 /* Now use shuffle again */
331 /* NOTE: This reduction assumes there are at least 4 warps (asserted).
332 * To use fewer warps, add to the conditional:
333 * && threadLocalId < activeWarps * stride
335 assert(activeWarps * stride >= warp_size);
336 if (threadLocalId < warp_size)
338 float output = sm_virialAndEnergy[threadLocalId];
340 for (int delta = stride; delta < warp_size; delta <<= 1)
342 output += __shfl_down_sync(activeMask, output, delta, warp_size);
345 if (validComponentIndex)
347 assert(isfinite(output));
348 atomicAdd(gm_virialAndEnergy + componentIndex, output);
354 //! Kernel instantiations
355 template __global__ void pme_solve_kernel<GridOrdering::YZX, true>(const PmeGpuCudaKernelParams);
356 template __global__ void pme_solve_kernel<GridOrdering::YZX, false>(const PmeGpuCudaKernelParams);
357 template __global__ void pme_solve_kernel<GridOrdering::XYZ, true>(const PmeGpuCudaKernelParams);
358 template __global__ void pme_solve_kernel<GridOrdering::XYZ, false>(const PmeGpuCudaKernelParams);