2 * This file is part of the GROMACS molecular simulation package.
4 * Copyright (c) 2016,2017, by the GROMACS development team, led by
5 * Mark Abraham, David van der Spoel, Berk Hess, and Erik Lindahl,
6 * and including many others, as listed in the AUTHORS file in the
7 * top-level source directory and at http://www.gromacs.org.
9 * GROMACS is free software; you can redistribute it and/or
10 * modify it under the terms of the GNU Lesser General Public License
11 * as published by the Free Software Foundation; either version 2.1
12 * of the License, or (at your option) any later version.
14 * GROMACS is distributed in the hope that it will be useful,
15 * but WITHOUT ANY WARRANTY; without even the implied warranty of
16 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
17 * Lesser General Public License for more details.
19 * You should have received a copy of the GNU Lesser General Public
20 * License along with GROMACS; if not, see
21 * http://www.gnu.org/licenses, or write to the Free Software Foundation,
22 * Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24 * If you want to redistribute modifications to GROMACS, please
25 * consider that scientific software is very special. Version
26 * control is crucial - bugs must be traceable. We will be happy to
27 * consider code for inclusion in the official distribution, but
28 * derived work must not be called official GROMACS. Details are found
29 * in the README & COPYING files - if they are missing, get the
30 * official version at http://www.gromacs.org.
32 * To help us fund GROMACS development, we humbly ask that you cite
33 * the research papers on the package. Check out http://www.gromacs.org.
37 * \brief Implements PME GPU Fourier grid solving in CUDA.
39 * \author Aleksei Iupinov <a.yupinov@gmail.com>
46 #include "gromacs/gpu_utils/cuda_arch_utils.cuh"
47 #include "gromacs/gpu_utils/cudautils.cuh"
48 #include "gromacs/utility/exceptions.h"
49 #include "gromacs/utility/gmxassert.h"
52 #include "pme-timings.cuh"
54 //! Solving kernel max block width in warps picked among powers of 2 (2, 4, 8, 16) for max. occupancy and min. runtime
55 //! (560Ti (CC2.1), 660Ti (CC3.0) and 750 (CC5.0)))
56 constexpr int c_solveMaxWarpsPerBlock = 8;
57 //! Solving kernel max block size in threads
58 constexpr int c_solveMaxThreadsPerBlock = (c_solveMaxWarpsPerBlock * warp_size);
60 // CUDA 6.5 can not compile enum class as a template kernel parameter,
61 // so we replace it with a duplicate simple enum
62 #if GMX_CUDA_VERSION >= 7000
63 using GridOrderingInternal = GridOrdering;
65 enum GridOrderingInternal
73 * PME complex grid solver kernel function.
75 * \tparam[in] gridOrdering Specifies the dimension ordering of the complex grid.
76 * \tparam[in] computeEnergyAndVirial Tells if the reciprocal energy and virial should be computed.
77 * \param[in] kernelParams Input PME CUDA data in constant memory.
80 GridOrderingInternal gridOrdering,
81 bool computeEnergyAndVirial
83 __launch_bounds__(c_solveMaxThreadsPerBlock)
84 __global__ void pme_solve_kernel(const struct PmeGpuCudaKernelParams kernelParams)
86 /* This kernel supports 2 different grid dimension orderings: YZX and XYZ */
87 int majorDim, middleDim, minorDim;
90 case GridOrderingInternal::YZX:
96 case GridOrderingInternal::XYZ:
106 /* Global memory pointers */
107 const float * __restrict__ gm_splineValueMajor = kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[majorDim];
108 const float * __restrict__ gm_splineValueMiddle = kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[middleDim];
109 const float * __restrict__ gm_splineValueMinor = kernelParams.grid.d_splineModuli + kernelParams.grid.splineValuesOffset[minorDim];
110 float * __restrict__ gm_virialAndEnergy = kernelParams.constants.d_virialAndEnergy;
111 float2 * __restrict__ gm_grid = (float2 *)kernelParams.grid.d_fourierGrid;
113 /* Various grid sizes and indices */
114 const int localOffsetMinor = 0, localOffsetMajor = 0, localOffsetMiddle = 0; //unused
115 const int localSizeMinor = kernelParams.grid.complexGridSizePadded[minorDim];
116 const int localSizeMiddle = kernelParams.grid.complexGridSizePadded[middleDim];
117 const int localCountMiddle = kernelParams.grid.complexGridSize[middleDim];
118 const int localCountMinor = kernelParams.grid.complexGridSize[minorDim];
119 const int nMajor = kernelParams.grid.realGridSize[majorDim];
120 const int nMiddle = kernelParams.grid.realGridSize[middleDim];
121 const int nMinor = kernelParams.grid.realGridSize[minorDim];
122 const int maxkMajor = (nMajor + 1) / 2; // X or Y
123 const int maxkMiddle = (nMiddle + 1) / 2; // Y OR Z => only check for !YZX
124 const int maxkMinor = (nMinor + 1) / 2; // Z or X => only check for YZX
126 /* Each thread works on one cell of the Fourier space complex 3D grid (gm_grid).
127 * Each block handles up to c_solveMaxThreadsPerBlock cells -
128 * depending on the grid contiguous dimension size,
129 * that can range from a part of a single gridline to several complete gridlines.
131 const int threadLocalId = threadIdx.x;
132 const int gridLineSize = localCountMinor;
133 const int gridLineIndex = threadLocalId / gridLineSize;
134 const int gridLineCellIndex = threadLocalId - gridLineSize * gridLineIndex;
135 const int gridLinesPerBlock = blockDim.x / gridLineSize;
136 const int activeWarps = (blockDim.x / warp_size);
137 const int indexMinor = blockIdx.x * blockDim.x + gridLineCellIndex;
138 const int indexMiddle = blockIdx.y * gridLinesPerBlock + gridLineIndex;
139 const int indexMajor = blockIdx.z;
141 /* Optional outputs */
150 assert(indexMajor < kernelParams.grid.complexGridSize[majorDim]);
151 if ((indexMiddle < localCountMiddle) & (indexMinor < localCountMinor) & (gridLineIndex < gridLinesPerBlock))
153 /* The offset should be equal to the global thread index for coalesced access */
154 const int gridIndex = (indexMajor * localSizeMiddle + indexMiddle) * localSizeMinor + indexMinor;
155 float2 * __restrict__ gm_gridCell = gm_grid + gridIndex;
157 const int kMajor = indexMajor + localOffsetMajor;
158 /* Checking either X in XYZ, or Y in YZX cases */
159 const float mMajor = (kMajor < maxkMajor) ? kMajor : (kMajor - nMajor);
161 const int kMiddle = indexMiddle + localOffsetMiddle;
162 float mMiddle = kMiddle;
163 /* Checking Y in XYZ case */
164 if (gridOrdering == GridOrderingInternal::XYZ)
166 mMiddle = (kMiddle < maxkMiddle) ? kMiddle : (kMiddle - nMiddle);
168 const int kMinor = localOffsetMinor + indexMinor;
169 float mMinor = kMinor;
170 /* Checking X in YZX case */
171 if (gridOrdering == GridOrderingInternal::YZX)
173 mMinor = (kMinor < maxkMinor) ? kMinor : (kMinor - nMinor);
175 /* We should skip the k-space point (0,0,0) */
176 const bool notZeroPoint = (kMinor > 0) | (kMajor > 0) | (kMiddle > 0);
179 switch (gridOrdering)
181 case GridOrderingInternal::YZX:
187 case GridOrderingInternal::XYZ:
197 /* 0.5 correction factor for the first and last components of a Z dimension */
198 float corner_fac = 1.0f;
199 switch (gridOrdering)
201 case GridOrderingInternal::YZX:
202 if ((kMiddle == 0) | (kMiddle == maxkMiddle))
208 case GridOrderingInternal::XYZ:
209 if ((kMinor == 0) | (kMinor == maxkMinor))
221 const float mhxk = mX * kernelParams.current.recipBox[XX][XX];
222 const float mhyk = mX * kernelParams.current.recipBox[XX][YY] + mY * kernelParams.current.recipBox[YY][YY];
223 const float mhzk = mX * kernelParams.current.recipBox[XX][ZZ] + mY * kernelParams.current.recipBox[YY][ZZ] + mZ * kernelParams.current.recipBox[ZZ][ZZ];
225 const float m2k = mhxk * mhxk + mhyk * mhyk + mhzk * mhzk;
227 //TODO: use LDG/textures for gm_splineValue
228 float denom = m2k * float(M_PI) * kernelParams.current.boxVolume * gm_splineValueMajor[kMajor] * gm_splineValueMiddle[kMiddle] * gm_splineValueMinor[kMinor];
229 assert(isfinite(denom));
230 assert(denom != 0.0f);
231 const float tmp1 = expf(-kernelParams.grid.ewaldFactor * m2k);
232 const float etermk = kernelParams.constants.elFactor * tmp1 / denom;
234 float2 gridValue = *gm_gridCell;
235 const float2 oldGridValue = gridValue;
236 gridValue.x *= etermk;
237 gridValue.y *= etermk;
238 *gm_gridCell = gridValue;
240 if (computeEnergyAndVirial)
242 const float tmp1k = 2.0f * (gridValue.x * oldGridValue.x + gridValue.y * oldGridValue.y);
244 float vfactor = (kernelParams.grid.ewaldFactor + 1.0f / m2k) * 2.0f;
245 float ets2 = corner_fac * tmp1k;
248 float ets2vf = ets2 * vfactor;
250 virxx = ets2vf * mhxk * mhxk - ets2;
251 virxy = ets2vf * mhxk * mhyk;
252 virxz = ets2vf * mhxk * mhzk;
253 viryy = ets2vf * mhyk * mhyk - ets2;
254 viryz = ets2vf * mhyk * mhzk;
255 virzz = ets2vf * mhzk * mhzk - ets2;
260 /* Optional energy/virial reduction */
261 if (computeEnergyAndVirial)
263 #if (GMX_PTX_ARCH >= 300)
264 /* A tricky shuffle reduction inspired by reduce_force_j_warp_shfl.
265 * The idea is to reduce 7 energy/virial components into a single variable (aligned by 8).
266 * We will reduce everything into virxx.
269 /* We can only reduce warp-wise */
270 const int width = warp_size;
271 const unsigned int activeMask = c_fullWarpMask;
273 /* Making pair sums */
274 virxx += gmx_shfl_down_sync(activeMask, virxx, 1, width);
275 viryy += gmx_shfl_up_sync (activeMask, viryy, 1, width);
276 virzz += gmx_shfl_down_sync(activeMask, virzz, 1, width);
277 virxy += gmx_shfl_up_sync (activeMask, virxy, 1, width);
278 virxz += gmx_shfl_down_sync(activeMask, virxz, 1, width);
279 viryz += gmx_shfl_up_sync (activeMask, viryz, 1, width);
280 energy += gmx_shfl_down_sync(activeMask, energy, 1, width);
281 if (threadLocalId & 1)
283 virxx = viryy; // virxx now holds virxx and viryy pair sums
284 virzz = virxy; // virzz now holds virzz and virxy pair sums
285 virxz = viryz; // virxz now holds virxz and viryz pair sums
288 /* Making quad sums */
289 virxx += gmx_shfl_down_sync(activeMask, virxx, 2, width);
290 virzz += gmx_shfl_up_sync (activeMask, virzz, 2, width);
291 virxz += gmx_shfl_down_sync(activeMask, virxz, 2, width);
292 energy += gmx_shfl_up_sync (activeMask, energy, 2, width);
293 if (threadLocalId & 2)
295 virxx = virzz; // virxx now holds quad sums of virxx, virxy, virzz and virxy
296 virxz = energy; // virxz now holds quad sums of virxz, viryz, energy and unused paddings
299 /* Making octet sums */
300 virxx += gmx_shfl_down_sync(activeMask, virxx, 4, width);
301 virxz += gmx_shfl_up_sync (activeMask, virxz, 4, width);
302 if (threadLocalId & 4)
304 virxx = virxz; // virxx now holds all 7 components' octet sums + unused paddings
307 /* We only need to reduce virxx now */
309 for (int delta = 8; delta < width; delta <<= 1)
311 virxx += gmx_shfl_down_sync(activeMask, virxx, delta, width);
313 /* Now first 7 threads of each warp have the full output contributions in virxx */
315 const int componentIndex = threadLocalId & (warp_size - 1);
316 const bool validComponentIndex = (componentIndex < c_virialAndEnergyCount);
317 /* Reduce 7 outputs per warp in the shared memory */
318 const int stride = 8; // this is c_virialAndEnergyCount==7 rounded up to power of 2 for convenience, hence the assert
319 assert(c_virialAndEnergyCount == 7);
320 const int reductionBufferSize = (c_solveMaxThreadsPerBlock / warp_size) * stride;
321 __shared__ float sm_virialAndEnergy[reductionBufferSize];
323 if (validComponentIndex)
325 const int warpIndex = threadLocalId / warp_size;
326 sm_virialAndEnergy[warpIndex * stride + componentIndex] = virxx;
330 /* Reduce to the single warp size */
331 const int targetIndex = threadLocalId;
333 for (int reductionStride = reductionBufferSize >> 1; reductionStride >= warp_size; reductionStride >>= 1)
335 const int sourceIndex = targetIndex + reductionStride;
336 if ((targetIndex < reductionStride) & (sourceIndex < activeWarps * stride))
338 // TODO: the second conditional is only needed on first iteration, actually - see if compiler eliminates it!
339 sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[sourceIndex];
344 /* Now use shuffle again */
345 if (threadLocalId < warp_size)
347 float output = sm_virialAndEnergy[threadLocalId];
349 for (int delta = stride; delta < warp_size; delta <<= 1)
351 output += gmx_shfl_down_sync(activeMask, output, delta, warp_size);
354 if (validComponentIndex)
356 assert(isfinite(output));
357 atomicAdd(gm_virialAndEnergy + componentIndex, output);
361 /* Shared memory reduction with atomics for compute capability < 3.0.
362 * Each component is first reduced into warp_size positions in the shared memory;
363 * Then first c_virialAndEnergyCount warps reduce everything further and add to the global memory.
364 * This can likely be improved, but is anyway faster than the previous straightforward reduction,
365 * which was using too much shared memory (for storing all 7 floats on each thread).
366 * [48KB (shared mem limit per SM on CC2.x) / sizeof(float) (4) / c_solveMaxThreadsPerBlock (256) / c_virialAndEnergyCount (7) ==
367 * 6 blocks per SM instead of 16 which is maximum on CC2.x].
370 const int lane = threadLocalId & (warp_size - 1);
371 const int warpIndex = threadLocalId / warp_size;
372 const bool firstWarp = (warpIndex == 0);
373 __shared__ float sm_virialAndEnergy[c_virialAndEnergyCount * warp_size];
376 sm_virialAndEnergy[0 * warp_size + lane] = virxx;
377 sm_virialAndEnergy[1 * warp_size + lane] = viryy;
378 sm_virialAndEnergy[2 * warp_size + lane] = virzz;
379 sm_virialAndEnergy[3 * warp_size + lane] = virxy;
380 sm_virialAndEnergy[4 * warp_size + lane] = virxz;
381 sm_virialAndEnergy[5 * warp_size + lane] = viryz;
382 sm_virialAndEnergy[6 * warp_size + lane] = energy;
387 atomicAdd(sm_virialAndEnergy + 0 * warp_size + lane, virxx);
388 atomicAdd(sm_virialAndEnergy + 1 * warp_size + lane, viryy);
389 atomicAdd(sm_virialAndEnergy + 2 * warp_size + lane, virzz);
390 atomicAdd(sm_virialAndEnergy + 3 * warp_size + lane, virxy);
391 atomicAdd(sm_virialAndEnergy + 4 * warp_size + lane, virxz);
392 atomicAdd(sm_virialAndEnergy + 5 * warp_size + lane, viryz);
393 atomicAdd(sm_virialAndEnergy + 6 * warp_size + lane, energy);
397 GMX_UNUSED_VALUE(activeWarps);
398 assert(activeWarps >= c_virialAndEnergyCount); // we need to cover all components, or have multiple iterations otherwise
399 const int componentIndex = warpIndex;
400 if (componentIndex < c_virialAndEnergyCount)
402 const int targetIndex = threadLocalId;
404 for (int reductionStride = warp_size >> 1; reductionStride >= 1; reductionStride >>= 1)
406 if (lane < reductionStride)
408 sm_virialAndEnergy[targetIndex] += sm_virialAndEnergy[targetIndex + reductionStride];
413 atomicAdd(gm_virialAndEnergy + componentIndex, sm_virialAndEnergy[targetIndex]);
420 void pme_gpu_solve(const PmeGpu *pmeGpu, t_complex *h_grid,
421 GridOrdering gridOrdering, bool computeEnergyAndVirial)
423 const bool copyInputAndOutputGrid = pme_gpu_is_testing(pmeGpu) || !pme_gpu_performs_FFT(pmeGpu);
425 cudaStream_t stream = pmeGpu->archSpecific->pmeStream;
426 const auto *kernelParamsPtr = pmeGpu->kernelParams.get();
428 if (copyInputAndOutputGrid)
430 cu_copy_H2D(kernelParamsPtr->grid.d_fourierGrid, h_grid, pmeGpu->archSpecific->complexGridSize * sizeof(float),
431 pmeGpu->settings.transferKind, stream);
434 int majorDim = -1, middleDim = -1, minorDim = -1;
435 switch (gridOrdering)
437 case GridOrdering::YZX:
443 case GridOrdering::XYZ:
450 GMX_ASSERT(false, "Implement grid ordering here and below for the kernel launch");
453 const int maxBlockSize = c_solveMaxThreadsPerBlock;
454 const int gridLineSize = pmeGpu->kernelParams->grid.complexGridSize[minorDim];
455 const int gridLinesPerBlock = std::max(maxBlockSize / gridLineSize, 1);
456 const int blocksPerGridLine = (gridLineSize + maxBlockSize - 1) / maxBlockSize;
457 const int cellsPerBlock = gridLineSize * gridLinesPerBlock;
458 const int blockSize = (cellsPerBlock + warp_size - 1) / warp_size * warp_size;
459 // rounding up to full warps so that shuffle operations produce defined results
460 dim3 threads(blockSize);
461 dim3 blocks(blocksPerGridLine,
462 (pmeGpu->kernelParams->grid.complexGridSize[middleDim] + gridLinesPerBlock - 1) / gridLinesPerBlock,
463 pmeGpu->kernelParams->grid.complexGridSize[majorDim]);
465 pme_gpu_start_timing(pmeGpu, gtPME_SOLVE);
466 if (gridOrdering == GridOrdering::YZX)
468 if (computeEnergyAndVirial)
470 pme_solve_kernel<GridOrderingInternal::YZX, true> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
474 pme_solve_kernel<GridOrderingInternal::YZX, false> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
477 else if (gridOrdering == GridOrdering::XYZ)
479 if (computeEnergyAndVirial)
481 pme_solve_kernel<GridOrderingInternal::XYZ, true> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
485 pme_solve_kernel<GridOrderingInternal::XYZ, false> <<< blocks, threads, 0, stream>>> (*kernelParamsPtr);
488 CU_LAUNCH_ERR("pme_solve_kernel");
489 pme_gpu_stop_timing(pmeGpu, gtPME_SOLVE);
491 if (computeEnergyAndVirial)
493 cu_copy_D2H(pmeGpu->staging.h_virialAndEnergy, kernelParamsPtr->constants.d_virialAndEnergy,
494 c_virialAndEnergyCount * sizeof(float), pmeGpu->settings.transferKind, stream);
497 if (copyInputAndOutputGrid)
499 cu_copy_D2H(h_grid, kernelParamsPtr->grid.d_fourierGrid, pmeGpu->archSpecific->complexGridSize * sizeof(float),
500 pmeGpu->settings.transferKind, stream);